Skip to content

Commit ca1387a

Browse files
authored
Add support for select op (#2179)
1 parent c9f8c1d commit ca1387a

33 files changed

+546
-57
lines changed

test/cpp/jit/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ if(USE_CUDA)
111111
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp)
112112
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_rng.cu)
113113
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_utils.cpp)
114+
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_indexing_ops.cpp)
114115
endif()
115116

116117
add_executable(test_jit

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,38 @@ TensorView* unaryOp(
442442
return unaryOp(type, cast_v1)->as<TensorView>();
443443
}
444444

445+
TensorView* select(TensorView* tv, int dim, Int* index) {
446+
auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain());
447+
TORCH_CHECK(dom.size() > 0, "select can not be applied to 0d tensor.");
448+
449+
std::vector<IterDomain*> new_root;
450+
new_root.reserve(dom.size() - 1);
451+
452+
if (dim < 0) {
453+
dim += dom.size();
454+
}
455+
456+
TORCH_CHECK(
457+
dim >= 0 && dim < dom.size(),
458+
"Select on invalid axis, received: ",
459+
dim,
460+
" however tensor view only has ",
461+
dom.size(),
462+
" non-reduction dims.");
463+
464+
for (auto i : c10::irange(dom.size())) {
465+
if (i != dim) {
466+
new_root.emplace_back(dom[i]->cloneWithoutRFactor());
467+
}
468+
}
469+
470+
auto td = IrBuilder::create<TensorDomain>(
471+
new_root, TensorDomain::getContiguousContiguity(new_root));
472+
auto out = IrBuilder::create<TensorView>(td, *tv->getDataType());
473+
IrBuilder::create<SelectOp>(out, tv, dom[dim], index);
474+
return out;
475+
}
476+
445477
// TENSOR FACTORIES
446478
TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
447479
auto n = shape.size();

torch/csrc/jit/codegen/cuda/arith.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ TORCH_CUDA_CU_API WelfordResult WelfordRaw(
143143
// import IrBuilder just for this one interface.
144144
Int* init_N = nullptr);
145145

