Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add extra configurability to parallelizeAllLike #1831

Merged
merged 7 commits into from
Jul 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,19 @@ std::vector<TensorView*> allTvs(Fusion* fusion) {
return uniqueEntries<TensorView>(all_tvs);
}

std::vector<TensorView*> allTvsExcept(
Fusion* fusion,
const std::unordered_set<TensorView*>& except) {
auto all_tvs = allTvs(fusion);
std::vector<TensorView*> result;
for (auto tv : all_tvs) {
if (except.count(tv) == 0) {
result.emplace_back(tv);
}
}
return result;
}

std::vector<Expr*> getReductionOps(Fusion* fusion, bool ignore_trivial) {
std::vector<Expr*> red_ops;

Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,12 @@ TORCH_CUDA_CU_API std::vector<TensorView*> outputTvsOf(
// returns all tensor views in fusion that are used between outputs and inputs.
TORCH_CUDA_CU_API std::vector<TensorView*> allTvs(Fusion* fusion);

// returns all tensor views in fusion that are used between outputs and inputs
// except the specified set.
TORCH_CUDA_CU_API std::vector<TensorView*> allTvsExcept(
Fusion* fusion,
const std::unordered_set<TensorView*>& except);

TORCH_CUDA_CU_API std::vector<Expr*> getReductionOps(
Fusion* fusion,
bool ignore_trivial = true);
Expand Down
52 changes: 23 additions & 29 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
}

int64_t unswitch_pos;
IterDomain* vectorize_id = nullptr;
if (params.break_point) {
// 2D parallelization scheme
TORCH_INTERNAL_ASSERT(rhs_i >= 0 && lhs_i >= 0);
Expand All @@ -692,9 +693,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
reference_tv->axis(1)->parallelize(ParallelType::Unswitch);
reference_tv->axis(3)->parallelize(ParallelType::TIDx);

// Aggressively mark with vectorized and cleanup later. That way we
// don't have to manually specify parallelization outside the reference.
reference_tv->axis(4)->parallelize(ParallelType::Vectorize);
vectorize_id = reference_tv->axis(4);

// [outer, Unswitch | i-remainder, TIDx, Vectorization]
// To make consistent with unrolling:
Expand Down Expand Up @@ -797,7 +796,7 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
reference_tv->axis(2)->parallelize(ParallelType::Unswitch);
// Aggressively mark with vectorized and cleanup later. That way we
// don't have to manually specify parallelization outside the reference.
reference_tv->axis(3)->parallelize(ParallelType::Vectorize);
vectorize_id = reference_tv->axis(3);

//[BIDx, TIDx, Unswitch, Vectorization]
// To make consistent with unrolling:
Expand All @@ -822,37 +821,32 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
TransformPropagator propagator(reference_tv);
MaxRootDomainInfoSpanningTree spanning_tree(reference_tv);
spanning_tree.traverse(&propagator);
scheduler_utils::parallelizeAllLike(reference_tv, all_tvs);
scheduler_utils::parallelizeAllLike(reference_tv);

if (params.vectorize) {
// Grab all tensor views that should be vectorized
auto vectorized_tvs =
auto inputs_outputs =
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true);
// Going to move inputs to consumers of inputs, need a copy as we'll modify
// the original.
{
auto vectorized_tvs_copy = vectorized_tvs;
for (auto inp : vectorized_tvs_copy) {
if (!inp->isFusionInput()) {
continue;
}
vectorized_tvs.erase(
std::find(vectorized_tvs.begin(), vectorized_tvs.end(), inp));
auto consumer_tvs = ir_utils::consumerTvsOf(inp);
vectorized_tvs.insert(
vectorized_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
std::vector<TensorView*> vectorized_tvs;
bool should_vectorize_reference_tv = false;
for (auto tv : inputs_outputs) {
if (!tv->isFusionInput()) {
vectorized_tvs.emplace_back(tv);
continue;
}
}
// Clear vectorize on tensors that shouldn't have it
for (auto tv : all_tvs) {
if (std::find(vectorized_tvs.begin(), vectorized_tvs.end(), tv) ==
vectorized_tvs.end()) {
for (auto id : tv->domain()->domain()) {
if (id->getParallelType() == ParallelType::Vectorize) {
id->parallelize(ParallelType::Serial);
}
}
if (tv == reference_tv) {
should_vectorize_reference_tv = true;
}
// move inputs to consumers of inputs
auto consumer_tvs = ir_utils::consumerTvsOf(tv);
vectorized_tvs.insert(
vectorized_tvs.end(), consumer_tvs.begin(), consumer_tvs.end());
}
vectorize_id->parallelize(ParallelType::Vectorize);
scheduler_utils::parallelizeAllLike(
reference_tv, vectorized_tvs, {ParallelType::Vectorize});
if (!should_vectorize_reference_tv) {
vectorize_id->parallelize(ParallelType::Serial);
}
}

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ void multiReductionInliner(
}

// Propagate parallelization
scheduler_utils::parallelizeAllLike(reference_tv, ir_utils::allTvs(fusion));
scheduler_utils::parallelizeAllLike(reference_tv);

// Find iter domains that are mapped to a trivial reduction, these should
// never be inlined.
Expand Down
47 changes: 35 additions & 12 deletions torch/csrc/jit/codegen/cuda/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,30 +188,53 @@ size_t mergeNonReduction(

void parallelizeAllLike(
TensorView* reference_tv,
const std::vector<TensorView*>& all_tvs) {
int64_t pos,
std::vector<TensorView*> selected_tvs,
const std::unordered_set<ParallelType>& selected_parallel_types,
bool propagate_padding) {
FusionGuard fg(reference_tv->fusion());

if (pos < 0) {
pos += reference_tv->nDims() + 1;
}
TORCH_CHECK(
pos >= 0 && pos <= reference_tv->nDims(),
"parallelizeAllLike called on an position outside valid range.");

std::unordered_map<IterDomain*, IterDomain*> concrete_to_reference_map;

auto ca_map = ComputeAtMap(FusionGuard::getCurFusion());

for (auto id : reference_tv->domain()->domain()) {
ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE)
->parallelize(id->getParallelType());
if (id->hasPaddingToMultipleOfWarp()) {
ca_map.getConcreteMappedID(id, IdMappingMode::PERMISSIVE)
->padToMultipleOfWarp(id->getMaybeSizeAfterPadding());
}
const auto& reference_dom = reference_tv->domain()->domain();
for (auto it = reference_dom.begin(); it != reference_dom.begin() + pos;
it++) {
auto ca_id = ca_map.getConcreteMappedID(*it, IdMappingMode::PERMISSIVE);
concrete_to_reference_map[ca_id] = *it;
}

for (auto tv : all_tvs) {
if (selected_tvs.empty()) {
selected_tvs = ir_utils::allTvs(reference_tv->fusion());
}
for (auto tv : selected_tvs) {
if (tv->isFusionInput()) {
continue;
}
for (const auto i : c10::irange(tv->domain()->domain().size())) {
auto ca_id =
ca_map.getConcreteMappedID(tv->axis(i), IdMappingMode::PERMISSIVE);
tv->axis(i)->parallelize(ca_id->getParallelType());
if (ca_id->hasPaddingToMultipleOfWarp()) {
tv->axis(i)->padToMultipleOfWarp(ca_id->getMaybeSizeAfterPadding());
if (concrete_to_reference_map.count(ca_id) > 0) {
auto reference_id = concrete_to_reference_map.at(ca_id);
auto reference_parallel_type = reference_id->getParallelType();
if (selected_parallel_types.empty() ||
selected_parallel_types.count(reference_parallel_type)) {
tv->axis(i)->parallelize(reference_parallel_type);
}
if (propagate_padding) {
if (reference_id->hasPaddingToMultipleOfWarp()) {
tv->axis(i)->padToMultipleOfWarp(
reference_id->getMaybeSizeAfterPadding());
}
}
}
}
}
Expand Down
25 changes: 24 additions & 1 deletion torch/csrc/jit/codegen/cuda/scheduler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,32 @@ size_t mergeNonReduction(
TensorView* tv,
const std::unordered_set<IterDomain*>& dont_merge = {});

// Propagate the parallelization from the selected dimensions of the reference
// tensor to their corresponding dimensions in all selected tensors in the DAG.
// Position `pos` means selecting all the dimensions [0, 1, ..., pos - 1]. pos =
// -1 means selecting all dimensions. `selected_tvs` are selected tensors in the
// DAG. Empty `selected_tvs` means selecting all tensors in the fusion of
// `reference_tv`. `selected_parallel_types` are the selected parallel types.
// Empty `selected_parallel_types` means selecting all parallel types.
TORCH_CUDA_CU_API void parallelizeAllLike(
TensorView* reference_tv,
const std::vector<TensorView*>& all_tvs);
int64_t pos = -1,
Copy link
Collaborator

@naoyam naoyam Jul 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment. Propagation is done only for the first pos domains, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added on top of parallelizeAllLike. And you are right, pos means selecting the first pos IDs.

std::vector<TensorView*> selected_tvs = {},
const std::unordered_set<ParallelType>& selected_parallel_types = {},
bool propagate_padding = true);

TORCH_CUDA_CU_API inline void parallelizeAllLike(
TensorView* reference_tv,
std::vector<TensorView*> selected_tvs,
const std::unordered_set<ParallelType>& selected_parallel_types = {},
bool propagate_padding = true) {
parallelizeAllLike(
reference_tv,
-1,
std::move(selected_tvs),
selected_parallel_types,
propagate_padding);
}

TORCH_CUDA_CU_API void computeAtInputs(
TensorView* consumer,
Expand Down
28 changes: 14 additions & 14 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13849,7 +13849,7 @@ TEST_F(NVFuserTest, FusionSimpleVectorizeUnroll_CUDA) {

TransformPropagatorWithCheck propagator(tv3);
MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv3);

tv0_cache->axis(2)->parallelize(ParallelType::Vectorize);
tv1_cache->axis(2)->parallelize(ParallelType::Vectorize);
Expand Down Expand Up @@ -17152,7 +17152,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination3_CUDA) {

tv4->axis(0)->parallelize(ParallelType::BIDx);
tv4->axis(1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv4, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv4);

GpuLower gpulw(&fusion);

Expand Down Expand Up @@ -17203,7 +17203,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination4_CUDA) {

tv1->axis(0)->parallelize(ParallelType::TIDy);
tv1->axis(1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv1, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv1);

GpuLower gpulw(&fusion);

Expand Down Expand Up @@ -17252,7 +17252,7 @@ TEST_F(NVFuserTest, FusionPredicateElimination5_CUDA) {
auto rtvs2 = tvs2.rFactor({1});

rtvs2.avg->axis(0)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(rtvs2.avg, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(rtvs2.avg);

GpuLower gpulw(&fusion);

Expand Down Expand Up @@ -20392,7 +20392,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering1_CUDA) {

tv3->axis(-2)->parallelize(ParallelType::BIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv3);

tv1->doubleBuffer();

Expand Down Expand Up @@ -20430,7 +20430,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering2_CUDA) {

tv3->axis(-2)->parallelize(ParallelType::BIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv3);

tv1->doubleBuffer();

Expand Down Expand Up @@ -20479,7 +20479,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering3_CUDA) {
tv2->doubleBuffer();

tv3->axis(-1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv3);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::manual_seed(0);
Expand Down Expand Up @@ -20520,7 +20520,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering4_CUDA) {

tv3->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(1)->parallelize(ParallelType::Unswitch);
scheduler_utils::parallelizeAllLike(tv3, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv3);

tv2->doubleBuffer();

Expand Down Expand Up @@ -20562,7 +20562,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering5_CUDA) {

tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv2->axis(1)->parallelize(ParallelType::Unswitch);
scheduler_utils::parallelizeAllLike(tv2, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv2);

tv1->doubleBuffer();

Expand Down Expand Up @@ -20684,7 +20684,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering8_CUDA) {
tv1->computeAt(tv4, 1);

tv4->axis(-1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(tv4, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv4);

tv2->doubleBuffer();
tv3->doubleBuffer();
Expand Down Expand Up @@ -20728,7 +20728,7 @@ TEST_F(NVFuserTest, FusionDoubleBuffering9_CUDA) {
tv3->computeAt(out, -1);

out->axis(-1)->parallelize(ParallelType::TIDx);
scheduler_utils::parallelizeAllLike(out, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(out);

tv2->doubleBuffer();
tv3->doubleBuffer();
Expand Down Expand Up @@ -20806,7 +20806,7 @@ TEST_F(NVFuserTest, FusionSmemBlockGemmCacheDoubleBuffer_CUDA) {
tv5->axis(-3)->parallelize(ParallelType::TIDy);
tv5->axis(-1)->parallelize(ParallelType::TIDx);

scheduler_utils::parallelizeAllLike(tv5, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(tv5);

tv0_cache_local->doubleBuffer();
tv1_cache_local->doubleBuffer();
Expand Down Expand Up @@ -21170,7 +21170,7 @@ TEST_F(NVFuserTest, FusionIssue1430_CUDA) {

auto rfactor = ir_utils::rfactorHelper(tv3, {1, 4});

scheduler_utils::parallelizeAllLike(rfactor, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(rfactor);

for (auto tv : ir_utils::allTvs(&fusion)) {
if (tv != tv1 || tv != tv3) {
Expand Down Expand Up @@ -23202,7 +23202,7 @@ TEST_F(NVFuserTest, FusionTestReEntrantGridWelford_CUDA) {
TransformPropagatorWithCheck propagator(reduction_tv);
MaxRootDomainInfoSpanningTree(reduction_tv).traverse(&propagator);
auto rfactor_tv = ir_utils::rfactorHelper(reduction_tv, {4});
scheduler_utils::parallelizeAllLike(rfactor_tv, ir_utils::allTvs(&fusion));
scheduler_utils::parallelizeAllLike(rfactor_tv);

tv0->computeAt(tv_avg, 2);
tv0->computeAt(cached_input, -2);
Expand Down
Loading