Skip to content

Commit 87fee1d

Browse files
committed
refactor
1 parent 11a56e6 commit 87fee1d

File tree

14 files changed

+87
-69
lines changed

14 files changed

+87
-69
lines changed

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ TensorView* select(TensorView* tv, int dim, Int* index) {
470470
auto td = IrBuilder::create<TensorDomain>(
471471
new_root, TensorDomain::getContiguousContiguity(new_root));
472472
auto out = IrBuilder::create<TensorView>(td, *tv->getDataType());
473-
IrBuilder::create<SelectOp>(out, tv, dim, index);
473+
IrBuilder::create<SelectOp>(out, tv, dom[dim], index);
474474
return out;
475475
}
476476

torch/csrc/jit/codegen/cuda/codegen.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,10 +1038,6 @@ class CudaKernelGenerator : private OptOutConstDispatch {
10381038
}
10391039
}
10401040

1041-
void handle(const SelectOp* sop) final {
1042-
indent() << gen(sop->output(0)) << " = " << gen(sop->input(0)) << ";\n";
1043-
}
1044-
10451041
std::string genArchString(MmaOptions::MacroType macro) {
10461042
std::stringstream ss;
10471043
if (isVolta(macro)) {

torch/csrc/jit/codegen/cuda/index_compute.cpp

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,7 +1460,8 @@ Val* hoistProducerIndex(
14601460
std::vector<Val*> Index::getGlobalProducerStridedIndices(
14611461
TensorView* producer_tv,
14621462
const TensorView* consumer_tv,
1463-
const std::vector<kir::ForLoop*>& loops) {
1463+
const std::vector<kir::ForLoop*>& loops,
1464+
const std::unordered_map<IterDomain*, Val*>& override_index) {
14641465
FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalProducerIndex");
14651466

14661467
// Replay producer to look like consumer so we can index on producer since
@@ -1536,13 +1537,6 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
15361537
}
15371538
}
15381539

1539-
IterDomain* selected_id = nullptr;
1540-
Val* selected_index = nullptr;
1541-
if (auto sop = dynamic_cast<SelectOp*>(consumer_tv->definition())) {
1542-
selected_id = TensorDomain::noReductions(root_dom)[sop->getDim()];
1543-
selected_index = sop->input(1);
1544-
}
1545-
15461540
TORCH_INTERNAL_ASSERT(
15471541
root_dom.size() == producer_tv->domain()->contiguity().size());
15481542
Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal();
@@ -1582,13 +1576,15 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
15821576
}
15831577

15841578
Val* root_ind = nullptr;
1585-
if (producer_indexing.indexMap().find(root_dom[i]) !=
1579+
auto override_it = override_index.find(root_dom[i]);
1580+
if (override_it != override_index.end()) {
1581+
root_ind = override_it->second;
1582+
} else if (
1583+
producer_indexing.indexMap().find(root_dom[i]) !=
15861584
producer_indexing.indexMap().end()) {
15871585
root_ind = producer_indexing.indexMap().at(root_dom[i]);
15881586
} else if (root_dom[i]->isBroadcast()) {
15891587
root_ind = GpuLower::current()->kernel()->zeroVal();
1590-
} else if (root_dom[i] == selected_id) {
1591-
root_ind = selected_index;
15921588
}
15931589

15941590
TORCH_INTERNAL_ASSERT(
@@ -1612,7 +1608,7 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
16121608
loops,
16131609
root_ind);
16141610

1615-
if (root_dom[i] != selected_id) {
1611+
if (!override_index.count(root_dom[i])) {
16161612
root_ind =
16171613
getProducerIndexWithHalo(producer_tv, i, root_ind, consumer_tv);
16181614
}
@@ -1686,7 +1682,8 @@ std::unordered_map<IterDomain*, IterDomain*> mapAllProducerDomainsToConsumer(
16861682
std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
16871683
TensorView* producer_tv,
16881684
const TensorView* consumer_tv,
1689-
const std::vector<kir::ForLoop*>& loops) {
1685+
const std::vector<kir::ForLoop*>& loops,
1686+
const std::unordered_map<IterDomain*, Val*>& override_index) {
16901687
const auto gpu_lower = GpuLower::current();
16911688

16921689
// Replay producer to look like consumer so we can index on producer since our
@@ -1794,13 +1791,6 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
17941791
// and use them.
17951792
auto root_dom = producer_tv->getMaybeRFactorDomain();
17961793

1797-
IterDomain* selected_id = nullptr;
1798-
Val* selected_index = nullptr;
1799-
if (auto sop = dynamic_cast<SelectOp*>(consumer_tv->definition())) {
1800-
selected_id = TensorDomain::noReductions(root_dom)[sop->getDim()];
1801-
selected_index = sop->input(1);
1802-
}
1803-
18041794
// Figure out which root axes we don't need to index
18051795
std::unordered_set<IterDomain*> skip_indexing;
18061796

@@ -1834,9 +1824,10 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
18341824
" id: ",
18351825
root_dom[i]->toString());
18361826

1827+
auto override_it = override_index.find(root_dom[i]);
18371828
auto root_ind_i =
1838-
(selected_id == root_dom[i] ? selected_index
1839-
: index_map.at(root_dom[i]));
1829+
(override_it != override_index.end() ? override_it->second
1830+
: index_map.at(root_dom[i]));
18401831

18411832
// index hoist must be done before the adjustments for halo
18421833
root_ind_i = hoistProducerIndex(
@@ -1850,7 +1841,7 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
18501841
loops,
18511842
root_ind_i);
18521843

1853-
if (root_dom[i] != selected_id) {
1844+
if (override_index.count(root_dom[i])) {
18541845
root_ind_i =
18551846
getProducerIndexWithHalo(producer_tv, i, root_ind_i, consumer_tv);
18561847
}
@@ -2237,7 +2228,8 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
22372228
std::vector<Val*> Index::getProducerStridedIndices(
22382229
TensorView* producer,
22392230
const TensorView* consumer,
2240-
const std::vector<kir::ForLoop*>& loops) {
2231+
const std::vector<kir::ForLoop*>& loops,
2232+
const std::unordered_map<IterDomain*, Val*>& override_index) {
22412233
FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices");
22422234
if (producer->domain()->noReductions().size() == 0) {
22432235
return std::vector<Val*>(
@@ -2247,11 +2239,11 @@ std::vector<Val*> Index::getProducerStridedIndices(
22472239

22482240
std::vector<Val*> strided_indices;
22492241
if (producer->getMemoryType() == MemoryType::Global) {
2250-
strided_indices =
2251-
getGlobalProducerStridedIndices(producer, consumer, loops);
2242+
strided_indices = getGlobalProducerStridedIndices(
2243+
producer, consumer, loops, override_index);
22522244
} else {
2253-
strided_indices =
2254-
getNonGlobalProducerStridedIndices(producer, consumer, loops);
2245+
strided_indices = getNonGlobalProducerStridedIndices(
2246+
producer, consumer, loops, override_index);
22552247
}
22562248

22572249
TORCH_INTERNAL_ASSERT(
@@ -2267,8 +2259,10 @@ std::vector<Val*> Index::getProducerStridedIndices(
22672259
kir::TensorIndex* Index::getProducerIndex(
22682260
TensorView* producer,
22692261
const TensorView* consumer,
2270-
const std::vector<kir::ForLoop*>& loops) {
2271-
auto strided_indices = getProducerStridedIndices(producer, consumer, loops);
2262+
const std::vector<kir::ForLoop*>& loops,
2263+
const std::unordered_map<IterDomain*, Val*>& override_index) {
2264+
auto strided_indices =
2265+
getProducerStridedIndices(producer, consumer, loops, override_index);
22722266
return SimplifyingIrBuilder::create<kir::TensorIndex>(
22732267
producer, strided_indices);
22742268
}

torch/csrc/jit/codegen/cuda/index_compute.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,8 @@ class Index {
309309
static std::vector<Val*> getNonGlobalProducerStridedIndices(
310310
TensorView* producer,
311311
const TensorView* consumer,
312-
const std::vector<kir::ForLoop*>& loops);
312+
const std::vector<kir::ForLoop*>& loops,
313+
const std::unordered_map<IterDomain*, Val*>& override_index = {});
313314

314315
// Consumer indexing if it's in shared or local memory
315316
static std::vector<Val*> getNonGlobalConsumerStridedIndices(
@@ -320,7 +321,8 @@ class Index {
320321
static std::vector<Val*> getGlobalProducerStridedIndices(
321322
TensorView* producer,
322323
const TensorView* consumer,
323-
const std::vector<kir::ForLoop*>& loops);
324+
const std::vector<kir::ForLoop*>& loops,
325+
const std::unordered_map<IterDomain*, Val*>& override_index = {});
324326

325327
// Consumer indexing if it's in global memory
326328
static std::vector<Val*> getGlobalConsumerStridedIndices(
@@ -344,7 +346,8 @@ class Index {
344346
static kir::TensorIndex* getProducerIndex(
345347
TensorView* producer,
346348
const TensorView* consumer,
347-
const std::vector<kir::ForLoop*>& loops);
349+
const std::vector<kir::ForLoop*>& loops,
350+
const std::unordered_map<IterDomain*, Val*>& override_index = {});
348351

349352
// Consumer index dispatch
350353
static kir::TensorIndex* getConsumerIndex(
@@ -358,7 +361,8 @@ class Index {
358361
static std::vector<Val*> getProducerStridedIndices(
359362
TensorView* producer,
360363
const TensorView* consumer,
361-
const std::vector<kir::ForLoop*>& loops);
364+
const std::vector<kir::ForLoop*>& loops,
365+
const std::unordered_map<IterDomain*, Val*>& override_index = {});
362366

363367
//! Returns a vector of strided indices mapped onto the (rfactor)
364368
//! root domain of a consumer tensor. The size of the returned

torch/csrc/jit/codegen/cuda/ir_internal_nodes.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,29 @@ class TORCH_CUDA_CU_API FullOp : public Expr {
5555

5656
class TORCH_CUDA_CU_API SelectOp : public Expr {
5757
public:
58-
SelectOp(IrBuilderPasskey, Val* out, Val* in, int dim, Val* index);
58+
SelectOp(
59+
IrBuilderPasskey,
60+
Val* out,
61+
Val* in,
62+
IterDomain* select_id,
63+
Val* index);
5964

6065
SelectOp(const SelectOp* src, IrCloner* ir_cloner);
6166

6267
Expr* shallowCopy() const override;
6368

6469
bool sameAs(const Statement* other) const override;
6570

66-
int getDim() const {
67-
return dim_;
71+
std::unordered_map<IterDomain*, Val*> getIndexOverridingMap() const {
72+
return {{select_id_, input(1)}};
73+
}
74+
75+
IterDomain* getSelectAxis() const {
76+
return select_id_;
6877
}
6978

7079
private:
71-
int dim_;
80+
IterDomain* select_id_;
7281
};
7382

7483
class TORCH_CUDA_CU_API ARangeOp : public Expr {

torch/csrc/jit/codegen/cuda/ir_iostream.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,8 @@ void IrPrinter::handle(const RNGOp* rop) {
506506

507507
void IrPrinter::handle(const SelectOp* sop) {
508508
indent() << sop->output(0) << "\n";
509-
indent() << " = select( " << sop->input(0) << ", dim = " << sop->getDim()
509+
indent() << " = select( " << sop->input(0)
510+
<< ", axis = " << sop->getSelectAxis()
510511
<< ", index = " << sop->input(1) << " )\n";
511512
}
512513

torch/csrc/jit/codegen/cuda/ir_nodes.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,20 +227,20 @@ SelectOp::SelectOp(
227227
IrBuilderPasskey passkey,
228228
Val* out,
229229
Val* in,
230-
int dim,
230+
IterDomain* select_id,
231231
Val* index)
232-
: Expr(passkey, ExprType::SelectOp), dim_(dim) {
232+
: Expr(passkey, ExprType::SelectOp), select_id_(select_id) {
233233
addInput(in);
234234
addInput(index);
235235
addOutput(out);
236236
}
237237

238238
SelectOp::SelectOp(const SelectOp* src, IrCloner* ir_cloner)
239-
: Expr(src, ir_cloner), dim_(src->dim_) {}
239+
: Expr(src, ir_cloner), select_id_(ir_cloner->clone(src->select_id_)) {}
240240

241241
Expr* SelectOp::shallowCopy() const {
242242
auto result =
243-
IrBuilder::create<SelectOp>(output(0), input(0), dim_, input(1));
243+
IrBuilder::create<SelectOp>(output(0), input(0), select_id_, input(1));
244244
result->copyPredicatesFrom(this);
245245
return result;
246246
}
@@ -253,7 +253,7 @@ bool SelectOp::sameAs(const Statement* other) const {
253253
return false;
254254
}
255255
const auto other_op = other->as<SelectOp>();
256-
if (dim_ != other_op->dim_) {
256+
if (!select_id_->sameAs(other_op->select_id_)) {
257257
return false;
258258
}
259259
return Expr::sameAs(other);

torch/csrc/jit/codegen/cuda/ir_utils.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,11 @@ struct SubstituteInExpr : public OptInDispatch {
276276
? substitute_
277277
: select_expr->output(0);
278278
expr_ = IrBuilder::create<SelectOp>(
279-
select_expr->container(), out, input, select_expr->getDim(), index);
279+
select_expr->container(),
280+
out,
281+
input,
282+
select_expr->getSelectAxis(),
283+
index);
280284
}
281285

282286
void handle(RNGOp* rng_expr) final {

torch/csrc/jit/codegen/cuda/lower_index.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@ namespace jit {
1313
namespace fuser {
1414
namespace cuda {
1515

16-
Val* IndexLowering::lowerSrcIndex(Val* src, Val* dst) const {
16+
Val* IndexLowering::lowerSrcIndex(
17+
Val* src,
18+
Val* dst,
19+
const std::unordered_map<IterDomain*, Val*>& override_index) const {
1720
if (auto tv = dynamic_cast<TensorView*>(src)) {
1821
TORCH_INTERNAL_ASSERT(dst->isA<TensorView>());
19-
return Index::getProducerIndex(tv, dst->as<TensorView>(), for_loops_);
22+
return Index::getProducerIndex(
23+
tv, dst->as<TensorView>(), for_loops_, override_index);
2024
} else {
2125
return src;
2226
}
@@ -193,10 +197,10 @@ void IndexLowering::handle(const TernaryOp* top) {
193197
}
194198

195199
void IndexLowering::handle(const SelectOp* sop) {
196-
const auto input = lowerSrcIndex(sop->input(0), sop->output(0));
200+
const auto input = lowerSrcIndex(
201+
sop->input(0), sop->output(0), sop->getIndexOverridingMap());
197202
const auto out = lowerDstIndex(sop->output(0));
198-
pushBack(
199-
IrBuilder::create<SelectOp>(out, input, sop->getDim(), sop->input(1)));
203+
pushBack(IrBuilder::create<UnaryOp>(UnaryOpType::Set, out, input));
200204
GpuLower::current()->propagateExprInfo(sop, back());
201205
}
202206

torch/csrc/jit/codegen/cuda/lower_index.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch {
6666

6767
void generate(const std::vector<Expr*>& exprs);
6868

69-
Val* lowerSrcIndex(Val* val, Val* dst) const;
69+
Val* lowerSrcIndex(
70+
Val* val,
71+
Val* dst,
72+
const std::unordered_map<IterDomain*, Val*>& override_index = {}) const;
7073

7174
Val* lowerDstIndex(Val* dst) const;
7275

0 commit comments

Comments
 (0)