Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add matmul benchmark #2007

Merged
merged 1 commit into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/cpp/nvfuser/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ if(USE_CUDA)
softmax_backward.cpp
scale_bias_relu.cpp
transpose.cpp
matmul.cpp
timm.cpp
utils.cpp
main.cpp)
Expand Down
357 changes: 357 additions & 0 deletions benchmarks/cpp/nvfuser/matmul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,357 @@
#include <torch/csrc/jit/codegen/cuda/arith.h>
#include <torch/csrc/jit/codegen/cuda/executor.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
#include <torch/csrc/jit/codegen/cuda/scheduler/matmul.h>

#include <benchmark/benchmark.h>

#include <cuda_runtime.h>

#include <benchmarks/cpp/nvfuser/utils.h>

using namespace torch::jit::fuser::cuda;

bool cudaArchGuardShouldSkip(int required_major, int required_minor) {
int capability_major = at::cuda::getCurrentDeviceProperties()->major;
int capability_minor = at::cuda::getCurrentDeviceProperties()->minor;

if (capability_major < required_major ||
(capability_major == required_major &&
capability_minor < required_minor)) {
return true;
}
return false;
}

bool hasRequiredSmemSize(size_t required_size) {
// Only checking device 0
return at::cuda::getDeviceProperties(0)->sharedMemPerBlockOptin >=
required_size;
}

#define NVFUSER_BENCHMARK_ARCH_SMEM_GUARD( \
REQUIRED_MAJOR, REQUIRED_MINOR, SMEM_SIZE, STATE) \
if (cudaArchGuardShouldSkip(REQUIRED_MAJOR, REQUIRED_MINOR) || \
!hasRequiredSmemSize(SMEM_SIZE)) { \
STATE.SkipWithError("Unsupported arch or not enough smem!"); \
return; \
}

// util to track support matmul operand layout.
using MatmulLayout = MmaOptions::MmaInputLayout;

static constexpr std::array<MatmulLayout, 3> kAllSupportedLayout = {
MatmulLayout::TT,
MatmulLayout::NT,
MatmulLayout::TN};

// Generic interface to get matmul op with the given layout.
TensorView* matmul(TensorView* a, TensorView* b, MatmulLayout layout) {
TORCH_CHECK(
a->nDims() == 2 && b->nDims() == 2, "only pure matmuls for these tests");
TensorView *tv2 = nullptr, *tv0b = nullptr, *tv1b = nullptr;
switch (layout) {
case MatmulLayout::TT:
tv0b = broadcast(a, {false, false, true});
tv1b = broadcast(b, {true, false, false});
tv2 = fusedMultiplySum(tv0b, tv1b, {1});
break;
case MatmulLayout::TN:
tv0b = broadcast(a, {false, true, false});
tv1b = broadcast(b, {true, false, false});
tv2 = fusedMultiplySum(tv0b, tv1b, {2});
break;
case MatmulLayout::NT:
tv0b = broadcast(a, {false, false, true});
tv1b = broadcast(b, {false, true, false});
tv2 = fusedMultiplySum(tv0b, tv1b, {0});
break;
default:
TORCH_CHECK(false, "unsupported data layout.");
}
return tv2;
}

// Utility to generate matmul input tensors based on given layout
at::Tensor atMatmul(at::Tensor a, at::Tensor b, MatmulLayout layout) {
switch (layout) {
case MatmulLayout::TT:
return a.matmul(b);
case MatmulLayout::TN:
return a.matmul(b.t());
case MatmulLayout::NT:
return a.t().matmul(b);
default:
TORCH_CHECK(false, "unsupported data layout.");
}
return at::Tensor();
}

// Utility to generate reference results based on given layout
std::pair<at::Tensor, at::Tensor> fp16MatmulAtInput(
int M,
int N,
int K,
MatmulLayout layout) {
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);

