Skip to content

Commit

Permalink
Extend mma dimension and layout checking to support strided batched m…
Browse files Browse the repository at this point in the history
…atmul and tensor contractions (csarofeen#1761)

Co-authored-by: Christian Sarofeen <csarofeen@nvidia.com>
  • Loading branch information
shmsong and csarofeen committed Jun 27, 2022
1 parent a054b3e commit ecc7a87
Show file tree
Hide file tree
Showing 4 changed files with 406 additions and 136 deletions.
12 changes: 12 additions & 0 deletions torch/csrc/jit/codegen/cuda/mma_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ MmaBuilder& MmaBuilder::operand(MmaOptions::Operand a_or_b) {

// TODO: validate op config
MmaOptions MmaBuilder::build() const {
TORCH_CHECK(
option_.mma_op != nullptr,
"Please configure accumulator tv before using swizzle options.")
return option_;
}

Expand All @@ -53,6 +56,15 @@ void MmaBuilder::configureMma(TensorView* mma_output) const {
mma->configureOptions(option_);
}

void MmaBuilder::accumulatorTv(TensorView* tv) {
TORCH_CHECK(
tv->getMemoryType() == MemoryType::Local, "Mma only outputs to register");
TORCH_CHECK(tv->definition(), "Input cannot be accumulator tv");
auto mma = dynamic_cast<MmaOp*>(tv->definition());
TORCH_CHECK(mma, "Requires mma op output for reduction tv");
option_.mma_op = mma;
}

namespace {

// Utility to get ldmatrix direction a mma layout and operand
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/codegen/cuda/mma_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ struct MmaOptions {
operand == other.operand &&
accumulator_stride == other.accumulator_stride;
}

// To be inferred by mma builder interface.
MmaOp* mma_op = nullptr;
};

//! User interface for configuring the mma and mma related
Expand Down Expand Up @@ -127,6 +130,10 @@ class TORCH_CUDA_CU_API MmaBuilder {
//! specified mma option.
LoadStoreOpType ldMatrix() const;

//! Store the accumulator tv register reference in mma builder
//! to avoid automatic matching of which mma ops.
void accumulatorTv(TensorView* tv);

//! Fill in mma options in scheduling time.
//! Each mma op in Fusion IR must be configured once before lowering.
//! Mma options are configuration parameters used in lowering to mma
Expand Down
Loading

0 comments on commit ecc7a87

Please sign in to comment.