Skip to content

Commit

Permalink
[Library] Add cublas library (#404)
Browse files Browse the repository at this point in the history
This PR add cublas to hidet.

Check the `tests/cuda/test_cublas.py` for the usage of cublas in hidet.
  • Loading branch information
yaoyaoding authored Jan 3, 2024
1 parent 7c71965 commit f70e5e6
Show file tree
Hide file tree
Showing 27 changed files with 1,139 additions and 64 deletions.
8 changes: 5 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")

# add hidet_runtime target
add_library(hidet_runtime SHARED
src/hidet/runtime/cuda_context.cpp
src/hidet/runtime/cpu_context.cpp
src/hidet/runtime/cuda/context.cpp
src/hidet/runtime/cuda/cublas.cpp
src/hidet/runtime/cuda/cuda.cpp
src/hidet/runtime/cpu/context.cpp
src/hidet/runtime/callbacks.cpp
src/hidet/runtime/logging.cpp
src/hidet/runtime/symbols.cpp
Expand All @@ -28,7 +30,7 @@ set_target_properties(hidet_runtime PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_

# add hidet target
add_library(hidet SHARED
src/hidet/packedfunc.cpp
src/hidet/empty.cpp # empty source file
)
target_include_directories(hidet PRIVATE ${CMAKE_SOURCE_DIR}/include)
set_target_properties(hidet PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
Expand Down
38 changes: 38 additions & 0 deletions include/hidet/runtime/cuda/cublas.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <hidet/runtime/common.h>
//#include <cublas_v2.h>

#define HIDET_CUBLAS_MAX_GPUS 32

typedef void* cublasHandle_t;

struct CublasContext {
cublasHandle_t handles[HIDET_CUBLAS_MAX_GPUS]; // cublas handle for each gpu on this node
static CublasContext* global();
static cublasHandle_t current_handle();
};

DLL void hidet_cublas_set_library_path(const char* path);

// kernel functions
DLL void hidet_cublas_gemm(
int m, int n, int k, int ta, int tb, int tc, void *ptr_a, void* ptr_b, void* ptr_c, bool trans_a, bool trans_b,
int compute_type
);

DLL void hidet_cublas_strided_gemm(
int b, int m, int n, int k, int ta, int tb, int tc, void *ptr_a, void* ptr_b, void* ptr_c,
int64_t sa, int64_t sb, int64_t sc,
bool trans_a, bool trans_b, int compute_type
);
14 changes: 5 additions & 9 deletions src/hidet/packedfunc.cpp → include/hidet/runtime/cuda/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <hidet/packedfunc.h>
#pragma once
#include <hidet/runtime/common.h>

extern "C" {

DLL void CallPackedFunc(PackedFunc func, void** args) {
auto f = PackedFunc_t(func.func_pointer);
f(func.num_args, func.arg_types, args);
}

}
DLL int hidet_cuda_device_count();
DLL int hidet_cuda_get_device();
DLL void hidet_cuda_set_device(int device);

24 changes: 0 additions & 24 deletions include/hidet/packedfunc.h → include/hidet/runtime/cuda/cudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,3 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#ifndef DLL
#define DLL extern "C" __attribute__((visibility("default")))
#endif

enum ArgType {
INT32 = 1,
FLOAT32 = 2,
POINTER = 3,
};

typedef void (*PackedFunc_t)(int num_args, int *arg_types, void** args);

struct PackedFunc {
int num_args;
int* arg_types;
void** func_pointer;
};

#define INT_ARG(p) (*(int*)(p))
#define FLOAT_ARG(p) (*(float*)(p))


9 changes: 5 additions & 4 deletions include/hidet/runtime/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ class FATALMessage {
return this->stream_;
}

[[noreturn]] ~FATALMessage() noexcept(false) {
throw HidetException(this->stream_.str().c_str());
[[noreturn]] ~FATALMessage() {
std::cerr << this->stream_.str() << std::endl;
std::abort();
}
};

Expand All @@ -67,8 +68,8 @@ class ERRORMessage {
return this->stream_;
}

~ERRORMessage() {
hidet_set_last_error(this->stream_.str().c_str());
~ERRORMessage() noexcept(false) {
throw HidetException(this->stream_.str().c_str());
}
};

2 changes: 2 additions & 0 deletions python/hidet/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@
from .memory import malloc, free, malloc_async, free_async, malloc_host, free_host, memcpy_peer, memcpy_peer_async
from .memory import memcpy, memcpy_async, memset, memset_async, memory_info
from .event import Event

from . import cublas
13 changes: 13 additions & 0 deletions python/hidet/cuda/cublas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .ffi import cublasComputeType, cudaDataType
from .kernels import gemm, strided_gemm
131 changes: 131 additions & 0 deletions python/hidet/cuda/cublas/ffi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import glob
from enum import IntEnum
from ctypes import c_int32, c_int64, c_void_p, c_bool, c_char_p
from hidet.ffi.ffi import get_func
from hidet.utils.py import initialize


class cudaDataType(IntEnum):
"""
See Also: https://docs.nvidia.com/cuda/cublas/index.html#cudadatatype-t
"""

CUDA_R_16F = 2
CUDA_C_16F = 6
CUDA_R_16BF = 14
CUDA_C_16BF = 15
CUDA_R_32F = 0
CUDA_C_32F = 4
CUDA_R_64F = 1
CUDA_C_64F = 5
CUDA_R_4I = 16
CUDA_C_4I = 17
CUDA_R_4U = 18
CUDA_C_4U = 19
CUDA_R_8I = 3
CUDA_C_8I = 7
CUDA_R_8U = 8
CUDA_C_8U = 9
CUDA_R_16I = 20
CUDA_C_16I = 21
CUDA_R_16U = 22
CUDA_C_16U = 23
CUDA_R_32I = 10
CUDA_C_32I = 11
CUDA_R_32U = 12
CUDA_C_32U = 13
CUDA_R_64I = 24
CUDA_C_64I = 25
CUDA_R_64U = 26
CUDA_C_64U = 27
CUDA_R_8F_E4M3 = 28 # real as a nv_fp8_e4m3
CUDA_R_8F_E5M2 = 29 # real as a nv_fp8_e5m2


class cublasComputeType(IntEnum):
"""
See Also: https://docs.nvidia.com/cuda/cublas/index.html#cublascomputetype-t
"""

CUBLAS_COMPUTE_16F = 64 # half - default
CUBLAS_COMPUTE_16F_PEDANTIC = 65 # half - pedantic
CUBLAS_COMPUTE_32F = 68 # float - default
CUBLAS_COMPUTE_32F_PEDANTIC = 69 # float - pedantic
CUBLAS_COMPUTE_32F_FAST_16F = 74 # float - fast allows down-converting inputs to half or TF32
CUBLAS_COMPUTE_32F_FAST_16BF = 75 # float - fast allows down-converting inputs to bfloat16 or TF32
CUBLAS_COMPUTE_32F_FAST_TF32 = 77 # float - fast allows down-converting inputs to TF32
CUBLAS_COMPUTE_64F = 70 # double - default
CUBLAS_COMPUTE_64F_PEDANTIC = 71 # double - pedantic
CUBLAS_COMPUTE_32I = 72 # signed 32-bit int - default
CUBLAS_COMPUTE_32I_PEDANTIC = 73 # signed 32-bit int - pedantic


set_library_path = get_func(func_name='hidet_cublas_set_library_path', arg_types=[c_char_p], restype=None) # path

gemm = get_func(
func_name='hidet_cublas_gemm',
arg_types=[
c_int32, # m
c_int32, # n
c_int32, # k
c_int32, # type a
c_int32, # type b
c_int32, # type c
c_void_p, # ptr a
c_void_p, # ptr b
c_void_p, # ptr c
c_bool, # trans a
c_bool, # trans b
c_int32, # compute type
],
restype=None,
)

strided_gemm = get_func(
func_name='hidet_cublas_strided_gemm',
arg_types=[
c_int32, # batch size
c_int32, # m
c_int32, # n
c_int32, # k
c_int32, # type a
c_int32, # type b
c_int32, # type c
c_void_p, # ptr a
c_void_p, # ptr b
c_void_p, # ptr c
c_int64, # stride a
c_int64, # stride b
c_int64, # stride c
c_bool, # trans a
c_bool, # trans b
c_int32, # compute type
],
restype=None,
)


@initialize()
def set_cublas_library_path():
# use nvidia-cuda-cublas
for path in sys.path:
nvidia_path = os.path.join(path, 'nvidia')
if not os.path.exists(nvidia_path):
continue
cublas_path = glob.glob(os.path.join(nvidia_path, 'cublas', 'lib', 'libcublas.so.[0-9]*'))
if cublas_path:
set_library_path(cublas_path[0].encode('utf-8'))
return
Loading

0 comments on commit f70e5e6

Please sign in to comment.