Skip to content

Commit

Permalink
HipBLASLt GEMM benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
suryajasper committed Aug 5, 2024
1 parent c04c59d commit e1cf1ad
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 175 deletions.
27 changes: 16 additions & 11 deletions .github/workflows/run_bench.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ jobs:
cd $GITHUB_WORKSPACE/build
sudo ninja
cd $GITHUB_WORKSPACE
mkdir -p results
- name: Set up Python environment
run: |
Expand All @@ -61,27 +62,31 @@ jobs:
cd $GITHUB_WORKSPACE
source $GITHUB_WORKSPACE/venv/bin/activate
for device in $(seq 3 7); do (sudo $GITHUB_WORKSPACE/build/gemm-bench --device=$device &); done
./gb run --backends=rocblas --repeat=1 --output=rocblas.hdf
./gb run --backends=rocblas --repeat=1 --output=results/rocblas.hdf
sudo pkill -f gemm-bench
deactivate
- name: Run IREE Benchmarks
- name: Run HipBLASLt Benchmarks
run: |
sudo pkill -f gemm-bench
cd $GITHUB_WORKSPACE
source $GITHUB_WORKSPACE/venv/bin/activate
for device in $(seq 3 7); do (sudo $GITHUB_WORKSPACE/build/gemm-bench --device=$device &); done
./gb run --backends=iree --repeat=1 --output=iree.hdf
./gb run --backends=hipblaslt --repeat=1 --output=results/hipblaslt.hdf
sudo pkill -f gemm-bench
deactivate
- name: Upload IREE benchmark results
uses: actions/upload-artifact@v4
with:
name: iree-benchmark-results
path: iree.hdf
- name: Run IREE Benchmarks
run: |
cd $GITHUB_WORKSPACE
source $GITHUB_WORKSPACE/venv/bin/activate
for device in $(seq 3 7); do (sudo $GITHUB_WORKSPACE/build/gemm-bench --device=$device &); done
./gb run --backends=iree --repeat=1 --output=results/iree.hdf
sudo pkill -f gemm-bench
deactivate
- name: Upload RocBLAS benchmark results
- name: Upload benchmark results
uses: actions/upload-artifact@v4
with:
name: rocblas-benchmark-results
path: rocblas.hdf
name: benchmark-results
path: ./results/
4 changes: 3 additions & 1 deletion .vscode/c_cpp_properties.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
"${workspaceFolder}/third_party/iree/runtime/src",
"${workspaceFolder}/src/benchmark",
"/opt/rocm/include",
"/opt/rocm/include/hip",
"/opt/rocm/include/hip/amd_detail",
"/usr/include",
"/usr/include/clang/15/include",
"/usr/include/c++/12",
"/usr/include/x86_64-linux-gnu/c++/12/",
"/usr/include/python3.10"
],
"defines": [],
"defines": ["__HIP__=1"],
"cStandard": "c17",
"cppStandard": "c++14",
"intelliSenseMode": "linux-clang-x64",
Expand Down
2 changes: 1 addition & 1 deletion gemmbench/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def roofline(results=None, **kwargs):
flops = 0
bytes = 1

if 'sharkfa' in result_file:
if 'sharkfa' in result_file or 'torch' in result_file:
B, H, S_Q, S_KV, DH = item['A'], item['B'], item['M'], item['N'], item['K']
if result_file.split('.')[-1] == 'hdf':
item['A'], item['B'] = B, H = ord(item['A'][0]), ord(item['B'][0])
Expand Down
4 changes: 1 addition & 3 deletions gemmbench/gbm_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from ctypes import c_int, c_char, c_double, c_float, POINTER
import numpy as np

# Load the shared library
lib = ctypes.CDLL('./libgemm_bench.so') # Adjust the path as needed
lib = ctypes.CDLL('./libgemm_bench.so')

