Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scheduler_utils::parallelizeAllLike is changing parallelization that it should not change #1828

Closed
zasdfgbnm opened this issue Jul 15, 2022 · 3 comments · Fixed by #1831
Closed

Comments

@zasdfgbnm
Copy link
Collaborator

🐛 Describe the bug

I am trying to schedule a transpose kernel manually using the following code:

// Copied from Shiming's PR
struct MatmulSelector : public MaxInfoSpanningTree::Selector {
 public:
  explicit MatmulSelector(std::unordered_set<TensorView*> selected_tvs)
      : selected_tvs_(selected_tvs) {}

  explicit MatmulSelector() : propagate_all_(true) {}

  bool allowC2P(TensorView* from, TensorView* to) final {
    return propagate_all_ || selected_tvs_.count(to);
  }
  bool allowP2C(TensorView* from, TensorView* to) final {
    return propagate_all_ || selected_tvs_.count(to);
  }
  bool allowSibling(TensorView* from, TensorView* to) final {
    return propagate_all_ || selected_tvs_.count(to);
  }

 private:
  std::unordered_set<TensorView*> selected_tvs_;
  bool propagate_all_ = false;
};

// Copied from Shiming's PR, modified
void propagateFrom(
    TensorView* from_tv,
    int pos,
    std::unordered_set<TensorView*> included_tvs,
    bool propagate_transformation,
    bool propagate_parallel_type,
    bool propagate_inline_position) {
  MatmulSelector selector(included_tvs);
  MaxRootDomainInfoSpanningTree spanning_tree(from_tv, &selector);

  if (propagate_transformation) {
    TransformPropagator propagator(from_tv, pos);
    spanning_tree.traverse(&propagator);
  }

  if (propagate_parallel_type) {
    scheduler_utils::parallelizeAllLike(
        from_tv, {included_tvs.begin(), included_tvs.end()});
  }

  if (propagate_inline_position) {
    InlinePropagator inline_propagator(from_tv, pos, included_tvs);
    spanning_tree.traverse(&inline_propagator);
  }
}

TEST_F(NVFuserTest, Transpose2) {
  Fusion fusion;
  FusionGuard fg(&fusion);

  auto tv0 = makeSymbolicTensor(3);
  fusion.addInput(tv0);
  auto tv1 = sin(tv0);
  auto tv2 = transpose(tv1, 1, 2);
  auto tv3 = cos(tv2);
  fusion.addOutput(tv3);

  auto tv4 = tv0->cacheAfter();
  auto tv5 = tv3->cacheBefore();

  tv1->setMemoryType(MemoryType::Shared);

  tv1->split(1, 32);
  tv1->reorder({{2, -1}});
  tv1->split(2, 32);
  tv1->reorder({{3, -1}});

  tv1->axis(0)->parallelize(ParallelType::BIDz);
  tv1->axis(1)->parallelize(ParallelType::BIDy);
  tv1->axis(2)->parallelize(ParallelType::BIDx);
  tv1->axis(3)->parallelize(ParallelType::TIDy);
  tv1->axis(4)->parallelize(ParallelType::TIDx);

  tv2->split(1, 32);
  tv2->reorder({{2, -1}});
  tv2->split(2, 32);
  tv2->reorder({{3, -1}});
  tv2->reorder({{1, 2}});

  tv2->axis(0)->parallelize(ParallelType::BIDz);
  tv2->axis(1)->parallelize(ParallelType::BIDy);
  tv2->axis(2)->parallelize(ParallelType::BIDx);
  tv2->axis(3)->parallelize(ParallelType::TIDy);
  tv2->axis(4)->parallelize(ParallelType::TIDx);

  fusion.print();

  propagateFrom(tv1, -1, {tv1, tv0, tv4}, true, true, false);
  propagateFrom(tv1, 3, {tv1, tv0, tv4}, false, false, true);

  fusion.print();

  propagateFrom(tv2, -1, {tv2, tv3, tv5}, true, true, false);
  propagateFrom(tv2, 3, {tv2, tv3, tv5}, false, false, true);

  fusion.printKernel();
}

And I am seeing strange behavior: At the first print, tv1 and tv2 has:

T1_s[ iblockIdx.z3{i1}, iblockIdx.y18{( ceilDiv(i2, 32) )}, iblockIdx.x20{( ceilDiv(i3, 32) )}, ithreadIdx.y19{32}, ithreadIdx.x21{32} ]
 root domain : (iblockIdx.z3{i1},iS4{i2},iS5{i3})
  Split: iS4{i2} by factor 32 -> iblockIdx.y18{( ceilDiv(i2, 32) )}, ithreadIdx.y19{32}, start offset: 0, stop offset: 0
  Split: iS5{i3} by factor 32 -> iblockIdx.x20{( ceilDiv(i3, 32) )}, ithreadIdx.x21{32}, start offset: 0, stop offset: 0
