Skip to content

Commit

Permalink
Merge all dims in pointwise scheduler (#1872)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored Jul 29, 2022
1 parent 172fb36 commit a48270a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
8 changes: 0 additions & 8 deletions torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,10 +528,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
int rhs_i = -1;
for (int i = (int)reference_tv->nDims(); i > (int)params.break_point; i--) {
auto axis_i = i - 1;
if (reference_tv->axis(axis_i)->isBroadcast() ||
reference_tv->axis(axis_i)->isReduction()) {
continue;
}
if (rhs_i == -1) {
rhs_i = axis_i;
} else {
Expand All @@ -548,10 +544,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
int lhs_i = -1;
for (int i = (int)params.break_point; i > 0; i--) {
auto axis_i = i - 1;
if (reference_tv->axis(axis_i)->isBroadcast() ||
reference_tv->axis(axis_i)->isReduction()) {
continue;
}
if (lhs_i == -1) {
lhs_i = axis_i;
} else {
Expand Down
34 changes: 34 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24847,6 +24847,40 @@ TEST_F(NVFuserTest, FusionInsertMagicZero1_CUDA) {
tv2->toString());
}

TEST_F(
NVFuserTest,
FusionPointwiseScheduleWithBroadcastAndTrivialReduction_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(3);
auto tv1 = makeContigTensor(2);
fusion.addInput(tv0);
fusion.addInput(tv1);
auto tv2 = broadcast(tv0, {false, true, false, true, false, true});
auto tv3 = sin(tv2);
auto tv4 = add(tv3, tv1);
auto tv5 = sum(tv4, {1});
fusion.addOutput(tv5);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({100, 100, 10}, options);
at::Tensor t1 = at::randn({10, 20}, options);

auto aten_output = (t0.view({100, 1, 100, 1, 10, 1}).sin() + t1).squeeze(1);

std::vector<IValue> aten_inputs = {t0, t1};

auto lparams = schedulePointwise(&fusion, aten_inputs);

FusionExecutor fe;
fe.compileFusion(&fusion, aten_inputs, lparams);
auto cg_outputs = fe.runFusion(aten_inputs, lparams);

testValidate(
&fusion, cg_outputs, aten_inputs, {aten_output}, __LINE__, __FILE__);
}

TEST_F(NVFuserTest, FusionInlinePropagatorMismatchedDims1_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down

0 comments on commit a48270a

Please sign in to comment.