Skip to content

Commit

Permalink
Swizzle op formulation for non-affine swizzles (#1441)
Browse files Browse the repository at this point in the history
  • Loading branch information
shmsong authored Jul 8, 2022
1 parent 3ed8330 commit acd5ed4
Show file tree
Hide file tree
Showing 41 changed files with 1,789 additions and 19 deletions.
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ libtorch_nvfuser_runtime_sources = [
"torch/csrc/jit/codegen/cuda/runtime/helpers.cu",
"torch/csrc/jit/codegen/cuda/runtime/index_utils.cu",
"torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu",
"torch/csrc/jit/codegen/cuda/runtime/swizzle.cu",
"torch/csrc/jit/codegen/cuda/runtime/memory.cu",
"torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu",
"torch/csrc/jit/codegen/cuda/runtime/tensor.cu",
Expand Down
48 changes: 48 additions & 0 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ class CudaKernelGenerator : private OptOutConstDispatch {
std::stringstream name;
if (val->isA<TensorView>()) {
name << "T";
} else if (val->isA<kir::IntPair>()) {
name << "ip";
} else {
name << typePrefix(val->dtype());
}
Expand Down Expand Up @@ -2379,6 +2381,52 @@ class CudaKernelGenerator : private OptOutConstDispatch {
indent() << "NVFUSER_UPDATE_MAGIC_ZERO\n";
}

void handle(const kir::Swizzle2DInt* swizzle_2d) {
TORCH_INTERNAL_ASSERT(print_inline_);
TORCH_INTERNAL_ASSERT(
swizzle_2d->swizzleType() != Swizzle2DType::NoSwizzle,
"Swizzle type undefined.");
if (print_inline_) {
code_ << swizzle_2d->swizzleType() << "({" << gen(swizzle_2d->inX())
<< "," << gen(swizzle_2d->inY()) << "} , "
<< "{" << gen(swizzle_2d->extentX()) << ","
<< gen(swizzle_2d->extentY()) << "})";
}
}

void handle(const kir::IntPair* int_pair) {
const auto def = int_pair->definition();
if (print_inline_) {
code_ << gen(def);
} else {
code_ << varName(int_pair);
}
}

void handle(const kir::PairSelect* pair_select) {
if (print_inline_) {
code_ << gen(pair_select->in());
} else {
indent() << gen(pair_select->out()) << " = " << gen(pair_select->in());
}

switch (pair_select->selection()) {
case kir::PairSelect::Selection::X:
code_ << ".x";
break;
case kir::PairSelect::Selection::Y:
code_ << ".y";
break;
default:
TORCH_INTERNAL_ASSERT(false, "unknown select")
break;
}

if (!print_inline_) {
code_ << ";\n";
}
}

private:
std::stringstream code_;
const kir::Kernel* kernel_;
Expand Down
34 changes: 34 additions & 0 deletions torch/csrc/jit/codegen/cuda/compute_at_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ IterDomainGraph::IterDomainGraph(Fusion* fusion) {
build(fusion);
}

//! Map corresponding inputs and outputs of swizzle op together
//! on the given disjoint set, if the given id is an output
//! of a swizzle operator.
//!
//! The current usage of swizzle operator is local to each tensor
//! itself, so they should not affect exact or permissive mapping
//! between iterdomains on different tensor domains.
//! TODO:
//! Exact mapping based index hoisting of swizzled iterdomains
//! is disabled currently and will be re-enabled in the next
//! few build out steps.
void mapMaybeSwizzleOp(
DisjointSets<IterDomain*>& disjoint_sets,
IterDomain* id) {
if (auto swizzle_2d = dynamic_cast<Swizzle2D*>(id->definition())) {
// Map each input to its corresponding output on the given
// disjoint set.
disjoint_sets.mapEntries(swizzle_2d->inX(), swizzle_2d->outX());
disjoint_sets.mapEntries(swizzle_2d->inY(), swizzle_2d->outY());
}
}

void IterDomainGraph::build(Fusion* fusion) {
// Initialize a node for every iteration domain
for (auto tv : ir_utils::allTvs(fusion)) {
Expand Down Expand Up @@ -178,6 +200,12 @@ void IterDomainGraph::build(Fusion* fusion) {
exact_nodes_.mapEntries(c_id, p_id);
consumers_.at(p_id).pushBack(c_id);
producers_.at(c_id).pushBack(p_id);

// Add the swizzle inputs to the same
// disjoint set as well if either c_id
// or p_id is swizzle output.
mapMaybeSwizzleOp(exact_nodes_, p_id);
mapMaybeSwizzleOp(exact_nodes_, c_id);
}

for (auto entry : permissive_c2p_map) {
Expand All @@ -189,6 +217,12 @@ void IterDomainGraph::build(Fusion* fusion) {
permissive_nodes_.mapEntries(c_id, p_id);
consumers_.at(p_id).pushBack(c_id);
producers_.at(c_id).pushBack(p_id);

// Add the swizzle inputs to the same
// disjoint set as well if either c_id
// or p_id is swizzle output.
mapMaybeSwizzleOp(permissive_nodes_, p_id);
mapMaybeSwizzleOp(permissive_nodes_, c_id);
}

// Make sure we always get root mapping for the permissive map. Because
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/codegen/cuda/contiguity.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ class ContigIDs : public OptInDispatch {

void handle(Merge* merge) override;

// TODO:
// Currently not propagating any contiguity information
// as contiguity is generally not preserved after swizzles.
// But in follow ups we could gradually add back a few special
// cases, depending on specific swizzle type and axes.
void handle(Swizzle2D* swizzle) override {}

IterDomain* getCAIndexConcreteId(IterDomain* id) const;

//! True if an ID is indexable.
Expand Down
60 changes: 60 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ void Val::dispatch(T handler, Val* val) {
case ValType::TensorIndex:
ptr(handler)->handle(val->as<kir::TensorIndex>());
return;
case ValType::IntPair:
ptr(handler)->handle(val->as<kir::IntPair>());
return;
default:
break;
}
Expand Down Expand Up @@ -126,6 +129,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::Merge:
ptr(handler)->handle(expr->as<Merge>());
return;
case ExprType::Swizzle2D:
ptr(handler)->handle(expr->as<Swizzle2D>());
return;
case ExprType::TransposeOp:
ptr(handler)->handle(expr->as<TransposeOp>());
return;
Expand Down Expand Up @@ -184,6 +190,12 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::AllocateFusedReduction:
ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
return;
case ExprType::Swizzle2DInt:
ptr(handler)->handle(expr->as<kir::Swizzle2DInt>());
return;
case ExprType::PairSelect:
ptr(handler)->handle(expr->as<kir::PairSelect>());
return;
default:
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
}
Expand Down Expand Up @@ -242,6 +254,9 @@ void Val::constDispatch(T handler, const Val* val) {
case ValType::TensorIndex:
ptr(handler)->handle(val->as<kir::TensorIndex>());
return;
case ValType::IntPair:
ptr(handler)->handle(val->as<kir::IntPair>());
return;
default:
break;
}
Expand Down Expand Up @@ -285,6 +300,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::Merge:
ptr(handler)->handle(expr->as<Merge>());
return;
case ExprType::Swizzle2D:
ptr(handler)->handle(expr->as<Swizzle2D>());
return;
case ExprType::TransposeOp:
ptr(handler)->handle(expr->as<TransposeOp>());
return;
Expand Down Expand Up @@ -343,6 +361,12 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::AllocateFusedReduction:
ptr(handler)->handle(expr->as<kir::AllocateFusedReduction>());
return;
case ExprType::Swizzle2DInt:
ptr(handler)->handle(expr->as<kir::Swizzle2DInt>());
return;
case ExprType::PairSelect:
ptr(handler)->handle(expr->as<kir::PairSelect>());
return;
default:
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
}
Expand Down Expand Up @@ -409,6 +433,9 @@ void Val::mutatorDispatch(T mutator, Val* val) {
case ValType::TensorIndex:
ptr(mutator)->mutate(val->as<kir::TensorIndex>());
return;
case ValType::IntPair:
ptr(mutator)->mutate(val->as<kir::IntPair>());
return;
default:
break;
}
Expand Down Expand Up @@ -452,6 +479,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::Merge:
ptr(mutator)->mutate(expr->as<Merge>());
return;
case ExprType::Swizzle2D:
ptr(mutator)->mutate(expr->as<Swizzle2D>());
return;
case ExprType::TransposeOp:
ptr(mutator)->mutate(expr->as<TransposeOp>());
return;
Expand Down Expand Up @@ -510,6 +540,12 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::AllocateFusedReduction:
ptr(mutator)->mutate(expr->as<kir::AllocateFusedReduction>());
return;
case ExprType::Swizzle2DInt:
ptr(mutator)->mutate(expr->as<kir::Swizzle2DInt>());
return;
case ExprType::PairSelect:
ptr(mutator)->mutate(expr->as<kir::PairSelect>());
return;
default:
TORCH_INTERNAL_ASSERT(false, "Unknown exprtype in dispatch!");
}
Expand Down Expand Up @@ -648,6 +684,9 @@ void OptOutConstDispatch::handle(const kir::Predicate* stmt) {
void OptOutConstDispatch::handle(const kir::TensorIndex* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::IntPair* stmt) {
unhandled(stmt);
}

// Exprs
void OptOutConstDispatch::handle(const UnaryOp* stmt) {
Expand Down Expand Up @@ -684,6 +723,9 @@ void OptOutConstDispatch::handle(const Split* stmt) {
void OptOutConstDispatch::handle(const Merge* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const Swizzle2D* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const TransposeOp* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -742,6 +784,12 @@ void OptOutConstDispatch::handle(const kir::GridWelford* stmt) {
void OptOutConstDispatch::handle(const kir::AllocateFusedReduction* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::Swizzle2DInt* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::PairSelect* stmt) {
unhandled(stmt);
}

void OptOutDispatch::unhandled(Statement*) {}

Expand Down Expand Up @@ -777,6 +825,9 @@ void OptOutDispatch::handle(kir::Predicate* stmt) {
void OptOutDispatch::handle(kir::TensorIndex* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::IntPair* stmt) {
unhandled(stmt);
}

// Exprs
void OptOutDispatch::handle(UnaryOp* stmt) {
Expand Down Expand Up @@ -813,6 +864,9 @@ void OptOutDispatch::handle(Split* stmt) {
void OptOutDispatch::handle(Merge* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(Swizzle2D* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(TransposeOp* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -871,6 +925,12 @@ void OptOutDispatch::handle(kir::GridWelford* stmt) {
void OptOutDispatch::handle(kir::AllocateFusedReduction* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::Swizzle2DInt* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::PairSelect* stmt) {
unhandled(stmt);
}

} // namespace cuda
} // namespace fuser
Expand Down
Loading

0 comments on commit acd5ed4

Please sign in to comment.