Skip to content

Commit

Permalink
Fix canScheduleCompileTime check of transpose scheduler (#1969)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Sep 13, 2022
1 parent b1bd32c commit 306d4a6
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,7 @@ class DomainMap : public pointwise_utils::DomainMap {
return result;
}

static bool hasAtLeastTwoValidGroups(Fusion* fusion) {
FusionGuard fg(fusion);
DomainMap domain_map(fusion);
auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim();
if (grouped_inputs_outputs.size() < 2) {
return false;
}
return domain_map.findReferenceFor(grouped_inputs_outputs[0]) != nullptr &&
domain_map.findReferenceFor(grouped_inputs_outputs[1]) != nullptr;
}

int getInnerLeafDim(TensorView* tv, IterDomain* root_dim) const {
IterDomain* getMappedRootDimIn(TensorView* tv, IterDomain* root_dim) const {
// Find the root id mapped to `root_dim`
const auto& root_dom = tv->getRootDomain();
IterDomain* mapped_id = nullptr;
Expand All @@ -67,6 +56,29 @@ class DomainMap : public pointwise_utils::DomainMap {
break;
}
}
return mapped_id;
}

static bool hasAtLeastTwoValidGroups(Fusion* fusion) {
FusionGuard fg(fusion);
DomainMap domain_map(fusion);
auto grouped_inputs_outputs = domain_map.groupInputsOutputsByInnerDim();
if (grouped_inputs_outputs.size() < 2) {
return false;
}
auto ref1 = domain_map.findReferenceFor(grouped_inputs_outputs[0]);
auto ref2 = domain_map.findReferenceFor(grouped_inputs_outputs[1]);
if (ref1 == nullptr || ref2 == nullptr) {
return false;
}
// reference 1 is the global reference, so it must have dim mapped the
// innermost dim of both groups
auto innermost2 = scheduler_utils::innerMostRootDim(ref2);
return domain_map.getMappedRootDimIn(ref1, innermost2) != nullptr;
}

int getInnerLeafDim(TensorView* tv, IterDomain* root_dim) const {
auto mapped_id = getMappedRootDimIn(tv, root_dim);
TORCH_INTERNAL_ASSERT(
mapped_id != nullptr,
"Can not find ID mapped to ",
Expand Down

0 comments on commit 306d4a6

Please sign in to comment.