Skip to content

Commit

Permalink
Add an example on how to manually schedule transpose (#1889)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Aug 10, 2022
1 parent 83dbf56 commit 8a45dbf
Showing 1 changed file with 170 additions and 0 deletions.
170 changes: 170 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25593,6 +25593,176 @@ TEST_F(NVFuserTest, FusionPrint_CUDA) {
}
}

TEST_F(NVFuserTest, FusionManualScheduleTransposeComplexDAG1_CUDA) {
// achieved: 833.526 GB/s on RTX 3090 (theoretical bandwidth: 936 GB/s)
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(3);
auto tv1 = makeContigTensor(3);
auto tv2 = makeContigTensor(3);
fusion.addInput(tv0);
fusion.addInput(tv1);
fusion.addInput(tv2);
auto tv3 = transpose(tv0, 1, 2);
auto tv4 = transpose(tv1, 0, 1);
auto tv5 = sigmoid(tv1);
auto tv6 = add(tv2, tv3);
auto tv7 = transpose(tv5, 0, 2);
auto tv8 = add(tv4, tv0);
auto tv9 = relu(tv8);
fusion.addOutput(tv9);
auto tv10 = sin(tv6);
fusion.addOutput(tv10);
auto tv11 = transpose(tv6, 0, 1);
auto tv12 = add(tv7, tv11);
fusion.addOutput(tv12);

// group 1: tv0, tv1, *tv9, innermost dim K
// group 2: tv2, *tv10, tv12, innermost dim N

// cache inputs and outputs
auto tv0_cache = tv0->cacheAfter();
auto tv1_cache = tv1->cacheAfter();
auto tv2_cache = tv2->cacheAfter();
auto tv9_cache = tv9->cacheBefore();
auto tv10_cache = tv10->cacheBefore();
auto tv12_cache = tv12->cacheBefore();

// Step 1: Make 32x32 tiles, schedule outer dimensions
{
// Pick an arbitrary tensor as a reference tensor for this step. There is no
// requirement on which group this reference tensor should belong to. Here
// we pick tv9, which belongs to group 1.

// Make 32x32 tile:
// [M, N, K]
tv9->split(1, 32);
tv9->reorder({{2, -1}});
tv9->split(2, 32);
tv9->reorder({{3, -1}});
// [M, N/32, K/32, 32(N), 32(K)]

// merge outer dims, parallelize on BIDx, and unswitch
tv9->merge(0);
tv9->merge(0);
tv9->split(0, 1);
// [M * N/32 * K/32, 1, 32(N), 32(K)]
tv9->axis(0)->parallelize(ParallelType::BIDx);
tv9->axis(1)->parallelize(ParallelType::Unswitch);
// [BIDx, Unswitch, 32(N), 32(K)]

// propagate to the entire DAG
MaxRootDomainInfoSpanningTree entire_dag(tv9);
TransformPropagator tp(tv9);
entire_dag.traverse(&tp);
scheduler_utils::parallelizeAllLike(tv9);
}

constexpr int threads_per_block = 128;

// Step 2, schedule group 2
{
// group 2: tv2, *tv10, tv12, innermost dim N

tv2_cache->setMemoryType(MemoryType::Shared);
tv10_cache->setMemoryType(MemoryType::Shared);
tv12_cache->setMemoryType(MemoryType::Shared);

// pick tv10 as reference tensor for group 2
// [BIDx, Unswitch, 32(N), 32(K)]
tv10->reorder({{-1, -2}});
// [BIDx, Unswitch, 32(K), 32(N)]
tv10->merge(2);
tv10->split(2, 4);
tv10->split(2, threads_per_block);
tv10->axis(-1)->parallelize(ParallelType::Vectorize);
tv10->axis(-2)->parallelize(ParallelType::TIDx);
tv10->axis(-3)->parallelize(ParallelType::Unroll);
// [BIDx, Unswitch, Unroll, TIDx, Vectorize]

// Propagate to group 2 and its cache. Note that group 2 and its cache are
// not connected, so we need to borrow other tensors of the DAG to be able
// to propagate. The transformations on borrowed tensors will be overwritten
// in the next step. We can not borrow the reference tensor of group 1.
auto all_tvs_except_ref1 = ir_utils::allTvsExcept(&fusion, {tv9});
auto all_tvs_except_ref1_set = std::unordered_set<TensorView*>(
all_tvs_except_ref1.begin(), all_tvs_except_ref1.end());
SetSelector selector(all_tvs_except_ref1_set);
MaxRootDomainInfoSpanningTree tree(tv10, &selector);
TransformPropagator tp(tv10);
tree.traverse(&tp);
scheduler_utils::parallelizeAllLike(
tv10, {tv2_cache, tv10, tv12}, {ParallelType::TIDx});
scheduler_utils::parallelizeAllLike(
tv10,
{tv2_cache, tv10, tv12},
{ParallelType::Vectorize, ParallelType::Unroll});
}

// Step 3, schedule group 1
{
// group 1: tv0, tv1, *tv9, innermost dim K
// [BIDx, Unswitch, 32(N), 32(K)]
tv9->merge(2);
tv9->split(2, 4);
tv9->split(2, threads_per_block);
tv9->axis(-1)->parallelize(ParallelType::Vectorize);
tv9->axis(-2)->parallelize(ParallelType::TIDx);
tv9->axis(-3)->parallelize(ParallelType::Unroll);
// [BIDx, Unswitch, Unroll, TIDx, Vectorize]

// Propagate to the entire DAG except for group 2 and its cached inputs
auto all_tvs_except2 =
ir_utils::allTvsExcept(&fusion, {tv2, tv2_cache, tv10, tv12});
auto all_tvs_except2_set = std::unordered_set<TensorView*>(
all_tvs_except2.begin(), all_tvs_except2.end());
SetSelector selector(all_tvs_except2_set);
MaxRootDomainInfoSpanningTree tree(tv9, &selector);
TransformPropagator tp(tv9);
tree.traverse(&tp);
scheduler_utils::parallelizeAllLike(
tv9, all_tvs_except2, {ParallelType::TIDx});
scheduler_utils::parallelizeAllLike(
tv9,
{tv0_cache, tv1_cache, tv9},
{ParallelType::Vectorize, ParallelType::Unroll});
}

// inline
MaxRootDomainInfoSpanningTree entire_dag(tv9);
InlinePropagator inline_propagator(tv9, -1, ComputeAtMode::MostInlined);
entire_dag.traverse(&inline_propagator);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input0 = at::randn({512, 1024, 256}, options);
at::Tensor input1 = at::randn({1024, 512, 256}, options);
at::Tensor input2 = at::randn({512, 256, 1024}, options);

FusionExecutor fe;
fe.compileFusion(&fusion, {input0, input1, input2});
auto outputs = fe.runFusion({input0, input1, input2});

auto t3 = input0.transpose(1, 2);
auto t4 = input1.transpose(0, 1);
auto t5 = input1.sigmoid();
auto t6 = input2 + t3;
auto t7 = t5.transpose(0, 2);
auto t8 = t4 + input0;
auto t9 = t8.relu();
auto t10 = t6.sin();
auto t11 = t6.transpose(0, 1);
auto t12 = t7 + t11;

testValidate(
&fusion,
outputs,
{input0, input1, input2},
{t9, t10, t12},
__LINE__,
__FILE__);
}

TEST_F(NVFuserTest, FusionCheckedSymbolicShape_CUDA) {
const auto options =
at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
Expand Down

0 comments on commit 8a45dbf

Please sign in to comment.