146+
TORCH_CUDA_CU_API TensorView* select(TensorView* tv, int dim, Int* index);
147+
146148
// RNG OPERATIONS
147149
TORCH_CUDA_CU_API TensorView* rand(
148150
const std::vector<Val*>& shape,
@@ -375,12 +377,14 @@ TORCH_CUDA_CU_API Val* atan2(Val* v1, Val* v2);
375377
TORCH_CUDA_CU_API TensorView* atan2(TensorView* v1, Val* v2);
376378
TORCH_CUDA_CU_API TensorView* atan2(Val* v1, TensorView* v2);
377379
TORCH_CUDA_CU_API TensorView* atan2(TensorView* v1, TensorView* v2);
378-
// div
380+
// div: promote to float for integer division, has the same semantics as the
381+
// python's operator /
379382
TORCH_CUDA_CU_API Val* div(Val* v1, Val* v2);
380383
TORCH_CUDA_CU_API TensorView* div(TensorView* v1, Val* v2);
381384
TORCH_CUDA_CU_API TensorView* div(Val* v1, TensorView* v2);
382385
TORCH_CUDA_CU_API TensorView* div(TensorView* v1, TensorView* v2);
383-
// cpp_div: similar to div, but don't promote to float
386+
// cpp_div: similar to div, but don't promote to float, this has the same
387+
// semantics as the C++'s operator /
384388
TORCH_CUDA_CU_API Val* cpp_div(Val* v1, Val* v2);
385389
TORCH_CUDA_CU_API TensorView* cpp_div(TensorView* v1, Val* v2);
386390
TORCH_CUDA_CU_API TensorView* cpp_div(Val* v1, TensorView* v2);

torch/csrc/jit/codegen/cuda/dispatch.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ void Expr::dispatch(T handler, Expr* expr) {
110110
case ExprType::TernaryOp:
111111
ptr(handler)->handle(expr->as<TernaryOp>());
112112
return;
113+
case ExprType::SelectOp:
114+
ptr(handler)->handle(expr->as<SelectOp>());
115+
return;
113116
case ExprType::RNGOp:
114117
ptr(handler)->handle(expr->as<RNGOp>());
115118
return;
@@ -296,6 +299,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
296299
case ExprType::TernaryOp:
297300
ptr(handler)->handle(expr->as<TernaryOp>());
298301
return;
302+
case ExprType::SelectOp:
303+
ptr(handler)->handle(expr->as<SelectOp>());
304+
return;
299305
case ExprType::RNGOp:
300306
ptr(handler)->handle(expr->as<RNGOp>());
301307
return;
@@ -490,6 +496,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
490496
case ExprType::TernaryOp:
491497
ptr(mutator)->mutate(expr->as<TernaryOp>());
492498
return;
499+
case ExprType::SelectOp:
500+
ptr(mutator)->mutate(expr->as<SelectOp>());
501+
return;
493502
case ExprType::RNGOp:
494503
ptr(mutator)->mutate(expr->as<RNGOp>());
495504
return;
@@ -749,6 +758,9 @@ void OptOutConstDispatch::handle(const BinaryOp* stmt) {
749758
void OptOutConstDispatch::handle(const TernaryOp* stmt) {
750759
unhandled(stmt);
751760
}
761+
void OptOutConstDispatch::handle(const SelectOp* stmt) {
762+
unhandled(stmt);
763+
}
752764
void OptOutConstDispatch::handle(const RNGOp* stmt) {
753765
unhandled(stmt);
754766
}
@@ -905,6 +917,9 @@ void OptOutDispatch::handle(BinaryOp* stmt) {
905917
void OptOutDispatch::handle(TernaryOp* stmt) {
906918
unhandled(stmt);
907919
}
920+
void OptOutDispatch::handle(SelectOp* stmt) {
921+
unhandled(stmt);
922+
}
908923
void OptOutDispatch::handle(RNGOp* stmt) {
909924
unhandled(stmt);
910925
}

torch/csrc/jit/codegen/cuda/dispatch.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class EyeOp;
7474
class UnaryOp;
7575
class BinaryOp;
7676
class TernaryOp;
77+
class SelectOp;
7778
class RNGOp;
7879
class ReductionOp;
7980
class GroupedReductionOp;
@@ -149,6 +150,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
149150
virtual void handle(const UnaryOp* stmt);
150151
virtual void handle(const BinaryOp* stmt);
151152
virtual void handle(const TernaryOp* stmt);
153+
virtual void handle(const SelectOp* stmt);
152154
virtual void handle(const RNGOp* stmt);
153155
virtual void handle(const ReductionOp* stmt);
154156
virtual void handle(const GroupedReductionOp* stmt);
@@ -216,6 +218,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
216218
virtual void handle(UnaryOp* stmt);
217219
virtual void handle(BinaryOp* stmt);
218220
virtual void handle(TernaryOp* stmt);
221+
virtual void handle(SelectOp* stmt);
219222
virtual void handle(RNGOp* stmt);
220223
virtual void handle(ReductionOp* stmt);
221224
virtual void handle(GroupedReductionOp* stmt);
@@ -324,6 +327,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
324327
virtual void mutate(UnaryOp*);
325328
virtual void mutate(BinaryOp*);
326329
virtual void mutate(TernaryOp*);
330+
virtual void mutate(SelectOp*);
327331
virtual void mutate(RNGOp*);
328332
virtual void mutate(ReductionOp*);
329333
virtual void mutate(GroupedReductionOp*);

torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2633,7 +2633,8 @@ ScheduleHeuristic SegmentCandidateFinder::deriveHeuristic(
26332633
SegmentedGroup* group) {
26342634
Fusion* fusion = segmented_fusion_->completeFusion();
26352635
auto h = tryMerge(fusion, runtime_info_, group);
2636-
TORCH_INTERNAL_ASSERT(h.has_value());
2636+
TORCH_INTERNAL_ASSERT(
2637+
h.has_value(), "Can not find a scheduler to schedule fusion segment");
26372638
return h.value();
26382639
}
26392640

torch/csrc/jit/codegen/cuda/index_compute.cpp

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ Val* getProducerIndexWithHalo(
6969
const TensorView* producer_tv,
7070
size_t producer_axis,
7171
Val* producer_index,
72-
const TensorView* consumer_tv) {
73-
const auto offset =
74-
getProducerHaloOffset(producer_tv, producer_axis, consumer_tv);
72+
const TensorView* consumer_tv,
73+
bool is_overriden_index) {
74+
const auto offset = is_overriden_index
75+
? 0
76+
: getProducerHaloOffset(producer_tv, producer_axis, consumer_tv);
7577

7678
if (offset == 0) {
7779
return producer_index;
@@ -1460,7 +1462,8 @@ Val* hoistProducerIndex(
14601462
std::vector<Val*> Index::getGlobalProducerStridedIndices(
14611463
TensorView* producer_tv,
14621464
const TensorView* consumer_tv,
1463-
const std::vector<kir::ForLoop*>& loops) {
1465+
const std::vector<kir::ForLoop*>& loops,
1466+
const std::unordered_map<IterDomain*, Val*>& override_index) {
14641467
FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalProducerIndex");
14651468

14661469
// Replay producer to look like consumer so we can index on producer since
@@ -1545,23 +1548,6 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
15451548
continue;
15461549
}
15471550

1548-
Val* root_ind = nullptr;
1549-
if (producer_indexing.indexMap().find(root_dom[dim]) !=
1550-
producer_indexing.indexMap().end()) {
1551-
root_ind = producer_indexing.indexMap().at(root_dom[dim]);
1552-
} else if (root_dom[dim]->isBroadcast()) {
1553-
root_ind = GpuLower::current()->kernel()->zeroVal();
1554-
}
1555-
1556-
TORCH_INTERNAL_ASSERT(
1557-
root_ind != nullptr,
1558-
"Couldn't find root mapping for ",
1559-
producer_tv->toString(),
1560-
" dim: ",
1561-
dim,
1562-
" id: ",
1563-
root_dom[dim]->toString());
1564-
15651551
if (producer_tv->domain()->contiguity()[dim]) {
15661552
// If contig, used the stored stride which may be the previous
15671553
// dimensions stride * previous dimensions size
@@ -1591,18 +1577,27 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
15911577
continue;
15921578
}
15931579

1594-
TORCH_INTERNAL_ASSERT(
1580+
Val* root_ind = nullptr;
1581+
auto override_it = override_index.find(root_dom[i]);
1582+
if (override_it != override_index.end()) {
1583+
root_ind = override_it->second;
1584+
} else if (
15951585
producer_indexing.indexMap().find(root_dom[i]) !=
1596-
producer_indexing.indexMap().end(),
1597-
"Couldn't find root mapping for TV",
1598-
producer_tv->name(),
1586+
producer_indexing.indexMap().end()) {
1587+
root_ind = producer_indexing.indexMap().at(root_dom[i]);
1588+
} else if (root_dom[i]->isBroadcast()) {
1589+
root_ind = GpuLower::current()->kernel()->zeroVal();
1590+
}
1591+
1592+
TORCH_INTERNAL_ASSERT(
1593+
root_ind != nullptr,
1594+
"Couldn't find root mapping for ",
1595+
producer_tv->toString(),
15991596
" dim: ",
16001597
i,
16011598
" id: ",
16021599
root_dom[i]->toString());
16031600

1604-
auto root_ind = producer_indexing.indexMap().at(root_dom[i]);
1605-
16061601
// index hoist must be done before the adjustments for halo
16071602
root_ind = hoistProducerIndex(
16081603
root_dom[i],
@@ -1615,7 +1610,12 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
16151610
loops,
16161611
root_ind);
16171612

1618-
root_ind = getProducerIndexWithHalo(producer_tv, i, root_ind, consumer_tv);
1613+
root_ind = getProducerIndexWithHalo(
1614+
producer_tv,
1615+
i,
1616+
root_ind,
1617+
consumer_tv,
1618+
override_index.count(root_dom[i]));
16191619

16201620
root_ind = getProducerIndexWithGather(
16211621
root_ind,
@@ -1686,7 +1686,8 @@ std::unordered_map<IterDomain*, IterDomain*> mapAllProducerDomainsToConsumer(
16861686
std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
16871687
TensorView* producer_tv,
16881688
const TensorView* consumer_tv,
1689-
const std::vector<kir::ForLoop*>& loops) {
1689+
const std::vector<kir::ForLoop*>& loops,
1690+
const std::unordered_map<IterDomain*, Val*>& override_index) {
16901691
const auto gpu_lower = GpuLower::current();
16911692

16921693
// Replay producer to look like consumer so we can index on producer since our
@@ -1827,7 +1828,10 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
18271828
" id: ",
18281829
root_dom[i]->toString());
18291830

1830-
auto root_ind_i = index_map.at(root_dom[i]);
1831+
auto override_it = override_index.find(root_dom[i]);
1832+
auto root_ind_i =
1833+
(override_it != override_index.end() ? override_it->second
1834+
: index_map.at(root_dom[i]));
18311835

18321836
// index hoist must be done before the adjustments for halo
18331837
root_ind_i = hoistProducerIndex(
@@ -1841,8 +1845,12 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
18411845
loops,
18421846
root_ind_i);
18431847

1844-
root_ind_i =
1845-
getProducerIndexWithHalo(producer_tv, i, root_ind_i, consumer_tv);
1848+
root_ind_i = getProducerIndexWithHalo(
1849+
producer_tv,
1850+
i,
1851+
root_ind_i,
1852+
consumer_tv,
1853+
override_index.count(root_dom[i]));
18461854

18471855
root_ind_i = getProducerIndexWithGather(
18481856
root_ind_i,
@@ -2226,7 +2234,8 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
22262234
std::vector<Val*> Index::getProducerStridedIndices(
22272235
TensorView* producer,
22282236
const TensorView* consumer,
2229-
const std::vector<kir::ForLoop*>& loops) {
2237+
const std::vector<kir::ForLoop*>& loops,
2238+
const std::unordered_map<IterDomain*, Val*>& override_index) {
22302239
FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices");
22312240
if (producer->domain()->noReductions().size() == 0) {
22322241
return std::vector<Val*>(
@@ -2236,11 +2245,11 @@ std::vector<Val*> Index::getProducerStridedIndices(
22362245

22372246
std::vector<Val*> strided_indices;
22382247
if (producer->getMemoryType() == MemoryType::Global) {
2239-
strided_indices =
2240-
getGlobalProducerStridedIndices(producer, consumer, loops);
2248+
strided_indices = getGlobalProducerStridedIndices(
2249+
producer, consumer, loops, override_index);
22412250
} else {
2242-
strided_indices =
2243-
getNonGlobalProducerStridedIndices(producer, consumer, loops);
2251+
strided_indices = getNonGlobalProducerStridedIndices(
2252+
producer, consumer, loops, override_index);
22442253
}
22452254

22462255
TORCH_INTERNAL_ASSERT(
@@ -2256,8 +2265,10 @@ std::vector<Val*> Index::getProducerStridedIndices(
22562265
kir::TensorIndex* Index::getProducerIndex(
22572266
TensorView* producer,
22582267
const TensorView* consumer,
2259-
const std::vector<kir::ForLoop*>& loops) {
2260-
auto strided_indices = getProducerStridedIndices(producer, consumer, loops);
2268+
const std::vector<kir::ForLoop*>& loops,
2269+
const std::unordered_map<IterDomain*, Val*>& override_index) {
2270+
auto strided_indices =
2271+
getProducerStridedIndices(producer, consumer, loops, override_index);
22612272
return SimplifyingIrBuilder::create<kir::TensorIndex>(
22622273
producer, strided_indices);
22632274
}

torch/csrc/jit/codegen/cuda/index_compute.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,8 @@ class Index {
309309
static std::vector<Val*> getNonGlobalProducerStridedIndices(
310310
TensorView* producer,
311311
const TensorView* consumer,
312-
const std::vector<kir::ForLoop*>& loops);
312+
const std::vector<kir::ForLoop*>& loops,
313+
const std::unordered_map<IterDomain*, Val*>& override_index = {});
313314

314315
// Consumer indexing if it's in shared or local memory
315316
static std::vector<Val*> getNonGlobalConsumerStridedIndices(
@@ -320,7 +321,8 @@ class Index {
320321
static std::vector<Val*> getGlobalProducerStridedIndices(
321322
TensorView* producer,
322323
const TensorView* consumer,
323-
const std::vector<kir::ForLoop*>& loops);
324+
const std::vector<kir::ForLoop*>& loops,
325+
const std::unordered_map<IterDomain*, Val*>& override_index = {});
324326

325327
// Consumer indexing if it's in global memory
326328
static std::vector<Val*> getGlobalConsumerStridedIndices(
@@ -344,7 +346,8 @@ class Index {
344346
static kir::TensorIndex* getProducerIndex(
345347
TensorView* producer,
346348
const TensorView* consumer,
347-
const std::vector<kir::ForLoop*>& loops);
349+
const std::vector<kir::ForLoop*>& loops,
350+
const std::unordered_map<IterDomain*, Val*>& override_index = {});
348351

349352
// Consumer index dispatch
350353
static kir::TensorIndex* getConsumerIndex(
@@ -358,7 +361,8 @@ class Index {
358361
static std::vector<Val*> getProducerStridedIndices(
359362
TensorView* producer,
360363
const TensorView* consumer,
361-
const std::vector<kir::ForLoop*>& loops);
364+
const std::vector<kir::ForLoop*>& loops,
365+
const std::unordered_map<IterDomain*, Val*>& override_index = {});
362366

363367
//! Returns a vector of strided indices mapped onto the (rfactor)
364368
//! root domain of a consumer tensor. The size of the returned

torch/csrc/jit/codegen/cuda/ir_builder.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ IR_BUILDER_INSTANTIATE(EyeOp)
6666
IR_BUILDER_INSTANTIATE(UnaryOp)
6767
IR_BUILDER_INSTANTIATE(BinaryOp)
6868
IR_BUILDER_INSTANTIATE(TernaryOp)
69+
IR_BUILDER_INSTANTIATE(SelectOp)
6970
IR_BUILDER_INSTANTIATE(RNGOp)
7071
IR_BUILDER_INSTANTIATE(ReductionOp)
7172
IR_BUILDER_INSTANTIATE(GroupedReductionOp)

torch/csrc/jit/codegen/cuda/ir_cloner.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ void IrCloner::handle(const TernaryOp* op) {
112112
clone_ = IrBuilder::clone(op, this);
113113
}
114114

115+
void IrCloner::handle(const SelectOp* op) {
116+
clone_ = IrBuilder::clone(op, this);
117+
}
118+
115119
void IrCloner::handle(const RNGOp* op) {
116120
clone_ = IrBuilder::clone(op, this);
117121
}

0 commit comments

Comments
 (0)