Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some indexing cleanups, Add eye support #1940

Merged
merged 26 commits into from
Sep 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,23 @@ TensorView* arange(Val* start, Val* end, Val* step, DataType dtype) {
return out;
}

TensorView* eye(Val* rows, Val* cols, DataType dtype) {
TORCH_CHECK(rows->getDataType() == DataType::Int, "rows must have type Int");
TORCH_CHECK(cols->getDataType() == DataType::Int, "cols must have type Int");
auto out = TensorViewBuilder()
.ndims(2)
.dtype(dtype)
.contiguity({true, true})
.shape(std::vector<Val*>{rows, cols})
.build();
IrBuilder::create<EyeOp>(out, dtype);
return out;
}

TensorView* eye(Val* size, DataType dtype) {
return eye(size, size, dtype);
}

// UNARY OPERATIONS

#define NVFUSER_DEFINE_UNARY_OP(op_name, op_type) \
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ TORCH_CUDA_CU_API TensorView* arange(
Val* end,
Val* step,
DataType dtype = DataType::Int);
TORCH_CUDA_CU_API TensorView* eye(Val* size, DataType dtype);
TORCH_CUDA_CU_API TensorView* eye(Val* rows, Val* cols, DataType dtype);

// UNARY OPERATIONS
// abs
Expand Down
10 changes: 9 additions & 1 deletion torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,12 +566,20 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}

void handle(const ARangeOp* aop) final {
auto index = genTensorIndex(aop->getLinearIndex()->as<kir::TensorIndex>());
auto index =
genTensorIndex(aop->getLinearLogicalIndex()->as<kir::TensorIndex>());
indent() << gen(aop->output(0)) << " = arange<" << aop->dtype() << ">";
code_ << "(" << index << ", " << gen(aop->start()) << ", "
<< gen(aop->step()) << ");\n";
}

void handle(const EyeOp* aop) final {
auto index1 = gen(aop->getIndex1());
auto index2 = gen(aop->getIndex2());
indent() << gen(aop->output(0)) << " = (" << aop->dtype() << ")";
code_ << "(" << index1 << " == " << index2 << ");\n";
}

