Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Welford Scheduling Support #561

Merged
merged 48 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d321193
introduce MultiScanOp
shmsong Nov 13, 2020
821e176
device-to-device schedule
shmsong Nov 16, 2020
c3b85a3
Merge branch 'multi_output_scan' of https://github.com/csarofeen/pyto…
shmsong Dec 8, 2020
9b0b404
fix codegen
shmsong Dec 8, 2020
227ffa7
swap in welfordOp
shmsong Dec 10, 2020
798d49a
Merge branch '20_12_3_devel' into multi_output_scan
shmsong Dec 10, 2020
ff2758a
format
shmsong Dec 10, 2020
3c16587
convert multiscan to welford
shmsong Dec 16, 2020
e5a42b9
preliminary kernel gen
shmsong Dec 16, 2020
b2c9c45
fix serial welford
shmsong Dec 17, 2020
2b5d469
Merge remote-tracking branch 'origin/20_12_3_devel' into welford_rebase
shmsong Dec 17, 2020
27f7ecc
add initialization
shmsong Dec 17, 2020
1bdc500
format
shmsong Dec 17, 2020
9cbc07b
use independent index lowering
shmsong Dec 17, 2020
4170533
format
shmsong Dec 17, 2020
17da83c
add serial welford test
shmsong Dec 18, 2020
aa997d7
add scheduling primitives
shmsong Jan 4, 2021
2a3cc45
Merge branch '20_12_3_devel' into welford_rebase
shmsong Jan 4, 2021
919931f
fix rfactor indexing
shmsong Jan 5, 2021
c3d3969
remove unwanted changes
shmsong Jan 8, 2021
7fc4926
cleanup && clang-tidy
shmsong Jan 8, 2021
f387381
fix sync_flag allocation
shmsong Jan 8, 2021
65b332e
Merge branch '20_12_3_devel' of https://github.com/csarofeen/pytorch …
shmsong Jan 8, 2021
486bc8f
format
shmsong Jan 8, 2021
5fb554b
refactor allocation
shmsong Jan 8, 2021
fa15a61
refactor alloc
shmsong Jan 8, 2021
0786658
Merge remote-tracking branch 'origin/20_12_3_devel' into welford_rebase2
shmsong Jan 18, 2021
b434d65
change welford API
shmsong Jan 18, 2021
ef2dbfe
revise rfactor interface
shmsong Jan 18, 2021
f843551
revise welford root domain map
shmsong Jan 18, 2021
8126155
add assertions and cleanup conditionals
shmsong Jan 18, 2021
850e768
rename helper function
shmsong Jan 19, 2021
8396373
minor fix
shmsong Jan 19, 2021
06e6bfb
change rfactor interface
shmsong Jan 22, 2021
77094da
add a scheduleReduction Test
shmsong Jan 24, 2021
df05bf9
change schedule
shmsong Jan 24, 2021
3fe883f
minor cleanup
shmsong Feb 11, 2021
4dffcec
Merge remote-tracking branch 'origin/20_12_3_devel' into welford_refa…
shmsong Feb 11, 2021
6ae1010
minor cleanup
shmsong Feb 12, 2021
d79df8c
update kernel summary pass
shmsong Feb 12, 2021
99b648c
fix codegen ; cleanup test
shmsong Feb 12, 2021
ff68685
bug fix
shmsong Feb 15, 2021
c73a9b3
thread_predicate bugfix; cleanup
shmsong Feb 15, 2021
f5e32ae
clang format
shmsong Feb 15, 2021
a049fb3
update comments
shmsong Feb 16, 2021
519fa2f
minor cleanup
shmsong Feb 16, 2021
7872914
Macro Names
shmsong Feb 16, 2021
3d92cf0
minor fix
shmsong Feb 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
359 changes: 359 additions & 0 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10852,6 +10852,365 @@ __global__ void kernel1(
TORCH_CHECK(in0.mean(dims).allclose(out_avg, /*rtol*/ 1e-5, /*atol*/ 1e-6));
}

TEST(NVFuserTest, FusionWelfordOp_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

int M = 64, N = 128;

auto tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);
auto tv1 = mul(tv0, new Double(1));
auto tvs = Welford(tv1, {1});
auto tv_M2 = tvs.var;
auto tv_avg = tvs.avg;
auto tv_N = tvs.n;
fusion.addOutput(tv_M2);
fusion.addOutput(tv_avg);
fusion.addOutput(tv_N);

tv_avg->split(1, 32);
tv_avg->split(0, 32);
tv_avg->split(0, 4);
tv_avg->reorder({{-1, -3}, {-3, -1}});
tv1->computeAt(tv_avg, -1);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
at::manual_seed(0);
at::Tensor t0 = at::randn({M, N}, options);
at::Tensor t_var = at::empty({M}, options);
at::Tensor t_avg = at::empty({M}, options);
at::Tensor t_N = at::empty({M}, options_int);