T2_l[ iblockIdx.z6{i1}, iblockIdx.y24{( ceilDiv(i2, 32) )}, iblockIdx.x22{( ceilDiv(i3, 32) )}, ithreadIdx.y23{32}, ithreadIdx.x25{32} ]
 root domain : (iblockIdx.z6{i1},iS7{i3},iS8{i2})
  Split: iS8{i2} by factor 32 -> iblockIdx.y24{( ceilDiv(i2, 32) )}, ithreadIdx.x25{32}, start offset: 0, stop offset: 0
  Split: iS7{i3} by factor 32 -> iblockIdx.x22{( ceilDiv(i3, 32) )}, ithreadIdx.y23{32}, start offset: 0, stop offset: 0

which is expected. But at the second print, they become:

T1_s[ iblockIdx.z3{i1}, iblockIdx.y18{( ceilDiv(i2, 32) )}, iblockIdx.x20{( ceilDiv(i3, 32) )}, ithreadIdx.y19{32}, ithreadIdx.x21{32} ] ca_pos( 3 ) produce_pos( 3)
 root domain : (iblockIdx.z3{i1},iS4{i2},iS5{i3})
  Split: iS4{i2} by factor 32 -> iblockIdx.y18{( ceilDiv(i2, 32) )}, ithreadIdx.y19{32}, start offset: 0, stop offset: 0
  Split: iS5{i3} by factor 32 -> iblockIdx.x20{( ceilDiv(i3, 32) )}, ithreadIdx.x21{32}, start offset: 0, stop offset: 0
T2_l[ iblockIdx.z6{i1}, iblockIdx.y24{( ceilDiv(i2, 32) )}, iblockIdx.x22{( ceilDiv(i3, 32) )}, ithreadIdx.x23{32}, ithreadIdx.y25{32} ] produce_pos( 3)
 root domain : (iblockIdx.z6{i1},iS7{i3},iS8{i2})
  Split: iS8{i2} by factor 32 -> iblockIdx.y24{( ceilDiv(i2, 32) )}, ithreadIdx.y25{32}, start offset: 0, stop offset: 0
  Split: iS7{i3} by factor 32 -> iblockIdx.x22{( ceilDiv(i3, 32) )}, ithreadIdx.x23{32}, start offset: 0, stop offset: 0

This is not what I expect, because tv2 is not in the selected set of tensors that I want to change, but the parallelization of the last two dims of tv2 is swapped.

The root cause seems to be this:

ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE)
->parallelize(id->getParallelType());

I don't understand:

  • Why do we want to change the concrete mapped id's parallelization regardless of the tensor it belongs to, and
  • If the dim I am currently looking at is not a broadcasting dim, isn't it already concrete? Why are we still looking for a concretely mapped dim?

Versions

devel

@naoyam
Copy link
Collaborator

naoyam commented Jul 15, 2022

Why do we want to change the concrete mapped id's parallelization regardless of the tensor it belongs to, and

My guess is that it's just because the propagation algorithm was designed to do so. If that's not really convenient, we should reconsider the design.

If the dim I am currently looking at is not a broadcasting dim, isn't it already concrete? Why are we still looking for a concretely mapped dim?

There's only one concrete ID among those mapped together, so there can be non-broadcast non-concrete IDs. This propagation algorithm is meant to propagate the parallel types of the reference tensor to those IDs.

@zasdfgbnm
Copy link
Collaborator Author

zasdfgbnm commented Jul 15, 2022

There's only one concrete ID among those mapped together, so there can be non-broadcast non-concrete IDs.

I see. Thanks for the info! It is good to know that.

This propagation algorithm is meant to propagate the parallel types of the reference tensor to those IDs.
My guess is that it's just because the propagation algorithm was designed to do so. If that's not really convenient, we should reconsider the design.

So currently parallelizeAllLike is really designed to propagate parallelization to all TVs, not a selected set of TVs? (So the all_tvs parameter really meant "all tvs in the fusion", not "all tvs you want to propagate"?) I think it won't be hard to restrict the propagation to the selected TVs, I just need to maintain a concrete ID to reference ID map, instead of changing the parallelization of the concrete ID. I will open a PR for that.

@naoyam
Copy link
Collaborator

naoyam commented Jul 15, 2022

Yes, as long as IDs are mapped with the permissive map, they are propagated. That doesn't work very well with something like transpose as you noticed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants