Skip to content

Commit

Permalink
Ampere async copy ep.2: circular buffering extension to support pipel…
Browse files Browse the repository at this point in the history
…ined matmul operand load (#1827)
  • Loading branch information
shmsong authored Jul 31, 2022
1 parent e0ae11a commit d863d69
Show file tree
Hide file tree
Showing 26 changed files with 612 additions and 46 deletions.
16 changes: 14 additions & 2 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
// Double buffered local tensors need indexed initialization,
// so will need to use `arraySet` option.
if (out_tv->getMemoryType() == MemoryType::Local &&
!out_tv->isDoubleBuffered()) {
!(out_tv->isDoubleBuffered() || out_tv->isCircularBuffered())) {
// Vectorized initialization
indent() << varName(out_tv) << ".set(" << gen(uop->in()) << ");\n";
} else {
Expand Down Expand Up @@ -2344,7 +2344,19 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}

void handle(const kir::CpAsyncWait* cpasync_wait) final {
indent() << "Ampere::cpAsyncBarrier();\n";
if (cpasync_wait->keepStages() > 0) {
// Perform partial sync, see comment on kir::CpAsyncWait.
indent() << "Ampere::cpAsyncPartialBarrier<" << cpasync_wait->keepStages()
<< ">();\n";
} else {
// Perform sync all, see comment on kir::CpAsyncWait.
indent() << "Ampere::cpAsyncBarrier();\n";
}
}

void handle(const kir::CpAsyncCommit* cpasync_wait) final {
// Commit inflight cp.async transfers. See comment on kir::CpAsyncCommit.
indent() << "Ampere::cpAsyncCommit();\n";
}

void handle(const kir::GridSync* sync) final {
Expand Down
15 changes: 15 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ void Expr::dispatch(T handler, Expr* expr) {
case ExprType::CpAsyncWait:
ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
return;
case ExprType::CpAsyncCommit:
ptr(handler)->handle(expr->as<kir::CpAsyncCommit>());
return;
case ExprType::InitMagicZero:
ptr(handler)->handle(expr->as<kir::InitMagicZero>());
return;
Expand Down Expand Up @@ -334,6 +337,9 @@ void Expr::constDispatch(T handler, const Expr* expr) {
case ExprType::CpAsyncWait:
ptr(handler)->handle(expr->as<kir::CpAsyncWait>());
return;
case ExprType::CpAsyncCommit:
ptr(handler)->handle(expr->as<kir::CpAsyncCommit>());
return;
case ExprType::InitMagicZero:
ptr(handler)->handle(expr->as<kir::InitMagicZero>());
return;
Expand Down Expand Up @@ -513,6 +519,9 @@ void Expr::mutatorDispatch(T mutator, Expr* expr) {
case ExprType::CpAsyncWait:
ptr(mutator)->mutate(expr->as<kir::CpAsyncWait>());
return;
case ExprType::CpAsyncCommit:
ptr(mutator)->mutate(expr->as<kir::CpAsyncCommit>());
return;
case ExprType::InitMagicZero:
ptr(mutator)->mutate(expr->as<kir::InitMagicZero>());
return;
Expand Down Expand Up @@ -757,6 +766,9 @@ void OptOutConstDispatch::handle(const kir::GridSync* stmt) {
void OptOutConstDispatch::handle(const kir::CpAsyncWait* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::CpAsyncCommit* stmt) {
unhandled(stmt);
}
void OptOutConstDispatch::handle(const kir::InitMagicZero* stmt) {
unhandled(stmt);
}
Expand Down Expand Up @@ -898,6 +910,9 @@ void OptOutDispatch::handle(kir::GridSync* stmt) {
void OptOutDispatch::handle(kir::CpAsyncWait* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::CpAsyncCommit* stmt) {
unhandled(stmt);
}
void OptOutDispatch::handle(kir::InitMagicZero* stmt) {
unhandled(stmt);
}
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class Allocate;
class BlockSync;
class GridSync;
class CpAsyncWait;
class CpAsyncCommit;
class ForLoop;
class IfThenElse;
class GridReduction;
Expand Down Expand Up @@ -163,6 +164,7 @@ class TORCH_CUDA_CU_API OptOutConstDispatch : public PolymorphicBase {
virtual void handle(const kir::BlockSync*);
virtual void handle(const kir::GridSync*);
virtual void handle(const kir::CpAsyncWait*);
virtual void handle(const kir::CpAsyncCommit*);
virtual void handle(const kir::InitMagicZero*);
virtual void handle(const kir::UpdateMagicZero*);
virtual void handle(const kir::ForLoop*);
Expand Down Expand Up @@ -225,6 +227,7 @@ class TORCH_CUDA_CU_API OptOutDispatch : public PolymorphicBase {
virtual void handle(kir::BlockSync* stmt);
virtual void handle(kir::GridSync* stmt);
virtual void handle(kir::CpAsyncWait* stmt);
virtual void handle(kir::CpAsyncCommit* stmt);
virtual void handle(kir::InitMagicZero* stmt);
virtual void handle(kir::UpdateMagicZero* stmt);
virtual void handle(kir::ForLoop* stmt);
Expand Down Expand Up @@ -328,6 +331,7 @@ class TORCH_CUDA_CU_API OptOutMutator : public PolymorphicBase {
virtual void mutate(kir::BlockSync*);
virtual void mutate(kir::GridSync*);
virtual void mutate(kir::CpAsyncWait*);
virtual void mutate(kir::CpAsyncCommit*);
virtual void mutate(kir::InitMagicZero*);
virtual void mutate(kir::UpdateMagicZero*);
virtual void mutate(kir::ForLoop*);
Expand Down
52 changes: 40 additions & 12 deletions torch/csrc/jit/codegen/cuda/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1148,8 +1148,11 @@ indexMapFromTV(
}

if (loop == double_buffer_loop) {
auto stage_depth =
GpuLower::current()->doubleBufferInfo().getStageDepthFor(
loop->iter_domain());
idx = SimplifyingIrBuilder::addExpr(
idx, GpuLower::current()->kernel()->oneVal());
idx, SimplifyingIrBuilder::create<Int>(stage_depth - 1));
}

loop_to_ind_map[loop] = idx;
Expand Down Expand Up @@ -1811,14 +1814,16 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
}
}

if (producer_tv->isDoubleBuffered()) {
if (producer_tv->isDoubleBuffered() || producer_tv->isCircularBuffered()) {
auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(
producer_tv, loops, true);
if (db_loop != nullptr) {
auto stage_depth = gpu_lower->doubleBufferInfo().getStageDepthFor(
db_loop->iter_domain());
auto loop_index =
db_loop->isTrivial() ? db_loop->start() : db_loop->index();
auto db_switch_index = SimplifyingIrBuilder::modExpr(
loop_index, SimplifyingIrBuilder::create<Int>(2));
loop_index, SimplifyingIrBuilder::create<Int>(stage_depth));
auto original_alloc_size =
gpu_lower->doubleBufferInfo().getOriginalAllocSize(producer_tv);
auto db_strided_index =
Expand Down Expand Up @@ -2077,14 +2082,36 @@ std::vector<Val*> Index::getNonGlobalConsumerStridedIndices(
TORCH_INTERNAL_ASSERT(
strided_inds.size() == consumer_tv->getMaybeRFactorDomain().size());

if (consumer_tv->isDoubleBuffered()) {
auto db_loop = gpu_lower->doubleBufferInfo().getDoubleBufferLoop(
consumer_tv, loops, true);
if (db_loop != nullptr) {
auto db_switch_index = SimplifyingIrBuilder::subExpr(
gpu_lower->kernel()->oneVal(),
SimplifyingIrBuilder::modExpr(
db_loop->index(), SimplifyingIrBuilder::create<Int>(2)));
if (consumer_tv->isDoubleBuffered() || consumer_tv->isCircularBuffered()) {
auto db_loop =
gpu_lower->doubleBufferInfo().getDoubleBufferLoop(consumer_tv, loops);
auto stage_depth =
gpu_lower->doubleBufferInfo().getStageDepthFor(db_loop->iter_domain());
bool is_circular_buffer_loop = stage_depth > 2;
bool is_prolog =
db_loop->doubleBufferLoopStage() == DoubleBufferLoopStage::Prolog;

Val* db_switch_index = nullptr;

// In double buffered we don't materialize the prolog loop as there will
// be only one iteration. In circular buffer case we materialize the
// prolog loop as well covering the first N-1 iterations, N being the
// stage depth.
if (!is_prolog || is_circular_buffer_loop) {
if (is_prolog && is_circular_buffer_loop) {
// The buffer switching logic is the same as original index
// in the case of circular buffer prolog.
db_switch_index = db_loop->index();
} else {
// Switching index generated for main loop or epilog component.
db_switch_index = SimplifyingIrBuilder::modExpr(
SimplifyingIrBuilder::addExpr(
db_loop->index(),
SimplifyingIrBuilder::create<Int>(stage_depth - 1)),
SimplifyingIrBuilder::create<Int>(stage_depth));
}

// Use the generated switching buffer index to access the buffer space.
auto original_alloc_size =
gpu_lower->doubleBufferInfo().getOriginalAllocSize(consumer_tv);
auto db_strided_index =
Expand Down Expand Up @@ -2119,7 +2146,8 @@ std::vector<Val*> Index::getProducerStridedIndices(
TORCH_INTERNAL_ASSERT(
strided_indices.size() ==
producer->getMaybeRFactorDomain().size() +
(producer->isDoubleBuffered() ? 1 : 0));
(producer->isDoubleBuffered() || producer->isCircularBuffered() ? 1
: 0));

return strided_indices;
}
Expand Down
23 changes: 23 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,26 @@ class TORCH_CUDA_CU_API TensorView : public Val {
// Apply double buffering transformation
void doubleBuffer();

// Apply circular buffering transformation
void circularBuffer(unsigned int number_of_stage);

// Returns true if this tensor is double buffered.
bool isDoubleBuffered() const {
return is_double_buffered_;
}

// Returns true if this tensor is circular buffered.
bool isCircularBuffered() const {
return is_circular_buffered_;
}

// Returns the depth of circular buffering if applicable.
unsigned int circularBufferDepth() const {
TORCH_INTERNAL_ASSERT(
is_circular_buffered_, toString(), "not circular buffered");
return circular_buffer_stage_;
}

//! Transforms the innermost iterdomains according to the given mma swizzle,
//! this should be used on the tvs that are either inputs/outputs of an
//! MmaOp, or any tv's that are involved in prolog/epilog fusions and need to
Expand Down Expand Up @@ -509,6 +525,13 @@ class TORCH_CUDA_CU_API TensorView : public Val {
SwizzleType swizzle_type_ = SwizzleType::NoSwizzle;
std::vector<IterDomain*> axes_to_swizzle_;
bool is_double_buffered_ = false;

//! Indicates if the tensor is circular buffered.
bool is_circular_buffered_ = false;

//! Indicates the circular buffering stage depth if applicable.
unsigned int circular_buffer_stage_ = 0;

// special handling for CPU based zero-dim tensors (i.e. CPU Tensors that only
// have one value). This is only used if on an input value, otherwise ignored.
// This is important as special handling because these "scalars" should be
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/codegen/cuda/ir_iostream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,10 @@ void IrPrinter::handle(const kir::BlockSync* node) {
}

void IrPrinter::handle(const kir::CpAsyncWait* node) {
indent() << "CPASYNC_WAIT(" << node->keepStages() << ")\n";
}

void IrPrinter::handle(const kir::CpAsyncCommit* node) {
indent() << "CPASYNC_WAIT()\n";
}

Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/ir_iostream.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class TORCH_CUDA_CU_API IrPrinter : public OptInConstDispatch {
void handle(const kir::BlockSync*) final;
void handle(const kir::GridSync*) final;
void handle(const kir::CpAsyncWait*) final;
void handle(const kir::CpAsyncCommit*) final;
void handle(const kir::InitMagicZero*) final;
void handle(const kir::UpdateMagicZero*) final;
void handle(const kir::AllocateFusedReduction*) final;
Expand Down
11 changes: 9 additions & 2 deletions torch/csrc/jit/codegen/cuda/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,15 @@ GridSync::GridSync(
sync_dims_(sync_dims),
sync_buffer_(sync_buffer) {}

CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey)
: Expr(passkey, ExprType::CpAsyncWait) {
CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages)
: Expr(passkey, ExprType::CpAsyncWait), keep_stages_(keep_stages) {
TORCH_INTERNAL_ASSERT(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
}

CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey)
: Expr(passkey, ExprType::CpAsyncCommit) {
TORCH_INTERNAL_ASSERT(
passkey.ir_container_->isA<kir::Kernel>(),
"IR type only valid for Kernel container.");
Expand Down
23 changes: 20 additions & 3 deletions torch/csrc/jit/codegen/cuda/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Allocate;
class BlockSync;
class GridSync;
class CpAsyncWait;
class CpAsyncCommit;
class InitMagicZero;
class UpdateMagicZero;
class ForLoop;
Expand Down Expand Up @@ -258,11 +259,27 @@ class TORCH_CUDA_CU_API BlockSync final : public Expr {
};

// CpAsyncWait represents wait intrinsics for cp.async
// TODO: expand to support different wait modes of the intrinsic
// as the analysis passes build out.
class TORCH_CUDA_CU_API CpAsyncWait final : public Expr {
public:
explicit CpAsyncWait(IrBuilderPasskey passkey);
explicit CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages = 0);

//! Returns the remaining number of stages that are not synchronized
//! after this op.
unsigned int keepStages() const {
return keep_stages_;
}

private:
//! Number of stage to leave un-sync'ed by this op.
unsigned int keep_stages_ = 0;
};

// CpAsyncCommit represents commit intrinsics for cp.async
// A commit intrinsic communicates delimiter of transaction groups
// to the async load hardware. Example usage see [Cicular buffer].
class TORCH_CUDA_CU_API CpAsyncCommit final : public Expr {
public:
explicit CpAsyncCommit(IrBuilderPasskey passkey);
};

// Synchronize all blocks in device, implies cooperative group launch is
Expand Down
8 changes: 6 additions & 2 deletions torch/csrc/jit/codegen/cuda/lower_allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ class AllocationInserter : public kir::ExprMutator {

// Double the allocation size if double-buffered. Record the
// original size for indexing.
if (info.buffer->isDoubleBuffered()) {
if (info.buffer->isDoubleBuffered() || info.buffer->isCircularBuffered()) {
Val* original_alloc_size = nullptr;
for (auto alloc_dim : alloc_dims) {
if (original_alloc_size == nullptr) {
Expand All @@ -420,7 +420,11 @@ class AllocationInserter : public kir::ExprMutator {
}
GpuLower::current()->doubleBufferInfo().setOriginalAllocSize(
info.buffer, original_alloc_size);
alloc_dims.push_back(IrBuilder::create<Int>(2));
int double_buffer_stage = 2;
if (info.buffer->isCircularBuffered()) {
double_buffer_stage = info.buffer->circularBufferDepth();
}
alloc_dims.push_back(IrBuilder::create<Int>(double_buffer_stage));
}

// Create the allocation node
Expand Down
Loading

0 comments on commit d863d69

Please sign in to comment.