FusionExecutor fe;
fe.compileFusion(&fusion);
auto outputs = fe.runFusion({t0});

// by default Welford outputs sum of square diff so need to divide to get var
outputs[0] /= N;

testValidate(
&fusion,
outputs,
{t0},
{t0.var({1}, false), t0.mean({1}), at::ones({M}, options_int) * N},
__LINE__,
__FILE__);
}

TEST(NVFuserTest, FusionBlockWelfordOp_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

int M = 64, N = 128;

auto tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);
auto tv1 = mul(tv0, new Double(1));
auto tvs = Welford(tv1, {1});
auto tv_M2 = tvs.var;
auto tv_avg = tvs.avg;
auto tv_N = tvs.n;
fusion.addOutput(tv_M2);
fusion.addOutput(tv_avg);
fusion.addOutput(tv_N);

tv_avg->axis(-1)->parallelize(ParallelType::TIDx);

tv1->computeAt(tv_avg, -1);

//
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
at::manual_seed(0);
at::Tensor t0 = at::randn({M, N}, options);
at::Tensor t_var = at::empty({M}, options);
at::Tensor t_avg = at::empty({M}, options);
at::Tensor t_N = at::empty({M}, options_int);

FusionExecutor fe;
fe.compileFusion(&fusion);
auto outputs = fe.runFusion({t0});

// by default Welford outputs sum of square diff so need to divide to get var
outputs[0] /= N;

testValidate(
&fusion,
outputs,
{t0},
{t0.var({1}, false), t0.mean({1}), at::ones({M}, options_int) * N},
__LINE__,
__FILE__);
}

TEST(NVFuserTest, FusionGridWelfordOp_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

int M = 64, N = 128;

auto tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);
auto tv1 = mul(tv0, new Double(1));
auto tvs = Welford(tv1, {1});
auto tv_M2 = tvs.var;
auto tv_avg = tvs.avg;
auto tv_N = tvs.n;
fusion.addOutput(tv_M2);
fusion.addOutput(tv_avg);
fusion.addOutput(tv_N);

tv_avg->axis(0)->parallelize(ParallelType::TIDx);
tv_avg->axis(-1)->parallelize(ParallelType::BIDx);

tv1->computeAt(tv_avg, -1);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
at::manual_seed(0);
at::Tensor t0 = at::randn({M, N}, options);
at::Tensor t_var = at::empty({M}, options);
at::Tensor t_avg = at::empty({M}, options);
at::Tensor t_N = at::empty({M}, options_int);

FusionExecutor fe;
fe.compileFusion(&fusion);
auto outputs = fe.runFusion({t0});

// by default Welford outputs sum of square diff so need to divide to get var
outputs[0] /= N;

testValidate(
&fusion,
outputs,
{t0},
{t0.var({1}, false), t0.mean({1}), at::ones({M}, options_int) * N},
__LINE__,
__FILE__);
}

TEST(NVFuserTest, FusionRfactorWelfordOp_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

int M = 64, N = 128;

auto tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);
auto tv1 = mul(tv0, new Double(1));
auto tvs = Welford(tv1, {1});
auto tv_M2 = tvs.var;
auto tv_avg = tvs.avg;
auto tv_N = tvs.n;
fusion.addOutput(tv_M2);
fusion.addOutput(tv_avg);
fusion.addOutput(tv_N);

tv_avg->split(1, 4);
auto rtvs = tvs.rFactor({2});
tv1->computeAt(tv_avg, -1);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
at::manual_seed(0);
at::Tensor t0 = at::randn({M, N}, options);
at::Tensor t_var = at::empty({M}, options);
at::Tensor t_avg = at::empty({M}, options);
at::Tensor t_N = at::empty({M}, options_int);

FusionExecutor fe;
fe.compileFusion(&fusion);
auto outputs = fe.runFusion({t0});

// by default Welford outputs sum of square diff so need to divide to get var
outputs[0] /= N;

testValidate(
&fusion,
outputs,
{t0},
{t0.var({1}, false), t0.mean({1}), at::ones({M}, options_int) * N},
__LINE__,
__FILE__);
}