class Problem(ctypes.Structure):
_fields_ = [
Expand Down Expand Up @@ -43,7 +42,6 @@ class Result(ctypes.Structure):
("device", c_int)
]

# Define function prototypes
lib.initialize_gemm_pipeline.argtypes = []
lib.initialize_gemm_pipeline.restype = None

Expand Down
2 changes: 1 addition & 1 deletion meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ iree_deps = [

gemm_bench_src += 'src/benchmark/DataInitialization.cpp'
gemm_bench_src += 'src/benchmark/run-rocblas.cpp'
# gemm_bench_src += 'src/benchmark/run-hipblaslt.cpp'
gemm_bench_src += 'src/benchmark/run-hipblaslt.cpp'
gemm_bench_src += 'src/benchmark/run_flashattention.cpp'
gemm_bench_src += 'src/benchmark/run_iree.cpp'
gemm_bench_deps += [rocblas_dep, rocsmi_dep, hipblas_dep, hipblaslt_dep]
Expand Down
4 changes: 2 additions & 2 deletions src/benchmark/gemm-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ namespace GEMMBench

std::unordered_map<std::string, GEMMPipeline*> benches{
{"rocblas", new RocBLASGEMMBench()},
{"hipblaslt", new HipBLASLtGEMMBench()},
{"iree", new IREEGEMMBench()},
{"sharkfa", new SHARKFABench()},
// {"hipblaslt", new HipBLASLtGEMMBench()},
};

/**
Expand Down Expand Up @@ -88,7 +88,7 @@ namespace GEMMBench
int run(int device)
{
std::cout << "Initializing tensors with trig..." << std::endl;
GEMMTrigInitializer initializer;
GEMMNullInitializer initializer;
GEMMData data("fp32", 1e9, &initializer);

std::cout << "Running on " << benches.size() << " benches" << std::endl;
Expand Down
73 changes: 42 additions & 31 deletions src/benchmark/gemm-bench.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
#include <string>
#include <vector>

// #include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hipblaslt/hipblaslt.h>

#include "FrequencyMonitor.hpp"
#include "Timer.hpp"

class GEMMData;

Expand Down Expand Up @@ -141,36 +145,43 @@ namespace GEMMBench
Result run(Problem problem) override;
};

// class HipBLASLtGEMMBench : public GEMMPipeline
// {
// public:
// HipBLASLtGEMMBench();
// ~HipBLASLtGEMMBench();

// void initialize() override;
// void destroy() override;
// void setDevice(int device_id) override;
// Result run(Problem problem) override;

// private:
// hipblasLtHandle_t handle;
// hipblasStatus_t hipblaslt_status;

// void executeGEMM(hipblasOperation_t transA,
// hipblasOperation_t transB,
// int64_t m,
// int64_t n,
// int64_t k,
// const float& alpha,
// const float& beta,
// void* d_A,
// void* d_B,
// void* d_C,
// void* d_D,
// void* workspace,
// size_t workspace_size,
// hipStream_t stream);
// };
class HipBLASLtGEMMBench : public GEMMPipeline
{
public:
HipBLASLtGEMMBench();
~HipBLASLtGEMMBench();

void initialize() override;
void destroy() override;
void setDevice(int device_id) override;
Result run(Problem problem) override;

private:
void* workspace;
hipStream_t stream;
hipblasLtHandle_t handle;
hipblasStatus_t hipblaslt_status;

void executeGemm(int num_iterations,
Timer* timer,
Frequency::Monitor* monitor,
hipblasLtHandle_t handle,
hipblasOperation_t trans_a,
hipblasOperation_t trans_b,
int64_t m,
int64_t n,
int64_t k,
int64_t batch_count,
float& alpha,
float& beta,
void* d_a,
void* d_b,
void* d_c,
void* d_d,
void* d_workspace,
int64_t max_workspace_size,
hipStream_t stream);
};

int testDPM(int device);
int run(int device);
Expand Down
Loading

0 comments on commit e1cf1ad

Please sign in to comment.