Skip to content

Commit

Permalink
Nonaffine swizzle formulation ep.2: Loop swizzle variant. (#1826)
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong committed Jul 31, 2022
1 parent d863d69 commit 501f4aa
Show file tree
Hide file tree
Showing 15 changed files with 1,346 additions and 1,968 deletions.
8 changes: 3 additions & 5 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2414,11 +2414,9 @@ class CudaKernelGenerator : private OptOutConstDispatch {

void handle(const kir::IntPair* int_pair) {
const auto def = int_pair->definition();
if (print_inline_) {
code_ << gen(def);
} else {
code_ << varName(int_pair);
}
TORCH_INTERNAL_ASSERT(
def != nullptr, "no support for un-inlined int pair yet.");
code_ << gen(def);
}

void handle(const kir::PairSelect* pair_select) {
Expand Down
95 changes: 63 additions & 32 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,43 @@ void IndexCompute::handle(Swizzle2D* swizzle_2d) {
const auto out_x_ind = out_x_it->second;
const auto out_y_ind = out_y_it->second;

// Actual swizzle operation is handled via IndexSwizzle pass
// all behavior in this pass is directly forward through the
// index and extent.
index_map_[in_x_id] = out_x_ind;
index_map_[in_y_id] = out_y_ind;
extent_map_[in_y_id] = getExtent(out_y_id);
extent_map_[in_x_id] = getExtent(out_x_id);
if (swizzle_mode_ == SwizzleMode::NoSwizzle ||
swizzle_mode_ != swizzle_2d->swizzleMode()) {
// Handle inactive swizzles by just passing through index
// and extend information.

TORCH_INTERNAL_ASSERT(
index_map_.count(in_x_id) == index_map_.count(in_y_id),
"input index should be either both defined or both undefined");
if (index_map_.count(in_x_id)) {
// Only propagate original index through if
// the input index hasn't been computed.
// TODO:
// This part should be cleaner once we remove the
// second index traversal pass.
return;
}
index_map_[in_x_id] = out_x_ind;
index_map_[in_y_id] = out_y_ind;
extent_map_[in_y_id] = getExtent(out_y_id);
extent_map_[in_x_id] = getExtent(out_x_id);
} else {
// Generate integer swizzle math if the
// swizzle is activated. See also
// [Note on swizzle mode].

auto out_pair = IrBuilder::swizzle2DIntExpr(
out_x_ind,
out_y_ind,
getExtent(out_x_id),
getExtent(out_y_id),
swizzle_2d->swizzleType());

index_map_[in_x_id] =
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::X);
index_map_[in_y_id] =
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::Y);
}
}

void IndexCompute::handle(Expr* e) {
Expand Down Expand Up @@ -616,9 +646,31 @@ IndexCompute::IndexCompute(
reference_halo_extent_map_(std::move(reference_halo_extent_map)) {
FUSER_PERF_SCOPE("GpuLower::Lower::IndexCompute::IndexCompute");
concrete_id_pass_ = true;
swizzle_mode_ = SwizzleMode::Loop;
}

void IndexCompute::run(const LoopIndexing& loop_indexing) {
// Apply loop swizzles if there are any that outputs to
// the loop domains.
// Currently only support loop swizzles that directly output
// to concrete loop domains and these are validated in
// validate swizzle pass.
// TODO:
// will gradually enable replaying and mapping of loop
// swizzles in the IR infrastructure and once that's piped
// through this part of logic will be removed.
std::unordered_set<Expr*> visited;
for (auto loop_id : loop_indexing.loopDomains()) {
auto loop_id_def = loop_id->definition();
if (loop_id_def != nullptr && loop_id_def->isA<Swizzle2D>()) {
if (visited.insert(loop_id_def).second) {
handle(loop_id_def);
}
}
}

// Run through the loop indexing expressions and generate
// the indexing integer math for the concrete ids.
for (auto expr : loop_indexing.getBackwardExprList()) {
handle(expr);
}
Expand Down Expand Up @@ -955,6 +1007,7 @@ void IndexSwizzle::run() {
UpdateLeafIndices update_leaves(td_, indexMap(), extentMap());
index_map_ = update_leaves.indexMap();
extent_map_ = update_leaves.extentMap();
IndexCompute::swizzle_mode_ = SwizzleMode::Data;
IndexCompute::run();
}
}
Expand All @@ -969,7 +1022,8 @@ void IndexSwizzle::handle(Expr* e) {
return swizzled_ids_.find(id) != swizzled_ids_.end();
}) ||
(e->isA<Swizzle2D>() &&
e->as<Swizzle2D>()->swizzleType() != Swizzle2DType::NoSwizzle);
e->as<Swizzle2D>()->swizzleType() != Swizzle2DType::NoSwizzle &&
e->as<Swizzle2D>()->swizzleMode() == SwizzleMode::Data);
if (!needs_update) {
return;
}
Expand All @@ -983,8 +1037,6 @@ void IndexSwizzle::handle(Expr* e) {
void IndexSwizzle::handle(Swizzle2D* swizzle_2d) {
auto out_x_id = swizzle_2d->outX();
auto out_y_id = swizzle_2d->outY();
auto in_x_id = swizzle_2d->inX();
auto in_y_id = swizzle_2d->inY();

auto out_x_it = index_map_.find(out_x_id);
auto out_y_it = index_map_.find(out_y_id);
Expand All @@ -998,28 +1050,7 @@ void IndexSwizzle::handle(Swizzle2D* swizzle_2d) {
out_x_it != index_map_.end() && out_y_it != index_map_.end(),
"Swizzle output indices were not propagated through");

const auto out_x_ind = out_x_it->second;
const auto out_y_ind = out_y_it->second;

// Can propagate zero only for a few
// swizzle types (TODO)

if (swizzle_2d->swizzleType() != Swizzle2DType::NoSwizzle) {
auto out_pair = IrBuilder::swizzle2DIntExpr(
out_x_ind,
out_y_ind,
getExtent(out_x_id),
getExtent(out_y_id),
swizzle_2d->swizzleType());

index_map_[in_x_id] =
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::X);
index_map_[in_y_id] =
IrBuilder::pairSelectExpr(out_pair, kir::PairSelect::Selection::Y);

swizzled_ids_.insert(in_x_id);
swizzled_ids_.insert(in_y_id);
}
IndexCompute::handle(swizzle_2d);
}

// Used for local and shared index mapping. Returns a map from loops
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/codegen/cuda/index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ class IndexCompute : public BackwardVisitor {
// map rather than the actual IDs used in the ID expressions.
bool concrete_id_pass_ = false;

// Mode of swizzle that are activated in this index compute
// instance. Will treat swizzles of different mode as no-op.
// Currently data mode swizzles are handled same as before in IndexSwizzle
// pass, while loop mode swizzles are handled early on in concrete indexing
// pass. See also [Note on swizzle mode]
SwizzleMode swizzle_mode_ = SwizzleMode::NoSwizzle;

public:
const std::unordered_map<IterDomain*, Val*>& indexMap() const {
return index_map_;
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,11 @@ class TORCH_CUDA_CU_API TensorView : public Val {

//! Swizzle the rectangular tile defined by the iterdomains corresponding
//! to the 2 given indices.
TensorView* swizzle(Swizzle2DType swizzle_type, int x, int y);
TensorView* swizzle(
Swizzle2DType swizzle_type,
int x,
int y,
SwizzleMode swizzle_mode = SwizzleMode::Data);

// WARNING: rFactor does not return this TensorView, ir returns a new
// tensorview consumed by this!
Expand Down
63 changes: 58 additions & 5 deletions torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,8 @@ class TORCH_CUDA_CU_API IterDomain : public Val {
static std::pair<IterDomain*, IterDomain*> swizzle(
Swizzle2DType swizzle_type,
IterDomain* in_x,
IterDomain* in_y);
IterDomain* in_y,
SwizzleMode swizzle_mode = SwizzleMode::Data);

bool isMmaSwizzled() const {
return is_mma_swizzled_;
Expand Down Expand Up @@ -1198,7 +1199,11 @@ class TORCH_CUDA_CU_API TensorDomain : public Val {

//! Applies 2D swizzle on a rectangular tile defined by
//! a pair of iterdomains contained in this domain.
void swizzle(Swizzle2DType swizzle_type, int x, int y);
void swizzle(
Swizzle2DType swizzle_type,
int x,
int y,
SwizzleMode swizzle_mode = SwizzleMode::Data);

// Transform TensorView according to merge and split transformations
TensorDomain* view(
Expand Down Expand Up @@ -1339,7 +1344,8 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr {
IterDomain* out_y,
IterDomain* in_x,
IterDomain* in_y,
Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle);
Swizzle2DType swizzle_type = Swizzle2DType::NoSwizzle,
SwizzleMode swizzle_mode = SwizzleMode::Data);

Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner);

Expand All @@ -1359,10 +1365,14 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr {
return in_y_;
}

const auto& swizzleType() const {
auto swizzleType() const {
return swizzle_type_;
}

auto swizzleMode() const {
return swizzle_mode_;
}

bool sameAs(const Statement* other) const override;

private:
Expand All @@ -1377,7 +1387,50 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr {

// The type of predefined 1-to-1 functions
// used for swizzling math.
Swizzle2DType swizzle_type_;
Swizzle2DType swizzle_type_ = Swizzle2DType::NoSwizzle;

// Swizzle mode of this swizzle instance.
// [Note on swizzle mode]
// On the current implementations we support two modes of
// swizzle math, namely, data mode and loop mode.
// `Data` mode swizzling is a swizzle that will change the
// data layout in shared memory, likely in global memory buffers
// as well in the future. see also IndexSwizzle in index_compute.cpp.
//
// Most important use cases are transpose bank conflict removal, and mma
// swizzled shared memory layout. Example illustrated in 1D case:
//
// for (int i = 0; i<I; i++){
// # This is a `Data` mode swizzle.
// Tshared [swizzled(i)] = Tin[i];
// }
// # Now Tshared holds swizzled data, i.e. the data layout of
// Tshared does not map to Tin with affine relationships.
//
// for(int i=0;i<I;i++){
// Tout = Tshared[swizzled(i)];
// }
//
// `Loop` mode swizzling does not affect the data layout of any buffer
// but only permutes the iteration order of serial or parallel loop.
// This is useful when we want to designate non-affine mapping of thread
// to data or we want to generate non-affine loops.
// Exampe illustrated in 1D case:
// for (int i = 0; i<I; i++){
// # This is a `Loop` mode swizzle
// Tshared [swizzled(i)] = Tin[swizzled(i)];
// }
// # Now Tshared holds normal data, i.e. it still has
// the same data layout as if the swizzle wasn't there.
//
// # Consumers of Tshared does not need to know about the
// loop swizzle at previous op if not inlined.
// for(int i=0;i<I;i++){
// Tout = Tshared[i];
// }
// TODO: Loop swizzles eventually will be piped through in all mappings
// and replay of the fusion IR infrastructure.
SwizzleMode swizzle_mode_ = SwizzleMode::Data;
};

//! Integer value which has a special name
Expand Down
22 changes: 15 additions & 7 deletions torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,8 @@ std::pair<IterDomain*, IterDomain*> IterDomain::stridedSplit(int factor) {
std::pair<IterDomain*, IterDomain*> IterDomain::swizzle(
Swizzle2DType swizzle_type,
IterDomain* in_x,
IterDomain* in_y) {
IterDomain* in_y,
SwizzleMode swizzle_mode) {
TORCH_CHECK(
!in_x->extent()->isZeroInt() && !in_y->extent()->isZeroInt(),
"Invalid swizzling of a empty dimension.");
Expand All @@ -1319,7 +1320,7 @@ std::pair<IterDomain*, IterDomain*> IterDomain::swizzle(
IterDomain* out_y = IterDomainBuilder(in_y).build();

IrBuilder::create<Swizzle2D>(
in_x->container(), out_x, out_y, in_x, in_y, swizzle_type);
in_x->container(), out_x, out_y, in_x, in_y, swizzle_type, swizzle_mode);

return std::make_pair(out_x, out_y);
}
Expand Down Expand Up @@ -1790,7 +1791,11 @@ std::vector<IterDomain*> TensorDomain::orderedAs(
return reordered_domain;
}

void TensorDomain::swizzle(Swizzle2DType swizzle_type, int x, int y) {
void TensorDomain::swizzle(
Swizzle2DType swizzle_type,
int x,
int y,
SwizzleMode swizzle_mode) {
TORCH_INTERNAL_ASSERT(nDims() > 0, "Tried to do merge on a 0-dim domain");

TORCH_CHECK(
Expand All @@ -1808,7 +1813,7 @@ void TensorDomain::swizzle(Swizzle2DType swizzle_type, int x, int y) {
IterDomain* axis_out_y = nullptr;

std::tie(axis_out_x, axis_out_y) =
IterDomain::swizzle(swizzle_type, axis_x, axis_y);
IterDomain::swizzle(swizzle_type, axis_x, axis_y, swizzle_mode);

domain_.erase(domain_.begin() + x);
domain_.insert(domain_.begin() + x, axis_out_x);
Expand Down Expand Up @@ -2039,13 +2044,15 @@ Swizzle2D::Swizzle2D(
IterDomain* out_y,
IterDomain* in_x,
IterDomain* in_y,
Swizzle2DType swizzle_type)
Swizzle2DType swizzle_type,
SwizzleMode swizzle_mode)
: Expr(passkey, ExprType::Swizzle2D),
out_x_{out_x},
out_y_{out_y},
in_x_{in_x},
in_y_{in_y},
swizzle_type_(swizzle_type) {
swizzle_type_(swizzle_type),
swizzle_mode_(swizzle_mode) {
addOutput(out_x);
addOutput(out_y);
addInput(in_x);
Expand All @@ -2071,7 +2078,8 @@ Swizzle2D::Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner)
out_y_(ir_cloner->clone(src->out_y_)),
in_x_(ir_cloner->clone(src->in_x_)),
in_y_(ir_cloner->clone(src->in_y_)),
swizzle_type_(src->swizzle_type_) {}
swizzle_type_(src->swizzle_type_),
swizzle_mode_(src->swizzle_mode_) {}

NamedScalar::NamedScalar(
IrBuilderPasskey passkey,
Expand Down
Loading

0 comments on commit 501f4aa

Please sign in to comment.