diff --git a/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h b/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h index 7947a27f48360..6cc4b1b8b93bd 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h +++ b/torch/csrc/jit/codegen/cuda/scheduler/pointwise_utils.h @@ -20,10 +20,6 @@ class DomainMap { } virtual ~DomainMap() = default; - bool areExactMapped(IterDomain* id1, IterDomain* id2) const { - return ca_map_.areMapped(id1, id2, IdMappingMode::EXACT); - } - const ComputeAtMap& getComputeAtMap() const { return ca_map_; } diff --git a/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp index 1bdd1d34a0a9a..bc8c3b4c71c99 100644 --- a/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp @@ -58,15 +58,49 @@ class DomainMap : public pointwise_utils::DomainMap { domain_map.findReferenceFor(grouped_inputs_outputs[1]) != nullptr; } - int getPosMappedTo(TensorView* tv, IterDomain* id) const { + int getInnerLeafDim(TensorView* tv, IterDomain* root_dim) const { + // Find the root id mapped to `root_dim` + const auto& root_dom = tv->getRootDomain(); + IterDomain* mapped_id = nullptr; + for (auto i : c10::irange(root_dom.size())) { + if (ca_map_.idGraph().permissiveNodes().permissiveAreMapped( + root_dom[i], root_dim)) { + mapped_id = root_dom[i]; + break; + } + } + TORCH_INTERNAL_ASSERT( + mapped_id != nullptr, + "Can not find ID mapped to ", + root_dim, + " in tensor ", + tv); + // Project the root id to leaf id + while (!mapped_id->uses().empty()) { + TORCH_INTERNAL_ASSERT(mapped_id->uses().size() == 1); + auto expr = mapped_id->uses()[0]; + if (expr->isA()) { + mapped_id = expr->as()->inner(); + } else { + auto merge = expr->as(); + TORCH_INTERNAL_ASSERT( + mapped_id == merge->inner(), + "Can not find ID mapped to ", + root_dim, + " in tensor ", + tv); + mapped_id = merge->out(); + } + } + // Find the position of the leaf id const auto& dom = tv->domain()->domain(); for (auto i : c10::irange(dom.size())) { - if (areExactMapped(id, tv->axis(i))) { + if (dom[i] == mapped_id) { return i; } } TORCH_INTERNAL_ASSERT( - false, "Can not find ID mapped to ", id, " in tensor ", tv); + false, "Can not find ID mapped to ", root_dim, " in tensor ", tv); } // Group inputs and outputs of a fusion by its inner most domain. For example @@ -240,22 +274,37 @@ void maybeBuildVirtualInnerDims( // both virtual innermost dim. // 2. The satisfied one did not merge in anything. For example, // T0[I0{1024*1024}, I1{2}] + // If this is the case, this means that we need to split the large + // inner-most dimension to satisfy the small innermost dimension int64_t large_dim; int64_t split_factor; + bool split_inner_most; if (merged_size1 < params.tile_size1) { if (params.dims_merged_with_2.empty()) { // case 2 - return; + split_inner_most = true; + large_dim = inner_most2; + split_factor = params.tile_size2; + } else { + // case 1 + split_inner_most = false; + large_dim = params.dims_merged_with_2.back(); + auto prev_merged_size2 = merged_size2 / shape_in_ref1[large_dim]; + split_factor = ceilDiv(params.tile_size2, prev_merged_size2); } - large_dim = params.dims_merged_with_2.back(); - split_factor = ceilDiv(params.tile_size1, merged_size1); } else { if (params.dims_merged_with_1.empty()) { // case 2 - return; + split_inner_most = true; + large_dim = inner_most1; + split_factor = params.tile_size1; + } else { + // case 1 + split_inner_most = false; + large_dim = params.dims_merged_with_1.back(); + auto prev_merged_size1 = merged_size1 / shape_in_ref1[large_dim]; + split_factor = ceilDiv(params.tile_size1, prev_merged_size1); } - large_dim = params.dims_merged_with_1.back(); - split_factor = ceilDiv(params.tile_size2, merged_size2); } params.split_before_tiling.push_back({large_dim, split_factor}); // adjust all dims to after-split @@ -271,12 +320,16 @@ void maybeBuildVirtualInnerDims( } // Give the split-out dim to the unsatisfied one, so that both are satisfied. if (merged_size1 < params.tile_size1) { - params.dims_merged_with_2.pop_back(); - params.dims_merged_with_2.push_back(large_dim + 1); + if (!split_inner_most) { + params.dims_merged_with_2.pop_back(); + params.dims_merged_with_2.push_back(large_dim + 1); + } params.dims_merged_with_1.push_back(large_dim); } else { - params.dims_merged_with_1.pop_back(); - params.dims_merged_with_1.push_back(large_dim + 1); + if (!split_inner_most) { + params.dims_merged_with_1.pop_back(); + params.dims_merged_with_1.push_back(large_dim + 1); + } params.dims_merged_with_2.push_back(large_dim); } } @@ -369,12 +422,6 @@ std::shared_ptr getTransposeHeuristics( if (n_elems < device_multiprocessor_count * kMaxTileSize * kMaxTileSize) { params->tile_size1 = 8; params->tile_size2 = 8; - // TODO: I was trying the following but I got silent wrong result - // params->tile_size1 = 8; - // params->tile_size2 = 4; - // This should not happen, because the correctness should be irrevalent to - // schedulers. We don't have to use tile size (8, 4), but we need to fix our - // bug in codegen. } // Expand inner-most dims to virtual inner-most dims so that the inner-most @@ -383,9 +430,9 @@ std::shared_ptr getTransposeHeuristics( auto inner_most_id2 = scheduler_utils::innerMostRootDim(reference2); auto inner_most_pos1_in_ref1 = - domain_map.getPosMappedTo(reference1, inner_most_id1); + domain_map.getInnerLeafDim(reference1, inner_most_id1); auto inner_most_pos2_in_ref1 = - domain_map.getPosMappedTo(reference1, inner_most_id2); + domain_map.getInnerLeafDim(reference1, inner_most_id2); // See note [Supporting small transpose dimensions] maybeBuildVirtualInnerDims( @@ -643,9 +690,9 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) { // merge with inner most dims to get virtual inner most dims size_t inner_most_pos1_in_ref1 = - domain_map.getPosMappedTo(reference1, inner_most_id1); + domain_map.getInnerLeafDim(reference1, inner_most_id1); size_t inner_most_pos2_in_ref1 = - domain_map.getPosMappedTo(reference1, inner_most_id2); + domain_map.getInnerLeafDim(reference1, inner_most_id2); if (merged1.has_value()) { if (inner_most_pos1_in_ref1 < *merged1) { reference1->reorder( diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp index 5e8b6bc1bda69..12c2593b63086 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp @@ -932,6 +932,37 @@ TEST_F(NVFuserTest, FusionScheduleTransposeSmallInnerSize3_CUDA) { testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__); } +// x->sin->transpose->cos->y +TEST_F(NVFuserTest, FusionScheduleTranspose2DSmallInnerSize_CUDA) { + std::array, 2> shapes{ + std::vector{1024 * 1024 * 128, 2}, + std::vector{2, 1024 * 1024 * 128}}; + for (const auto& shape : shapes) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeContigTensor(2); + fusion.addInput(tv0); + auto tv1 = sin(tv0); + auto tv2 = transpose(tv1, 0, 1); + auto tv3 = cos(tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::randn(shape, options); + + auto lparams = scheduleTranspose(&fusion, {input}); + + FusionExecutor fe; + fe.compileFusion(&fusion, {input}, lparams); + auto outputs = fe.runFusion({input}, lparams); + + auto tv_ref = input.sin().transpose(0, 1).cos(); + + testValidate(&fusion, outputs, {input}, {tv_ref}, __LINE__, __FILE__); + } +} + } // namespace jit } // namespace torch #endif // #if defined(USE_CUDA)