Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper propagation of IterType #1762

Merged
merged 1 commit into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,49 @@ Val* newScalar(ValType vtype, DataType dtype) {
" in newScalar.");
}

IterType promoteIterType(IterType type1, IterType type2) {
// Iteration: Default
// Reduction: Should not appear here
// Broadcast: Propagated only if type1 and type2 are Broadcast
// Gather: Converted to Iteration
// Stride: Shold not appear here
// VectorComponent: Converted to Iteration

TORCH_INTERNAL_ASSERT(
type1 != IterType::Reduction && type1 != IterType::Stride,
"Invalid IterType: ",
type1)
TORCH_INTERNAL_ASSERT(
type2 != IterType::Reduction && type2 != IterType::Stride,
"Invalid IterType: ",
type2);

// Do not propagate Gather and VectorComponent
if (type1 == IterType::Gather || type1 == IterType::VectorComponent) {
type1 = IterType::Iteration;
}
if (type2 == IterType::Gather || type2 == IterType::VectorComponent) {
type2 = IterType::Iteration;
}

// At this point, type1 and type2 must be either Iteration or
// Broadcast
TORCH_INTERNAL_ASSERT(
type1 == IterType::Iteration || type1 == IterType::Broadcast,
"Unexpected IterType: ",
type1);
TORCH_INTERNAL_ASSERT(
type2 == IterType::Iteration || type2 == IterType::Broadcast,
"Unexpected IterType: ",
type2);

if (type1 == IterType::Broadcast) {
return type2;
} else {
return type1;
}
}

TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype) {
std::vector<TensorView*> tvs;
for (auto val : vals) {
Expand Down Expand Up @@ -155,12 +198,8 @@ TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype) {
}
extent_vals[i] = promoteSize(extent_vals[i], dom[i]->extent());
if (iter_types[i].has_value()) {
// TODO: Enable, see conv tests and gather promotion/gather broadcast
// behavior.
//
// TORCH_INTERNAL_ASSERT(
// iter_types[i].value() == dom[i]->getIterType(),
// "Invalid iter type promotion in newOutputTv for expression.");
iter_types[i] =
promoteIterType(iter_types[i].value(), dom[i]->getIterType());
} else {
iter_types[i] = dom[i]->getIterType();
}
Expand Down
40 changes: 40 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5384,6 +5384,46 @@ TEST_F(NVFuserTest, FusionValidateParallelizeShift_CUDA) {
testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
}

// Test IterType promotion with gather
TEST_F(NVFuserTest, FusionGatherIterTypePromotion_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);

const int s1 = 11;
const int s2 = 3;

auto tv0 = makeConcreteTensor({s1});
fusion.addInput(tv0);
auto tv1 = makeConcreteTensor({s1, s2});
fusion.addInput(tv1);

const std::vector<int> window_shape = {3};
const std::vector<std::vector<int>> padding_width = {{1, 1}};

auto tv2 = gather(tv0, window_shape, padding_width);
auto tv3 = add(tv2, tv1);

fusion.addOutput(tv3);

TORCH_CHECK(
tv3->axis(1)->getIterType() == IterType::Iteration,
"Invalid IterType promotion: ",
tv3->axis(1)->toString());

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({s1}, options);
at::Tensor t1 = at::randn({s1, s2}, options);
std::vector<IValue> inputs = {t0, t1};

auto ref = gather(t0, window_shape, padding_width) + t1;

FusionExecutor fe;
fe.compileFusion(&fusion, inputs);
auto outputs = fe.runFusion(inputs);

testValidate(&fusion, outputs, inputs, {ref}, __LINE__, __FILE__);
}

} // namespace jit
} // namespace torch
#endif // #if defined(USE_CUDA)