TEST(NVFuserTest, FusionWelfordSchedule_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

int M = 64, N = 128;

auto tv0 = makeSymbolicTensor(2);
fusion.addInput(tv0);
auto tv1 = mul(tv0, new Double(1));
auto tvs = Welford(tv1, {1});
auto tv_M2 = tvs.var;
auto tv_avg = tvs.avg;
auto tv_N = tvs.n;
fusion.addOutput(tv_M2);
fusion.addOutput(tv_N);
fusion.addOutput(tv_avg);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
at::manual_seed(0);
at::Tensor t0 = at::randn({M, N}, options);
auto red_params = getReductionHeuristics(&fusion, {t0}, tv_avg);

tv_avg->split(1, 4);
tv_avg->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
tv_avg->split(0, NamedScalar::getParallelDim(ParallelType::TIDy));

auto rtvs = tvs.rFactor({-3, -1});

rtvs.avg->computeAt(tv_avg, -1);

rtvs.avg->axis(-1)->parallelize(ParallelType::Unroll);

tv_avg->axis(0)->parallelize(ParallelType::BIDx);
tv_avg->axis(1)->parallelize(ParallelType::TIDy);
tv_avg->axis(-1)->parallelize(ParallelType::TIDx);

tv1->computeAt(rtvs.avg, -1);

FusionExecutor fe;
fe.compileFusion(&fusion);
auto outputs = fe.runFusion({t0}, red_params.value().lparams);

// by default Welford outputs sum of square diff so need to divide to get var
outputs[0] /= N;

auto at_var = t0.var({1}, false);
auto at_avg = t0.mean({1});
auto at_n = at::ones({M}, options_int) * N;

testValidate(
&fusion,
outputs,
{t0},
{at_var, at_n, at_avg},
__LINE__,
__FILE__,
"validate welford",
red_params.value().lparams);
}

namespace {
void testWelford(DataType dtype, int red_axis, int odim, int rdim) {
const int axis = red_axis;
at::ScalarType aten_dtype = data_type_to_aten(dtype);

Fusion fusion;
FusionGuard fg(&fusion);
TensorView* tv0 = makeSymbolicTensor(2, dtype);
bool is_fp16 = dtype == DataType::Half;
TensorView* tv0_cast = tv0;
if (is_fp16) {
tv0_cast = castOp(DataType::Float, tv0);
}
fusion.addInput(tv0);
auto tv1 = mul(tv0_cast, new Double(1));
auto tvs = Welford(tv1, {axis});
auto tv_M2 = tvs.var;
auto tv_avg = tvs.avg;
auto tv_N = tvs.n;

TensorView* avg_cast = tv_avg;
TensorView* M2_cast = tv_M2;

if (is_fp16) {
avg_cast = castOp(DataType::Half, tv_avg);
M2_cast = castOp(DataType::Half, tv_M2);
}

fusion.addOutput(M2_cast);
fusion.addOutput(tv_N);
fusion.addOutput(avg_cast);

auto options = at::TensorOptions().dtype(aten_dtype).device(at::kCUDA, 0);
auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
at::manual_seed(0);
std::vector<TensorView*> outputs_of_red;
at::Tensor aten_input =
(axis ? at::randn({odim, rdim}, options)
: at::randn({rdim, odim}, options));

if (is_fp16) {
outputs_of_red.push_back(avg_cast);
outputs_of_red.push_back(M2_cast);
}

auto reduction_params = getReductionHeuristics(&fusion, {aten_input}, tv_avg);
scheduleReduction(&fusion, reduction_params.value(), tv_avg, outputs_of_red);

auto lparams = reduction_params.value().lparams;

FusionExecutor fe;
fe.compileFusion(&fusion);
auto outputs = fe.runFusion({aten_input}, reduction_params.value().lparams);

// by default Welford outputs sum of square diff so need to divide to
// get var

outputs[0] /= rdim;

auto at_var = aten_input.var({axis}, false);
auto at_avg = aten_input.mean({axis});
auto at_n =
(axis ? at::ones({odim, rdim}, options)
: at::ones({rdim, odim}, options));
at_n = at_n.sum({axis});

testValidate(
&fusion,
outputs,
{aten_input},
{at_var, at_n, at_avg},
__LINE__,
__FILE__,
"validate welford",
reduction_params.value().lparams);
}
} // namespace

TEST(NVFuserTest, FusionWelfordShmoo_CUDA) {
std::vector<DataType> dtypes = {
DataType::Double, DataType::Float, DataType::Half};
std::vector<int> red_axis = {1, 0};
std::vector<int> output_dims = {160, 320};
std::vector<int> red_dims;

// Tried to cut down the number iterations with just
// doing every other power of 2.
for (int i = 1; i <= 1024 * 1024; i <<= 2) {
red_dims.push_back(i);
}

for (auto dtype : dtypes) {
for (auto& axis : red_axis) {
for (auto& odim : output_dims) {
for (auto& rdim : red_dims) {
// TODO: original welford algorithm actually keeps a running sum of
// squares, i.e. M_{2n} in the
// cf:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
// algorithm notation, and it can reach inf for large numbers
// with half precision. skipping too large volumes for half for
// nwo might need further numerical experiments to re-design
// this.
if (rdim > 32768 && dtype == DataType::Half) {
continue;
}

testWelford(dtype, axis, odim, rdim);
}
}
}
}
}

TEST(NVFuserTest, FusionTranspose1_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down
Loading