Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/cpp/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ if(USE_CUDA)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_rng.cu)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_utils.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_indexing_ops.cpp)
endif()

add_executable(test_jit
Expand Down
32 changes: 32 additions & 0 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,38 @@ TensorView* unaryOp(
return unaryOp(type, cast_v1)->as<TensorView>();
}

TensorView* select(TensorView* tv, int dim, Int* index) {
auto dom = TensorDomain::noReductions(tv->getMaybeRFactorDomain());
TORCH_CHECK(dom.size() > 0, "select can not be applied to 0d tensor.");

std::vector<IterDomain*> new_root;
new_root.reserve(dom.size() - 1);

if (dim < 0) {
dim += dom.size();
}

TORCH_CHECK(
dim >= 0 && dim < dom.size(),
"Select on invalid axis, received: ",
dim,
" however tensor view only has ",
dom.size(),
" non-reduction dims.");

for (auto i : c10::irange(dom.size())) {
if (i != dim) {
new_root.emplace_back(dom[i]->cloneWithoutRFactor());
}
}

auto td = IrBuilder::create<TensorDomain>(
new_root, TensorDomain::getContiguousContiguity(new_root));
auto out = IrBuilder::create<TensorView>(td, *tv->getDataType());
IrBuilder::create<SelectOp>(out, tv, dom[dim], index);
return out;
}

// TENSOR FACTORIES
TensorView* rand(const std::vector<Val*>& shape, DataType dtype) {
auto n = shape.size();
Expand Down
8 changes: 6 additions & 2 deletions torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ TORCH_CUDA_CU_API WelfordResult WelfordRaw(
// import IrBuilder just for this one interface.
Int* init_N = nullptr);

TORCH_CUDA_CU_API TensorView* select(TensorView* tv, int dim, Int* index);

// RNG OPERATIONS
TORCH_CUDA_CU_API TensorView* rand(
const std::vector<Val*>& shape,
Expand Down Expand Up @@ -375,12 +377,14 @@ TORCH_CUDA_CU_API Val* atan2(Val* v1, Val* v2);
TORCH_CUDA_CU_API TensorView* atan2(TensorView* v1, Val* v2);
TORCH_CUDA_CU_API TensorView* atan2(Val* v1, TensorView* v2);
TORCH_CUDA_CU_API TensorView* atan2(TensorView* v1, TensorView* v2);
// div
// div: promote to float for integer division, has the same semantics as the
// python's operator /
TORCH_CUDA_CU_API Val* div(Val* v1, Val* v2);
TORCH_CUDA_CU_API TensorView* div(TensorView* v1, Val* v2);
TORCH_CUDA_CU_API TensorView* div(Val* v1, TensorView* v2);
TORCH_CUDA_CU_API TensorView* div(TensorView* v1, TensorView* v2);
// cpp_div: similar to div, but don't promote to float
// cpp_div: similar to div, but don't promote to float, this has the same
// semantics as the C++'s operator /
TORCH_CUDA_CU_API Val* cpp_div(Val* v1, Val* v2);
TORCH_CUDA_CU_API TensorView* cpp_div(TensorView* v1, Val* v2);
TORCH_CUDA_CU_API TensorView* cpp_div(Val* v1, TensorView* v2);
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::TernaryOp:
ptr(handler)->handle(expr->as<TernaryOp>());
return;
case ExprType::SelectOp:
ptr(handler)->handle(expr->as<SelectOp>());
return;
case ExprType::RNGOp:
ptr(handler)->handle(expr->as<RNGOp>());
return;
Expand Down Expand Up @@ -296,6 +299,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::TernaryOp:
ptr(handler)->handle(expr->as<TernaryOp>());
return;
case ExprType::SelectOp:
ptr(handler)->handle(expr->as<SelectOp>());
return;
case ExprType::RNGOp:
ptr(handler)->handle(expr->as<RNGOp>());
return;
Expand Down Expand Up @@ -490,6 +496,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::TernaryOp:
ptr(mutator)->mutate(expr->as<TernaryOp>());
return;
case ExprType::SelectOp:
ptr(mutator)->mutate(expr->as<SelectOp>());
return;
case ExprType::RNGOp:
ptr(mutator)->mutate(expr->as<RNGOp>());
return;
Expand Down Expand Up @@ -749,6 +758,9 @@ void OptOutConstDispatch::handle(const BinaryOp* stmt) {
void OptOutConstDispatch::handle(const TernaryOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const SelectOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const RNGOp* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -905,6 +917,9 @@ void OptOutDispatch::handle(BinaryOp* stmt) {
void OptOutDispatch::handle(TernaryOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(SelectOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(RNGOp* stmt) {
unhandled(stmt);
}
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class EyeOp;
class UnaryOp;
class BinaryOp;
class TernaryOp;
class SelectOp;
class RNGOp;
class ReductionOp;
class GroupedReductionOp;
Expand Down Expand Up @@ -149,6 +150,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const UnaryOp* stmt);
virtual void handle(const BinaryOp* stmt);
virtual void handle(const TernaryOp* stmt);
virtual void handle(const SelectOp* stmt);
virtual void handle(const RNGOp* stmt);
virtual void handle(const ReductionOp* stmt);
virtual void handle(const GroupedReductionOp* stmt);
Expand Down Expand Up @@ -216,6 +218,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
virtual void handle(UnaryOp* stmt);
virtual void handle(BinaryOp* stmt);
virtual void handle(TernaryOp* stmt);
virtual void handle(SelectOp* stmt);
virtual void handle(RNGOp* stmt);
virtual void handle(ReductionOp* stmt);
virtual void handle(GroupedReductionOp* stmt);
Expand Down Expand Up @@ -324,6 +327,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
virtual void mutate(UnaryOp*);
virtual void mutate(BinaryOp*);
virtual void mutate(TernaryOp*);
virtual void mutate(SelectOp*);
virtual void mutate(RNGOp*);
virtual void mutate(ReductionOp*);
virtual void mutate(GroupedReductionOp*);
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2633,7 +2633,8 @@ ScheduleHeuristic SegmentCandidateFinder::deriveHeuristic(
SegmentedGroup* group) {
Fusion* fusion = segmented_fusion_->completeFusion();
auto h = tryMerge(fusion, runtime_info_, group);
TORCH_INTERNAL_ASSERT(h.has_value());
TORCH_INTERNAL_ASSERT(
h.has_value(), "Can not find a scheduler to schedule fusion segment");
return h.value();
}

Expand Down
89 changes: 50 additions & 39 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ Val* getProducerIndexWithHalo(
const TensorView* producer_tv,
size_t producer_axis,
Val* producer_index,
const TensorView* consumer_tv) {
const auto offset =
getProducerHaloOffset(producer_tv, producer_axis, consumer_tv);
const TensorView* consumer_tv,
bool is_overriden_index) {
const auto offset = is_overriden_index
? 0
: getProducerHaloOffset(producer_tv, producer_axis, consumer_tv);

if (offset == 0) {
return producer_index;
Expand Down Expand Up @@ -1460,7 +1462,8 @@ Val* hoistProducerIndex(
std::vector<Val*> Index::getGlobalProducerStridedIndices(
TensorView* producer_tv,
const TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops) {
const std::vector<kir::ForLoop*>& loops,
const std::unordered_map<IterDomain*, Val*>& override_index) {
FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalProducerIndex");

// Replay producer to look like consumer so we can index on producer since
Expand Down Expand Up @@ -1545,23 +1548,6 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
continue;
}

Val* root_ind = nullptr;
if (producer_indexing.indexMap().find(root_dom[dim]) !=
producer_indexing.indexMap().end()) {
root_ind = producer_indexing.indexMap().at(root_dom[dim]);
} else if (root_dom[dim]->isBroadcast()) {
root_ind = GpuLower::current()->kernel()->zeroVal();
}

TORCH_INTERNAL_ASSERT(
root_ind != nullptr,
"Couldn't find root mapping for ",
producer_tv->toString(),
" dim: ",
dim,
" id: ",
root_dom[dim]->toString());

if (producer_tv->domain()->contiguity()[dim]) {
// If contig, used the stored stride which may be the previous
// dimensions stride * previous dimensions size
Expand Down Expand Up @@ -1591,18 +1577,27 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
continue;
}

TORCH_INTERNAL_ASSERT(
Val* root_ind = nullptr;
auto override_it = override_index.find(root_dom[i]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about passing the optional map to getTensorIndexFromIdGraph to provide an initial ID-to-index map. That would be more consistent if we would want to allow the same initial map in consumer indexing.

That said, I think this is good enough for now given that the whole indexing code would be redesigned.

Pinging @csarofeen

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

if (override_it != override_index.end()) {
root_ind = override_it->second;
} else if (
producer_indexing.indexMap().find(root_dom[i]) !=
producer_indexing.indexMap().end(),
"Couldn't find root mapping for TV",
producer_tv->name(),
producer_indexing.indexMap().end()) {
root_ind = producer_indexing.indexMap().at(root_dom[i]);
} else if (root_dom[i]->isBroadcast()) {
root_ind = GpuLower::current()->kernel()->zeroVal();
}

TORCH_INTERNAL_ASSERT(
root_ind != nullptr,
"Couldn't find root mapping for ",
producer_tv->toString(),
" dim: ",
i,
" id: ",
root_dom[i]->toString());

auto root_ind = producer_indexing.indexMap().at(root_dom[i]);

// index hoist must be done before the adjustments for halo
root_ind = hoistProducerIndex(
root_dom[i],
Expand All @@ -1615,7 +1610,12 @@ std::vector<Val*> Index::getGlobalProducerStridedIndices(
loops,
root_ind);

root_ind = getProducerIndexWithHalo(producer_tv, i, root_ind, consumer_tv);
root_ind = getProducerIndexWithHalo(
producer_tv,
i,
root_ind,
consumer_tv,
override_index.count(root_dom[i]));

root_ind = getProducerIndexWithGather(
root_ind,
Expand Down Expand Up @@ -1686,7 +1686,8 @@ std::unordered_map<IterDomain*, IterDomain*> mapAllProducerDomainsToConsumer(
std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
TensorView* producer_tv,
const TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops) {
const std::vector<kir::ForLoop*>& loops,
const std::unordered_map<IterDomain*, Val*>& override_index) {
const auto gpu_lower = GpuLower::current();

// Replay producer to look like consumer so we can index on producer since our
Expand Down Expand Up @@ -1827,7 +1828,10 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
" id: ",
root_dom[i]->toString());

auto root_ind_i = index_map.at(root_dom[i]);
auto override_it = override_index.find(root_dom[i]);
auto root_ind_i =
(override_it != override_index.end() ? override_it->second
: index_map.at(root_dom[i]));

// index hoist must be done before the adjustments for halo
root_ind_i = hoistProducerIndex(
Expand All @@ -1841,8 +1845,12 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
loops,
root_ind_i);

root_ind_i =
getProducerIndexWithHalo(producer_tv, i, root_ind_i, consumer_tv);
root_ind_i = getProducerIndexWithHalo(
producer_tv,
i,
root_ind_i,
consumer_tv,
override_index.count(root_dom[i]));

root_ind_i = getProducerIndexWithGather(
root_ind_i,
Expand Down Expand Up @@ -2226,7 +2234,8 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
std::vector<Val*> Index::getProducerStridedIndices(
TensorView* producer,
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops) {
const std::vector<kir::ForLoop*>& loops,
const std::unordered_map<IterDomain*, Val*>& override_index) {
FUSER_PERF_SCOPE("GpuLower::Lower::Index::getProducerStridedIndices");
if (producer->domain()->noReductions().size() == 0) {
return std::vector<Val*>(
Expand All @@ -2236,11 +2245,11 @@ std::vector<Val*> Index::getProducerStridedIndices(

std::vector<Val*> strided_indices;
if (producer->getMemoryType() == MemoryType::Global) {
strided_indices =
getGlobalProducerStridedIndices(producer, consumer, loops);
strided_indices = getGlobalProducerStridedIndices(
producer, consumer, loops, override_index);
} else {
strided_indices =
getNonGlobalProducerStridedIndices(producer, consumer, loops);
strided_indices = getNonGlobalProducerStridedIndices(
producer, consumer, loops, override_index);
}

TORCH_INTERNAL_ASSERT(
Expand All @@ -2256,8 +2265,10 @@ std::vector<Val*> Index::getProducerStridedIndices(
kir::TensorIndex* Index::getProducerIndex(
TensorView* producer,
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops) {
auto strided_indices = getProducerStridedIndices(producer, consumer, loops);
const std::vector<kir::ForLoop*>& loops,
const std::unordered_map<IterDomain*, Val*>& override_index) {
auto strided_indices =
getProducerStridedIndices(producer, consumer, loops, override_index);
return SimplifyingIrBuilder::create<kir::TensorIndex>(
producer, strided_indices);
}
Expand Down
12 changes: 8 additions & 4 deletions torch/csrc/jit/codegen/cuda/index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ class Index {
static std::vector<Val*> getNonGlobalProducerStridedIndices(
TensorView* producer,
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
const std::vector<kir::ForLoop*>& loops,
const std::unordered_map<IterDomain*, Val*>& override_index = {});

// Consumer indexing if it's in shared or local memory
static std::vector<Val*> getNonGlobalConsumerStridedIndices(
Expand All @@ -320,7 +321,8 @@ class Index {
static std::vector<Val*> getGlobalProducerStridedIndices(
TensorView* producer,
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
const std::vector<kir::ForLoop*>& loops,
const std::unordered_map<IterDomain*, Val*>& override_index = {});

// Consumer indexing if it's in global memory
static std::vector<Val*> getGlobalConsumerStridedIndices(
Expand All @@ -344,7 +346,8 @@ class Index {
static kir::TensorIndex* getProducerIndex(
TensorView* producer,
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
const std::vector<kir::ForLoop*>& loops,
const std::unordered_map<IterDomain*, Val*>& override_index = {});

// Consumer index dispatch
static kir::TensorIndex* getConsumerIndex(
Expand All @@ -358,7 +361,8 @@ class Index {
static std::vector<Val*> getProducerStridedIndices(
TensorView* producer,
const TensorView* consumer,
const std::vector<kir::ForLoop*>& loops);
const std::vector<kir::ForLoop*>& loops,
const std::unordered_map<IterDomain*, Val*>& override_index = {});

//! Returns a vector of strided indices mapped onto the (rfactor)
//! root domain of a consumer tensor. The size of the returned
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ IR_BUILDER_INSTANTIATE(EyeOp)
IR_BUILDER_INSTANTIATE(UnaryOp)
IR_BUILDER_INSTANTIATE(BinaryOp)
IR_BUILDER_INSTANTIATE(TernaryOp)
IR_BUILDER_INSTANTIATE(SelectOp)
IR_BUILDER_INSTANTIATE(RNGOp)
IR_BUILDER_INSTANTIATE(ReductionOp)
IR_BUILDER_INSTANTIATE(GroupedReductionOp)
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_cloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ void IrCloner::handle(const TernaryOp* op) {
clone_ = IrBuilder::clone(op, this);
}

void IrCloner::handle(const SelectOp* op) {
clone_ = IrBuilder::clone(op, this);
}

void IrCloner::handle(const RNGOp* op) {
clone_ = IrBuilder::clone(op, this);
}
Expand Down
Loading