void handle(const UnaryOp* uop) final {
bool is_vector_op = false;
size_t vector_word_size = 1;
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::ARangeOp:
ptr(handler)->handle(expr->as<ARangeOp>());
return;
case ExprType::EyeOp:
ptr(handler)->handle(expr->as<EyeOp>());
return;
case ExprType::UnaryOp:
ptr(handler)->handle(expr->as<UnaryOp>());
return;
Expand Down Expand Up @@ -290,6 +293,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::ARangeOp:
ptr(handler)->handle(expr->as<ARangeOp>());
return;
case ExprType::EyeOp:
ptr(handler)->handle(expr->as<EyeOp>());
return;
case ExprType::UnaryOp:
ptr(handler)->handle(expr->as<UnaryOp>());
return;
Expand Down Expand Up @@ -487,6 +493,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::ARangeOp:
ptr(mutator)->mutate(expr->as<ARangeOp>());
return;
case ExprType::EyeOp:
ptr(mutator)->mutate(expr->as<EyeOp>());
return;
case ExprType::UnaryOp:
ptr(mutator)->mutate(expr->as<UnaryOp>());
return;
Expand Down Expand Up @@ -749,6 +758,9 @@ void OptOutConstDispatch::handle(const FullOp* stmt) {
void OptOutConstDispatch::handle(const ARangeOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const EyeOp* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const UnaryOp* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -908,6 +920,9 @@ void OptOutDispatch::handle(FullOp* stmt) {
void OptOutDispatch::handle(ARangeOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(EyeOp* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(UnaryOp* stmt) {
unhandled(stmt);
}
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class NamedScalar;
// Exprs
class FullOp;
class ARangeOp;
class EyeOp;
class UnaryOp;
class BinaryOp;
class TernaryOp;
Expand Down Expand Up @@ -147,6 +148,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
// Exprs
virtual void handle(const FullOp* stmt);
virtual void handle(const ARangeOp* stmt);
virtual void handle(const EyeOp* stmt);
virtual void handle(const UnaryOp* stmt);
virtual void handle(const BinaryOp* stmt);
virtual void handle(const TernaryOp* stmt);
Expand Down Expand Up @@ -215,6 +217,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
// Exprs
virtual void handle(FullOp* stmt);
virtual void handle(ARangeOp* stmt);
virtual void handle(EyeOp* stmt);
virtual void handle(UnaryOp* stmt);
virtual void handle(BinaryOp* stmt);
virtual void handle(TernaryOp* stmt);
Expand Down Expand Up @@ -324,6 +327,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
// Exprs
virtual void mutate(FullOp*);
virtual void mutate(ARangeOp*);
virtual void mutate(EyeOp*);
virtual void mutate(UnaryOp*);
virtual void mutate(BinaryOp*);
virtual void mutate(TernaryOp*);
Expand Down
141 changes: 76 additions & 65 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1937,52 +1937,55 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
return strided_inds;
}

std::vector<Val*> Index::getLinearIndex(
TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops) {
template <typename func_t>
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need to be in this PR. But this function looks more like a domain guard factory.

auto evaluateWithOverridenContiguity(
TensorView* tv,
bool contiguity,
const func_t& functor) -> decltype(functor()) {
// Use domain guard to ignore the contiguity of
// consumer tv.
TensorDomain* consumer_tv_no_contiguity_domain = nullptr;
auto contiguity_vector =
std::vector<bool>(consumer_tv->getMaybeRFactorDomain().size(), true);
if (consumer_tv->hasRFactor()) {
consumer_tv_no_contiguity_domain = IrBuilder::create<TensorDomain>(
consumer_tv->getRootDomain(),
consumer_tv->getRFactorDomain(),
consumer_tv->domain()->domain(),
TensorDomain* domain_with_specified_contiguity = nullptr;
std::vector<bool> contiguity_vector(
tv->getMaybeRFactorDomain().size(), contiguity);
if (tv->hasRFactor()) {
domain_with_specified_contiguity = IrBuilder::create<TensorDomain>(
tv->getRootDomain(),
tv->getRFactorDomain(),
tv->domain()->domain(),
contiguity_vector);
} else {
consumer_tv_no_contiguity_domain = IrBuilder::create<TensorDomain>(
consumer_tv->getRootDomain(),
consumer_tv->domain()->domain(),
contiguity_vector);
domain_with_specified_contiguity = IrBuilder::create<TensorDomain>(
tv->getRootDomain(), tv->domain()->domain(), contiguity_vector);
}

ir_utils::TVDomainGuard domain_guard(
consumer_tv, consumer_tv_no_contiguity_domain);
ir_utils::TVDomainGuard domain_guard(tv, domain_with_specified_contiguity);

// TODO:
// More optimization on the underlying tensor layout
// will be done in a follow up.
return getGlobalConsumerStridedIndices(consumer_tv, loops);
return functor();
}

std::vector<Val*> Index::getGlobalConsumerStridedIndices(
const TensorView* consumer_tv,
std::vector<Val*> Index::getLinearLogicalIndex(
TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops) {
FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex");

auto gpu_lower = GpuLower::current();

auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv);
return evaluateWithOverridenContiguity(consumer_tv, true, [&]() {
return getGlobalConsumerStridedIndices(consumer_tv, loops);
});
}

auto consumer_indexing = index_from_id_graph.index;
std::vector<Val*> Index::getPerDimLogicalIndex(
TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops) {
return evaluateWithOverridenContiguity(consumer_tv, false, [&]() {
IndexFromIdGraph index_from_id_graph =
getTensorIndexFromIdGraph(loops, consumer_tv);
return getRootIndices(consumer_tv, loops, index_from_id_graph);
});
}

std::vector<Val*> Index::getStrides(const TensorView* tv) {
// Indices should now be mapped onto IterDomains in consumer, so just grab
// and use them.
auto root_dom = consumer_tv->getMaybeRFactorDomain();
auto root_dom = tv->getMaybeRFactorDomain();

// TODO: Abstract stride logic to reuse with producer indexing
std::vector<Val*> strides(
root_dom.size(), GpuLower::current()->kernel()->oneVal());
{
Expand All @@ -1993,39 +1996,21 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
continue;
}
std::stringstream ss;
ss << "T" << consumer_tv->name() << ".stride[" << stride_i++ << "]";
ss << "T" << tv->name() << ".stride[" << stride_i++ << "]";
strides[i] =
SimplifyingIrBuilder::create<NamedScalar>(ss.str(), DataType::Int);
}
}

TORCH_INTERNAL_ASSERT(
root_dom.size() == consumer_tv->domain()->contiguity().size());
TORCH_INTERNAL_ASSERT(root_dom.size() == tv->domain()->contiguity().size());
Val* cur_contig_stride = GpuLower::current()->kernel()->oneVal();
for (const auto i : c10::irange(root_dom.size())) {
auto dim = root_dom.size() - i - 1;
if (root_dom[dim]->isReduction() || root_dom[dim]->isStride()) {
continue;
}

Val* root_ind = nullptr;
if (consumer_indexing.indexMap().find(root_dom[dim]) !=
consumer_indexing.indexMap().end()) {
root_ind = consumer_indexing.indexMap().at(root_dom[dim]);
} else if (root_dom[dim]->isBroadcast()) {
root_ind = GpuLower::current()->kernel()->zeroVal();
}

TORCH_INTERNAL_ASSERT(
root_ind != nullptr,
"Couldn't find root mapping for ",
consumer_tv->toString(),
" dim: ",
dim,
" id: ",
root_dom[dim]->toString());

if (consumer_tv->domain()->contiguity()[dim]) {
if (tv->domain()->contiguity()[dim]) {
// If contig, used the stored stride which may be the previous
// dimensions stride * previous dimensions size
strides[dim] = cur_contig_stride;
Expand All @@ -2041,12 +2026,18 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
strides[dim], getHaloExtentOfRootAxis(root_dom[dim]));
}
}
return strides;
}

auto vectorize_shift =
loops.empty() ? nullptr : loops.back()->vectorize_shift();
std::vector<Val*> Index::getRootIndices(
const TensorView* tv,
const std::vector<kir::ForLoop*>& loops,
const IndexFromIdGraph& index_from_id_graph) {
auto gpu_lower = GpuLower::current();
auto root_dom = tv->getMaybeRFactorDomain();
auto indexing = index_from_id_graph.index;

// Global striding
std::vector<Val*> strided_inds(
std::vector<Val*> root_inds(
root_dom.size(), GpuLower::current()->kernel()->zeroVal());
for (const auto i : c10::irange(root_dom.size())) {
// See a comment in indexing to root domains in getGlobalProducerIndex.
Expand All @@ -2057,35 +2048,55 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
}

TORCH_INTERNAL_ASSERT(
consumer_indexing.indexMap().find(root_dom[i]) !=
consumer_indexing.indexMap().end(),
indexing.indexMap().find(root_dom[i]) != indexing.indexMap().end(),
"Couldn't find root mapping for ",
consumer_tv->toString(),
tv->toString(),
" dim: ",
i,
" id: ",
root_dom[i]->toString());

auto root_ind = consumer_indexing.indexMap().at(root_dom[i]);
auto root_ind = indexing.indexMap().at(root_dom[i]);

// index hoist must be done before the adjustments for halo
root_ind = hoistConsumerIndex(
root_dom[i],
consumer_tv,
consumer_indexing,
tv,
indexing,
index_from_id_graph.resolved_loop_domains,
index_from_id_graph.initial_concrete_index_map,
loops,
root_ind);

root_ind = SimplifyingIrBuilder::addExpr(
root_ind, getGlobalConsumerOffsetWithPartialSplit(root_dom[i]));
root_inds[i] = root_ind;
}
return root_inds;
}

if (root_ind->isZeroInt()) {
std::vector<Val*> Index::getGlobalConsumerStridedIndices(
const TensorView* consumer_tv,
const std::vector<kir::ForLoop*>& loops) {
FUSER_PERF_SCOPE("GpuLower::Lower::getGlobalConsumerIndex");

auto index_from_id_graph = getTensorIndexFromIdGraph(loops, consumer_tv);
auto consumer_indexing = index_from_id_graph.index;
auto strides = getStrides(consumer_tv);
auto root_inds = getRootIndices(consumer_tv, loops, index_from_id_graph);

// Global striding
auto vectorize_shift =
loops.empty() ? nullptr : loops.back()->vectorize_shift();
std::vector<Val*> strided_inds(
root_inds.size(), GpuLower::current()->kernel()->zeroVal());
for (const auto i : c10::irange(root_inds.size())) {
if (root_inds[i]->isZeroInt()) {
continue;
} else {
auto strided_ind = SimplifyingIrBuilder::mulExpr(root_ind, strides[i]);
if (i == root_dom.size() - 1 && vectorize_shift != nullptr) {
auto strided_ind =
SimplifyingIrBuilder::mulExpr(root_inds[i], strides[i]);
if (i == strides.size() - 1 && vectorize_shift != nullptr) {
strided_inds[i] =
SimplifyingIrBuilder::addExpr(strided_ind, vectorize_shift);
} else {
Expand Down
Loading