Skip to content

Commit

Permalink
Allow splitting inner-most ID to create virtual innermost ID in trans…
Browse files Browse the repository at this point in the history
…pose scheduler (#1930)
  • Loading branch information
zasdfgbnm authored Sep 6, 2022
1 parent a3ecb33 commit 45e95fd
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 27 deletions.
4 changes: 0 additions & 4 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ class DomainMap {
}
virtual ~DomainMap() = default;

bool areExactMapped(IterDomain* id1, IterDomain* id2) const {
return ca_map_.areMapped(id1, id2, IdMappingMode::EXACT);
}

const ComputeAtMap& getComputeAtMap() const {
return ca_map_;
}
Expand Down
93 changes: 70 additions & 23 deletions torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,49 @@ class DomainMap : public pointwise_utils::DomainMap {
domain_map.findReferenceFor(grouped_inputs_outputs[1]) != nullptr;
}

int getPosMappedTo(TensorView* tv, IterDomain* id) const {
int getInnerLeafDim(TensorView* tv, IterDomain* root_dim) const {
// Find the root id mapped to `root_dim`
const auto& root_dom = tv->getRootDomain();
IterDomain* mapped_id = nullptr;
for (auto i : c10::irange(root_dom.size())) {
if (ca_map_.idGraph().permissiveNodes().permissiveAreMapped(
root_dom[i], root_dim)) {
mapped_id = root_dom[i];
break;
}
}
TORCH_INTERNAL_ASSERT(
mapped_id != nullptr,
"Can not find ID mapped to ",
root_dim,
" in tensor ",
tv);
// Project the root id to leaf id
while (!mapped_id->uses().empty()) {
TORCH_INTERNAL_ASSERT(mapped_id->uses().size() == 1);
auto expr = mapped_id->uses()[0];
if (expr->isA<Split>()) {
mapped_id = expr->as<Split>()->inner();
} else {
auto merge = expr->as<Merge>();
TORCH_INTERNAL_ASSERT(
mapped_id == merge->inner(),
"Can not find ID mapped to ",
root_dim,
" in tensor ",
tv);
mapped_id = merge->out();
}
}
// Find the position of the leaf id
const auto& dom = tv->domain()->domain();
for (auto i : c10::irange(dom.size())) {
if (areExactMapped(id, tv->axis(i))) {
if (dom[i] == mapped_id) {
return i;
}
}
TORCH_INTERNAL_ASSERT(
false, "Can not find ID mapped to ", id, " in tensor ", tv);
false, "Can not find ID mapped to ", root_dim, " in tensor ", tv);
}

// Group inputs and outputs of a fusion by its inner most domain. For example
Expand Down Expand Up @@ -240,22 +274,37 @@ void maybeBuildVirtualInnerDims(
// both virtual innermost dim.
// 2. The satisfied one did not merge in anything. For example,
// T0[I0{1024*1024}, I1{2}]
// If this is the case, this means that we need to split the large
// inner-most dimension to satisfy the small innermost dimension
int64_t large_dim;
int64_t split_factor;
bool split_inner_most;
if (merged_size1 < params.tile_size1) {
if (params.dims_merged_with_2.empty()) {
// case 2
return;
split_inner_most = true;
large_dim = inner_most2;
split_factor = params.tile_size2;
} else {
// case 1
split_inner_most = false;
large_dim = params.dims_merged_with_2.back();
auto prev_merged_size2 = merged_size2 / shape_in_ref1[large_dim];
split_factor = ceilDiv(params.tile_size2, prev_merged_size2);
}
large_dim = params.dims_merged_with_2.back();
split_factor = ceilDiv(params.tile_size1, merged_size1);
} else {
if (params.dims_merged_with_1.empty()) {
// case 2
return;
split_inner_most = true;
large_dim = inner_most1;
split_factor = params.tile_size1;
} else {
// case 1
split_inner_most = false;
large_dim = params.dims_merged_with_1.back();
auto prev_merged_size1 = merged_size1 / shape_in_ref1[large_dim];
split_factor = ceilDiv(params.tile_size1, prev_merged_size1);
}
large_dim = params.dims_merged_with_1.back();
split_factor = ceilDiv(params.tile_size2, merged_size2);
}
params.split_before_tiling.push_back({large_dim, split_factor});
// adjust all dims to after-split
Expand All @@ -271,12 +320,16 @@ void maybeBuildVirtualInnerDims(
}
// Give the split-out dim to the unsatisfied one, so that both are satisfied.
if (merged_size1 < params.tile_size1) {
params.dims_merged_with_2.pop_back();
params.dims_merged_with_2.push_back(large_dim + 1);
if (!split_inner_most) {
params.dims_merged_with_2.pop_back();
params.dims_merged_with_2.push_back(large_dim + 1);
}
params.dims_merged_with_1.push_back(large_dim);
} else {
params.dims_merged_with_1.pop_back();
params.dims_merged_with_1.push_back(large_dim + 1);
if (!split_inner_most) {
params.dims_merged_with_1.pop_back();
params.dims_merged_with_1.push_back(large_dim + 1);
}
params.dims_merged_with_2.push_back(large_dim);
}
}
Expand Down Expand Up @@ -369,12 +422,6 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(
if (n_elems < device_multiprocessor_count * kMaxTileSize * kMaxTileSize) {
params->tile_size1 = 8;
params->tile_size2 = 8;
// TODO: I was trying the following but I got silent wrong result
// params->tile_size1 = 8;
// params->tile_size2 = 4;
// This should not happen, because the correctness should be irrevalent to
// schedulers. We don't have to use tile size (8, 4), but we need to fix our
// bug in codegen.
}

// Expand inner-most dims to virtual inner-most dims so that the inner-most
Expand All @@ -383,9 +430,9 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(
auto inner_most_id2 = scheduler_utils::innerMostRootDim(reference2);

auto inner_most_pos1_in_ref1 =
domain_map.getPosMappedTo(reference1, inner_most_id1);
domain_map.getInnerLeafDim(reference1, inner_most_id1);
auto inner_most_pos2_in_ref1 =
domain_map.getPosMappedTo(reference1, inner_most_id2);
domain_map.getInnerLeafDim(reference1, inner_most_id2);

// See note [Supporting small transpose dimensions]
maybeBuildVirtualInnerDims(
Expand Down Expand Up @@ -643,9 +690,9 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) {

// merge with inner most dims to get virtual inner most dims
size_t inner_most_pos1_in_ref1 =
domain_map.getPosMappedTo(reference1, inner_most_id1);
domain_map.getInnerLeafDim(reference1, inner_most_id1);
size_t inner_most_pos2_in_ref1 =
domain_map.getPosMappedTo(reference1, inner_most_id2);
domain_map.getInnerLeafDim(reference1, inner_most_id2);
if (merged1.has_value()) {
if (inner_most_pos1_in_ref1 < *merged1) {
reference1->reorder(
Expand Down
31 changes: 31 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,37 @@ TEST_F(NVFuserTest, FusionScheduleTransposeSmallInnerSize3_CUDA) {
testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__);
}

// x->sin->transpose->cos->y
TEST_F(NVFuserTest, FusionScheduleTranspose2DSmallInnerSize_CUDA) {
std::array<std::vector<int64_t>, 2> shapes{
std::vector<int64_t>{1024 * 1024 * 128, 2},
std::vector<int64_t>{2, 1024 * 1024 * 128}};
for (const auto& shape : shapes) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(2);
fusion.addInput(tv0);
auto tv1 = sin(tv0);
auto tv2 = transpose(tv1, 0, 1);
auto tv3 = cos(tv2);
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input = at::randn(shape, options);

auto lparams = scheduleTranspose(&fusion, {input});

FusionExecutor fe;
fe.compileFusion(&fusion, {input}, lparams);
auto outputs = fe.runFusion({input}, lparams);

auto tv_ref = input.sin().transpose(0, 1).cos();

testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__);
}
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)

0 comments on commit 45e95fd

Please sign in to comment.