switch (layout) {
case MatmulLayout::TT:
return std::make_pair(
at::randn({M, K}, options), at::randn({K, N}, options));
case MatmulLayout::TN:
return std::make_pair(
at::randn({M, K}, options), at::randn({N, K}, options));
case MatmulLayout::NT:
return std::make_pair(
at::randn({K, M}, options), at::randn({K, N}, options));
default:
TORCH_CHECK(false, "unsupported data layout.");
}
return std::make_pair(at::Tensor(), at::Tensor());
}

// TODO: separate compute and schedule definition once the can schedule
// logic and pattern matching is ready.
void setupMatmul(Fusion* fusion, MatmulLayout layout, MatmulParam params) {
// Only hgemm on the initial setup
auto a = makeContigTensor(2, DataType::Half);
auto b = makeContigTensor(2, DataType::Half);

auto c = matmul(a, b, layout);

fusion->addInput(a);
fusion->addInput(b);
fusion->addOutput(c);

scheduleMatmul(c, a, b, params);
}

static void SingleMatmulBase(
benchmark::State& benchmark_state,
MatmulLayout layout,
MatmulParam params) {
std::vector<int64_t> input_mnk{
benchmark_state.range(0),
benchmark_state.range(1),
benchmark_state.range(2)};

auto fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);

// Define fusion graph
setupMatmul(fusion, layout, params);

// inputs
at::manual_seed(0);

// Tensor inputs
auto inputs = fp16MatmulAtInput(
input_mnk.at(0), input_mnk.at(1), input_mnk.at(2), layout);

KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(
{inputs.first, inputs.second});

// Always use 32b indexing mode for now.
TORCH_INTERNAL_ASSERT(args.getIndexMode() == KernelIndexMode::INT32);

// Compile kernel
FusionExecutor fe;
fe.compileFusion(fusion, args, LaunchParams());

// Warm up run
auto outputs = fe.runFusion({inputs.first, inputs.second});
fe.setMeasureKernelTimeFlag(true);

// Sync everything up before we start
for (auto _ : benchmark_state) {
clearL2Cache();
auto outputs = fe.runFusion({inputs.first, inputs.second});
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
// cpu while benchmarking.
cudaDeviceSynchronize();

// TODO: FLOPS calculation
}

static void EagerModeMatmul(
benchmark::State& benchmark_state,
MatmulLayout layout) {
std::vector<int64_t> input_mnk{
benchmark_state.range(0),
benchmark_state.range(1),
benchmark_state.range(2)};

at::manual_seed(0);

auto inputs = fp16MatmulAtInput(
input_mnk.at(0), input_mnk.at(1), input_mnk.at(2), layout);

// warm up run
auto outputs = atMatmul(inputs.first, inputs.second, layout);

for (auto _ : benchmark_state) {
clearL2Cache();
CudaKernelTimer timer;
outputs = atMatmul(inputs.first, inputs.second, layout);
benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
}
// Sync everything up before we're finished, don't want to run ahead on the
// cpu while benchmarking.
cudaDeviceSynchronize();
}

// Actual benchmarking
// -----------------------------------------------------------------

size_t getSmemSize(GemmTile cta_tile, int stage_number) {
return ((cta_tile.m * cta_tile.k) + (cta_tile.n * cta_tile.k)) *
dataTypeSize(DataType::Half) * stage_number;
}

// TODO: this part eventually will be automated by heuristics
MatmulParam getMatmulParams(
GemmTile cta_tile,
int stage_number,
MatmulLayout layout) {
MatMulTileOptions gemm_tile;
gemm_tile.cta_tile = cta_tile;
// TODO: pipe through split K
gemm_tile.warp_tile = GemmTile(64, 64, cta_tile.k);
gemm_tile.instruction_tile = GemmTile(16, 16, 16);

// Collect mma swizzle info
auto mma_builder =
MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile)
.layout(layout);

MatmulParam params(mma_builder);
params.tile_sizes = gemm_tile;
params.async_gmem_load_operands = true;
params.double_buffer_options.double_buffer_smem_write = true;
params.double_buffer_options.double_buffer_smem_read = true;
params.double_buffer_options.smem_double_buffer_stage = stage_number;

