Skip to content

Commit

Permalink
Proper propagation of IterType (csarofeen#1762)
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam authored Jun 15, 2022
1 parent b263562 commit ec7fa41
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 6 deletions.
51 changes: 45 additions & 6 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,49 @@ Val* newScalar(ValType vtype, DataType dtype) {
" in newScalar.");
}

IterType promoteIterType(IterType type1, IterType type2) {
// Iteration: Default
// Reduction: Should not appear here
// Broadcast: Propagated only if type1 and type2 are Broadcast
// Gather: Converted to Iteration
// Stride: Shold not appear here
// VectorComponent: Converted to Iteration

TORCH_INTERNAL_ASSERT(
type1 != IterType::Reduction && type1 != IterType::Stride,
"Invalid IterType: ",
type1)
TORCH_INTERNAL_ASSERT(
type2 != IterType::Reduction && type2 != IterType::Stride,
"Invalid IterType: ",
type2);

// Do not propagate Gather and VectorComponent
if (type1 == IterType::Gather || type1 == IterType::VectorComponent) {
type1 = IterType::Iteration;
}
if (type2 == IterType::Gather || type2 == IterType::VectorComponent) {
type2 = IterType::Iteration;
}

// At this point, type1 and type2 must be either Iteration or
// Broadcast
TORCH_INTERNAL_ASSERT(
type1 == IterType::Iteration || type1 == IterType::Broadcast,
"Unexpected IterType: ",
type1);
TORCH_INTERNAL_ASSERT(
type2 == IterType::Iteration || type2 == IterType::Broadcast,
"Unexpected IterType: ",
type2);

if (type1 == IterType::Broadcast) {
return type2;
} else {
return type1;
}
}

TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype) {
std::vector<TensorView*> tvs;
for (auto val : vals) {
Expand Down Expand Up @@ -155,12 +198,8 @@ TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype) {
}
extent_vals[i] = promoteSize(extent_vals[i], dom[i]->extent());
if (iter_types[i].has_value()) {
// TODO: Enable, see conv tests and gather promotion/gather broadcast
// behavior.
//
// TORCH_INTERNAL_ASSERT(
// iter_types[i].value() == dom[i]->getIterType(),
// "Invalid iter type promotion in newOutputTv for expression.");
iter_types[i] =
promoteIterType(iter_types[i].value(), dom[i]->getIterType());
} else {
iter_types[i] = dom[i]->getIterType();
}
Expand Down
40 changes: 40 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5384,6 +5384,46 @@ TEST_F(NVFuserTest, FusionValidateParallelizeShift_CUDA) {
testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
}

// Test IterType promotion with gather
TEST_F(NVFuserTest, FusionGatherIterTypePromotion_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

const int s1 = 11;
const int s2 = 3;

auto tv0 = makeConcreteTensor({s1});
fusion.addInput(tv0);
auto tv1 = makeConcreteTensor({s1, s2});
fusion.addInput(tv1);

const std::vector<int> window_shape = {3};
const std::vector<std::vector<int>> padding_width = {{1, 1}};

auto tv2 = gather(tv0, window_shape, padding_width);
auto tv3 = add(tv2, tv1);

fusion.addOutput(tv3);

TORCH_CHECK(
tv3->axis(1)->getIterType() == IterType::Iteration,
"Invalid IterType promotion: ",
tv3->axis(1)->toString());

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({s1}, options);
at::Tensor t1 = at::randn({s1, s2}, options);
std::vector<IValue> inputs = {t0, t1};

auto ref = gather(t0, window_shape, padding_width) + t1;

FusionExecutor fe;
fe.compileFusion(&fusion, inputs);
auto outputs = fe.runFusion(inputs);

testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)

0 comments on commit ec7fa41

Please sign in to comment.