From eeed6dfc53179374f5a2a13fd82849f977861e60 Mon Sep 17 00:00:00 2001 From: ardfork <134447697+ardfork@users.noreply.github.com> Date: Fri, 25 Aug 2023 23:55:41 +0000 Subject: [PATCH] Initial hipBLAS support --- CMakeLists.txt | 32 +++++++++++++++++++++++++++++++ Makefile | 15 +++++++++++++++ ggml-cuda.cu | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 91385cb3f81..407d9800ca7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -65,6 +65,7 @@ else() option(WHISPER_BLAS_VENDOR "whisper: BLAS library vendor" Generic) option(WHISPER_OPENBLAS "whisper: prefer OpenBLAS" OFF) option(WHISPER_CUBLAS "whisper: support for cuBLAS" OFF) + option(WHISPER_HIPBLAS "whisper: support for hipBLAS" OFF) option(WHISPER_CLBLAST "whisper: use CLBlast" OFF) endif() @@ -191,6 +192,37 @@ if (WHISPER_CUBLAS) endif() endif() + +if (WHISPER_HIPBLAS) + list(APPEND CMAKE_PREFIX_PATH /opt/rocm) + if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang") + endif() + if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++") + endif() + + find_package(hip) + find_package(hipblas) + find_package(rocblas) + + if (${hipblas_FOUND} AND ${hip_FOUND}) + message(STATUS "HIP and hipBLAS found") + add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS) + add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) + set_property(TARGET ggml-rocm PROPERTY POSITION_INDEPENDENT_CODE ON) + set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX) + target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas) + + if (WHISPER_STATIC) + message(FATAL_ERROR "Static linking not supported for HIP/ROCm") + endif() + set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ggml-rocm) + else() + message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm") + endif() +endif() + if (WHISPER_CLBLAST) find_package(CLBlast) if (CLBlast_FOUND) diff --git a/Makefile b/Makefile index 49530031ddb..ee8017b0b16 100644 --- a/Makefile +++ b/Makefile @@ -161,6 +161,21 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif +ifdef WHISPER_HIPBLAS + ROCM_PATH ?= /opt/rocm + HIPCC ?= $(ROCM_PATH)/bin/hipcc + GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch) + CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS + CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS + LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib + LDFLAGS += -lhipblas -lamdhip64 -lrocblas + HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS)) + WHISPER_OBJ += ggml-cuda.o + +ggml-cuda.o: ggml-cuda.cu ggml-cuda.h + $(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< +endif + ifdef WHISPER_CLBLAST CFLAGS += -DGGML_USE_CLBLAST CXXFLAGS += -DGGML_USE_CLBLAST diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 50df20edd7a..694e0bcb837 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6,9 +6,60 @@ #include #include +#if defined(GGML_USE_HIPBLAS) +#include +#include +#include +#include +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#define cublasCreate hipblasCreate +#define cublasGetStatusString rocblas_status_to_string +#define cublasHandle_t hipblasHandle_t +#define cublasLoggerConfigure(logIsOn, logToStdOut, logToStdErr, logFileName) CUBLAS_STATUS_SUCCESS +#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetStream hipblasSetStream +#define cublasSgemm hipblasSgemm +#define cublasStatus_t hipblasStatus_t +#define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaError_t hipError_t +#define cudaEventCreateWithFlags hipEventCreateWithFlags +#define cudaEventDestroy hipEventDestroy +#define cudaEventDisableTiming hipEventDisableTiming +#define cudaEventRecord hipEventRecord +#define cudaEvent_t hipEvent_t +#define cudaFree hipFree +#define cudaFreeHost hipHostFree +#define cudaGetDevice hipGetDevice +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaMalloc hipMalloc +#define cudaMallocHost hipHostMalloc +#define cudaMemcpy hipMemcpy +#define cudaMemcpy2DAsync hipMemcpy2DAsync +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyKind hipMemcpyKind +#define cudaMemset hipMemset +#define cudaSetDevice hipSetDevice +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamNonBlocking hipStreamNonBlocking +#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0) +#define cudaStream_t hipStream_t +#define cudaSuccess hipSuccess +#else #include #include #include +#endif #include "ggml-cuda.h" #include "ggml.h"