Skip to content

Commit

Permalink
Simplify matmul scheduling with the new transform propagator. (#1817)
Browse files Browse the repository at this point in the history
Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
  • Loading branch information
shmsong and zasdfgbnm committed Jul 23, 2022
1 parent bbc1fb9 commit 1cd9451
Show file tree
Hide file tree
Showing 13 changed files with 1,136 additions and 1,135 deletions.
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/matmul.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/registry.cpp",
"torch/csrc/jit/codegen/cuda/scheduler/utils.cpp",
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/jit/codegen/cuda/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -971,13 +971,13 @@ class CudaKernelGenerator : private OptOutConstDispatch {
}
}

std::string genArchString(MmaOptions options) {
std::string genArchString(MmaOptions::MacroType macro) {
std::stringstream ss;
if (isVolta(options.macro)) {
if (isVolta(macro)) {
ss << "Volta";
} else if (isTuring(options.macro)) {
} else if (isTuring(macro)) {
ss << "Turing";
} else if (isAmpere(options.macro)) {
} else if (isAmpere(macro)) {
ss << "Ampere";
} else {
TORCH_INTERNAL_ASSERT(false, "mma macro unknown arch");
Expand All @@ -988,7 +988,7 @@ class CudaKernelGenerator : private OptOutConstDispatch {
std::string genMmaOp(const MmaOp* mma, bool init = false) {
std::stringstream ss;
auto options = mma->options();
ss << genArchString(options) << "::";
ss << genArchString(options.macro) << "::";
if (init) {
ss << "init";
}
Expand Down
30 changes: 27 additions & 3 deletions torch/csrc/jit/codegen/cuda/ir_internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,22 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr {
//! Fused Matmul operation
class TORCH_CUDA_CU_API MmaOp : public Expr {
public:
// This is a temporary data structure to for the
// scheduling specific parameters that we still need
// to store on an mma node. Eventually will only be
// the mma macro type that will stay on the IR node
// after additional cleaning ups.
struct OptionsInMma {
MmaOptions::MacroType macro = MmaOptions::MacroType::NoMMA;
MmaOptions::MmaInputLayout operand_layout = MmaOptions::MmaInputLayout::TT;
int accumulator_stride = 0;

bool operator==(const OptionsInMma& other) const {
return macro == other.macro && operand_layout == other.operand_layout &&
accumulator_stride == other.accumulator_stride;
}
};

MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init);

MmaOp(
Expand All @@ -346,7 +362,7 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
Val* in_a,
Val* in_b,
Val* init,
MmaOptions options);
OptionsInMma options);

MmaOp(const MmaOp* src, IrCloner* ir_cloner);

Expand Down Expand Up @@ -379,15 +395,23 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
}

void configureOptions(MmaOptions options) {
options_ = options;
options_ = OptionsInMma();
TORCH_INTERNAL_ASSERT(
options.macro != MmaOptions::MacroType::NoMMA,
"Un-configured mma type from options.");
TORCH_INTERNAL_ASSERT(
options.accumulator_stride > 0, "Un-configured accumulator stride.");
options_->accumulator_stride = options.accumulator_stride;
options_->macro = options.macro;
options_->operand_layout = options.operand_layout;
}

private:
Val* const out_ = nullptr;
Val* const in_a_ = nullptr;
Val* const in_b_ = nullptr;
Val* const init_ = nullptr;
c10::optional<MmaOptions> options_ = c10::nullopt;
c10::optional<OptionsInMma> options_ = c10::nullopt;
};

class TORCH_CUDA_CU_API TransposeOp : public Expr {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/ir_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ MmaOp::MmaOp(
Val* in_a,
Val* in_b,
Val* init,
MmaOptions options)
OptionsInMma options)
: MmaOp(passkey, out, in_a, in_b, init) {
options_ = options;
}
Expand Down
19 changes: 15 additions & 4 deletions torch/csrc/jit/codegen/cuda/mma_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ namespace jit {
namespace fuser {
namespace cuda {

MmaOp* MmaOptions::mmaOp() const {
TORCH_INTERNAL_ASSERT(
accumulator_tv != nullptr && accumulator_tv->definition() != nullptr,
"Invalid accumulator_tv.");
auto mma_op = dynamic_cast<MmaOp*>(accumulator_tv->definition());
TORCH_INTERNAL_ASSERT(
mma_op != nullptr, "accumulator tv not an output of mma op");
return mma_op;
}

MmaBuilder::MmaBuilder(
MmaOptions::MacroType macro,
MatMulTileOptions gemm_tile) {
Expand Down Expand Up @@ -41,7 +51,7 @@ MmaBuilder& MmaBuilder::operand(MmaOptions::Operand a_or_b) {
// TODO: validate op config
MmaOptions MmaBuilder::build() const {
TORCH_CHECK(
option_.mma_op != nullptr,
option_.accumulator_tv != nullptr,
"Please configure accumulator tv before using swizzle options.")
return option_;
}
Expand All @@ -60,9 +70,10 @@ void MmaBuilder::accumulatorTv(TensorView* tv) {
TORCH_CHECK(
tv->getMemoryType() == MemoryType::Local, "Mma only outputs to register");
TORCH_CHECK(tv->definition(), "Input cannot be accumulator tv");
auto mma = dynamic_cast<MmaOp*>(tv->definition());
TORCH_CHECK(mma, "Requires mma op output for reduction tv");
option_.mma_op = mma;
TORCH_CHECK(
tv->definition()->isA<MmaOp>(),
"Requires mma op output for reduction tv");
option_.accumulator_tv = tv;
}

namespace {
Expand Down
18 changes: 16 additions & 2 deletions torch/csrc/jit/codegen/cuda/mma_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ struct GemmTile {
GemmTile operator/(const GemmTile& other) {
return GemmTile(m / other.m, n / other.n, k / other.k);
}

std::vector<int> toVector() {
return {m, n, k};
}
};

//! Utility data structure for recording gemm tiles
Expand Down Expand Up @@ -95,8 +99,18 @@ struct MmaOptions {
accumulator_stride == other.accumulator_stride;
}

// To be inferred by mma builder interface.
MmaOp* mma_op = nullptr;
// The accumulator tensorview register supplied by the
// scheduler interface. Each mma builder is responsible
// for the parameters of one mma op, so the options struct
// would need a pointer to keep track of which mma op it
// is describing.
// Tracking mma expressions would not be stable as the expression
// can get deleted by mutate passes.
TensorView* accumulator_tv = nullptr;

//! Returns the mma op that this options parameter list
//! is describing. See comment on accumulator_tv.
MmaOp* mmaOp() const;
};

//! User interface for configuring the mma and mma related
Expand Down
Loading

0 comments on commit 1cd9451

Please sign in to comment.