return params;
}

static void Nvfuser_Matmul_4warp3stage(
benchmark::State& benchmark_state,
MatmulLayout layout) {
auto cta_tile = GemmTile(128, 128, 32);
int number_of_stage = 3;

auto params = getMatmulParams(cta_tile, number_of_stage, layout);

NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state);

// Run benchmark:
SingleMatmulBase(benchmark_state, layout, params);
}

static void Nvfuser_Matmul_8warp3stage(
benchmark::State& benchmark_state,
MatmulLayout layout) {
auto cta_tile = GemmTile(256, 128, 32);
int number_of_stage = 3;

auto params = getMatmulParams(cta_tile, number_of_stage, layout);

NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state);

// Run benchmark:
SingleMatmulBase(benchmark_state, layout, params);
}

static void Nvfuser_Matmul_4warp4stage(
benchmark::State& benchmark_state,
MatmulLayout layout) {
auto cta_tile = GemmTile(128, 128, 32);
int number_of_stage = 4;

auto params = getMatmulParams(cta_tile, number_of_stage, layout);

NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state);

// Run benchmark:
SingleMatmulBase(benchmark_state, layout, params);
}

static void Nvfuser_Matmul_8warp4stage(
benchmark::State& benchmark_state,
MatmulLayout layout) {
auto cta_tile = GemmTile(256, 128, 32);
int number_of_stage = 4;

auto params = getMatmulParams(cta_tile, number_of_stage, layout);

NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state);

// Run benchmark:
SingleMatmulBase(benchmark_state, layout, params);
}

// ----------------------------- Benchmark Instantiation-------

// Common utils:
#define NO_TILE_QUANTIZATION_ARGS \
ArgsProduct( \
{{2048}, {3456}, benchmark::CreateDenseRange(512, 4096, /*step=*/512)}) \
->Unit(benchmark::kMicrosecond) \
->UseManualTime();

#define ForAllLayouts(run) \
run(TT, MatmulLayout::TT); \
run(TN, MatmulLayout::TN); \
run(NT, MatmulLayout::NT)

// Instantiations:
#define Nvfuser_4warp3stage_test(layout_label, layout) \
BENCHMARK_CAPTURE( \
Nvfuser_Matmul_4warp3stage, \
no_quant_nvfuser_4warp_##layout_label, \
layout) \
->NO_TILE_QUANTIZATION_ARGS

#define Nvfuser_8warp3stage_test(layout_label, layout) \
BENCHMARK_CAPTURE( \
Nvfuser_Matmul_8warp3stage, \
no_quant_nvfuser_8warp_##layout_label, \
layout) \
->NO_TILE_QUANTIZATION_ARGS

#define Nvfuser_4warp4stage_test(layout_label, layout) \
BENCHMARK_CAPTURE( \
Nvfuser_Matmul_4warp4stage, \
no_quant_nvfuser_4warp_##layout_label, \
layout) \
->NO_TILE_QUANTIZATION_ARGS

#define Nvfuser_8warp4stage_test(layout_label, layout) \
BENCHMARK_CAPTURE( \
Nvfuser_Matmul_8warp4stage, \
no_quant_nvfuser_8warp_##layout_label, \
layout) \
->NO_TILE_QUANTIZATION_ARGS

#define Eagermode_test(layout_label, layout) \
BENCHMARK_CAPTURE( \
EagerModeMatmul, no_quant_eagermode_##layout_label, layout) \
->NO_TILE_QUANTIZATION_ARGS

ForAllLayouts(Nvfuser_4warp3stage_test);
ForAllLayouts(Nvfuser_4warp4stage_test);
ForAllLayouts(Nvfuser_8warp3stage_test);
ForAllLayouts(Nvfuser_8warp4stage_test);
ForAllLayouts(Eagermode_test);