Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CUDA driver error: misaligned address for transpose scheduler #1918

Merged
merged 6 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,9 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams& params) {

// parallelize group2 and its cached inputs
{
reference2->axis(-1)->parallelize(ParallelType::Vectorize);
if (params.vectorize_factor2 > 1) {
reference2->axis(-1)->parallelize(ParallelType::Vectorize);
}
reference2->axis(-2)->parallelize(ParallelType::TIDx);
reference2->axis(-3)->parallelize(ParallelType::Unroll);

Expand All @@ -542,9 +544,27 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams& params) {
scheduler_utils::parallelizeAllLike(
reference2,
{group2_and_cached_inputs.begin(), group2_and_cached_inputs.end()},
{ParallelType::Vectorize, ParallelType::TIDx});
{ParallelType::TIDx});

// Only unrolled the axes that exactly maps to the unrolled axes
// Only vectorize the axes that exactly maps to the vectorized axes
// on reference as support for permissively mapped axes are not
// yet clearly defined.
std::vector<TensorView*> vectorized_group2_cached_inputs;
for (auto gin : group2_and_cached_inputs) {
if (std::any_of(
gin->domain()->domain().begin(),
gin->domain()->domain().end(),
[&ca_map, reference2](IterDomain* id) {
return ca_map.areMapped(
id, reference2->axis(-1), IdMappingMode::EXACT);
})) {
vectorized_group2_cached_inputs.push_back(gin);
}
}
scheduler_utils::parallelizeAllLike(
reference2, vectorized_group2_cached_inputs, {ParallelType::Vectorize});

// Only unroll the axes that exactly maps to the unrolled axes
// on reference as support for permissively mapped axes are not
// yet clearly defined.
std::vector<TensorView*> unrolled_group2_cached_inputs;
Expand All @@ -559,7 +579,6 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams& params) {
unrolled_group2_cached_inputs.push_back(gin);
}
}

scheduler_utils::parallelizeAllLike(
reference2, unrolled_group2_cached_inputs, {ParallelType::Unroll});
}
Expand All @@ -571,7 +590,9 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams& params) {
reference1->merge(pos);
reference1->split(pos, params.vectorize_factor1);
reference1->split(pos, kThreadsPerBlock);
reference1->axis(-1)->parallelize(ParallelType::Vectorize);
if (params.vectorize_factor1 > 1) {
reference1->axis(-1)->parallelize(ParallelType::Vectorize);
}
reference1->axis(-2)->parallelize(ParallelType::TIDx);
reference1->axis(-3)->parallelize(ParallelType::Unroll);
// [..., Unroll, TIDx, Vectorize]
Expand Down Expand Up @@ -600,10 +621,26 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams& params) {
group1_and_cached_inputs.emplace_back(ir_utils::consumerTvsOf(tv)[0]);
}
}

// Only vectorize the axes that exactly maps to the vectorized axes
// on reference as support for permissively mapped axes are not
// yet clearly defined.
std::vector<TensorView*> vectorized_group1_cached_inputs;
for (auto gin : group1_and_cached_inputs) {
if (std::any_of(
gin->domain()->domain().begin(),
gin->domain()->domain().end(),
[&ca_map, reference1](IterDomain* id) {
return ca_map.areMapped(
id, reference1->axis(-1), IdMappingMode::EXACT);
})) {
vectorized_group1_cached_inputs.push_back(gin);
}
}
scheduler_utils::parallelizeAllLike(
reference1, group1_and_cached_inputs, {ParallelType::Vectorize});
reference1, vectorized_group1_cached_inputs, {ParallelType::Vectorize});

// Only unrolled the axes that exactly maps to the unrolled axes
// Only unroll the axes that exactly maps to the unrolled axes
// on reference as support for permissively mapped axes are not
// yet clearly defined.
std::vector<TensorView*> unrolled_group1_cached_inputs;
Expand All @@ -618,7 +655,6 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams& params) {
unrolled_group1_cached_inputs.push_back(gin);
}
}

scheduler_utils::parallelizeAllLike(
reference1, unrolled_group1_cached_inputs, {ParallelType::Unroll});
}
Expand Down
99 changes: 99 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,36 @@ TEST_F(NVFuserTest, FusionScheduleBroadcastOnly_CUDA) {
}
}

// mermaid graph:
// ```mermaid
// %%{
// init: {
// 'theme': 'base',
// 'themeVariables': { 'fontSize': '30px', 'fontFamily': 'times'}}
// }%%
// graph TD
// T0("T0(M, N, K)")
// T1("T1(N, M, K)")
// T2("T2(M, K, N)")
// T0 --> A("transpose(1, 2)") --> T3("T3(M, K, N)")
// T1 ---> sigmoid --> T5("T5(N, M, K)")
// T5 --> B("transpose(0, 2)") --> T7("T7(K, M, N)")
// T2 ----> C("add")
// T3 --> C --> T6("T6(M, K, N)")
// T6 --> D("transpose(0, 1)") --> T11("T11(K, M, N)")
// T11 --> E("add") -->T12("T12(K, M, N)")
// T7 --> E
// T1 ---> F("transpose(0, 1)") --> T4("T4(M, N, K)")
// T0 --> G("add") --> T8("T8(M, N, K)") --> relu ---> T9("T9(M, N, K)")
// T4 --> G
// T6 ---> sin ---> T10("T10(M, K, N)")
// style T0 fill:lightgreen
// style T1 fill:lightgreen
// style T2 fill:lightgreen
// style T12 fill:lightblue
// style T9 fill:lightblue
// style T10 fill:lightblue
// ```
TEST_F(NVFuserTest, FusionScheduleTransposeComplexDAG1_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down Expand Up @@ -546,6 +576,36 @@ TEST_F(NVFuserTest, FusionScheduleTransposeComplexDAG1_CUDA) {
__FILE__);
}

// mermaid graph:
// ```mermaid
// %%{
// init: {
// 'theme': 'base',
// 'themeVariables': { 'fontSize': '30px', 'fontFamily': 'times'}}
// }%%
// graph TD
// T0("T0(M, N, K)")
// T1("T1(N, M, K)")
// T2("T2(M, K, N)")
// T0 --> A("transpose(1, 2)") --> T3("T3(M, K, N)")
// T1 ---> sigmoid --> T5("T5(N, M, K)")
// T5 --> B("transpose(0, 2)") --> T7("T7(K, M, N)")
// T2 ----> C("add")
// T3 --> C --> T6("T6(M, K, N)")
// T6 --> D("transpose(0, 1)") --> T11("T11(K, M, N)")
// T11 --> E("add") -->T12("T12(K, M, N)")
// T7 --> E
// T1 ---> F("transpose(0, 1)") --> T4("T4(M, N, K)")
// T0 --> G("add") --> T8("T8(M, N, K)") --> relu ---> T9("T9(M, N, K)")
// T4 --> G
// T6 ---> sin ---> T10("T10(M, K, N)")
// style T0 fill:lightgreen
// style T1 fill:lightgreen
// style T2 fill:lightgreen
// style T12 fill:lightblue
// style T9 fill:lightblue
// style T10 fill:lightblue
// ```
TEST_F(NVFuserTest, FusionManualScheduleTransposeComplexDAG1_CUDA) {
// achieved: 833.526 GB/s on RTX 3090 (theoretical bandwidth: 936 GB/s)
Fusion fusion;
Expand Down Expand Up @@ -729,6 +789,45 @@ TEST_F(NVFuserTest, FusionViewNoTranspose_CUDA) {
TORCH_CHECK(!hasAtLeastTwoValidGroups(&fusion));
}

// t0------------.
// t2->broadcast->sub->mul->relu->t6
// t1------------------'
TEST_F(NVFuserTest, FusionScheduleTransposeMissingDim_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(3);
auto tv1 = makeContigConcreteTensor({1, -1, 1});
auto tv2 = makeContigTensor(1);
fusion.addInput(tv0);
fusion.addInput(tv1);
fusion.addInput(tv2);
auto tv3 = broadcast(tv2, {true, false, true});
auto tv4 = sub(tv0, tv3);
auto tv5 = mul(tv4, tv1);
auto tv6 = relu(tv5);
fusion.addOutput(tv6);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor input0 = at::randn({512, 1024, 512}, options);
at::Tensor input1 = at::randn({1, 1024, 1}, options);
at::Tensor input2 = at::randn({1024}, options);

auto lparams = scheduleTranspose(&fusion, {input0, input1, input2});

FusionExecutor fe;
fe.compileFusion(&fusion, {input0, input1, input2}, lparams);
auto outputs = fe.runFusion({input0, input1, input2}, lparams);

auto t3 = input2.unsqueeze(0).unsqueeze(-1);
auto t4 = input0 - t3;
auto t5 = t4 * input1;
auto t6 = at::relu(t5);

testValidate(
&fusion, outputs, {input0, input1, input2}, {t6}, __LINE__, __FILE__);
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)