From f70e5e6ab0638b615ffbce34ba2d5e5bd5abb016 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Wed, 3 Jan 2024 12:06:17 -0500 Subject: [PATCH] [Library] Add cublas library (#404) This PR add cublas to hidet. Check the `tests/cuda/test_cublas.py` for the usage of cublas in hidet. --- CMakeLists.txt | 8 +- include/hidet/runtime/cuda/cublas.h | 38 +++ .../hidet/runtime/cuda/cuda.h | 14 +- .../{packedfunc.h => runtime/cuda/cudnn.h} | 24 -- include/hidet/runtime/logging.h | 9 +- python/hidet/cuda/__init__.py | 2 + python/hidet/cuda/cublas/__init__.py | 13 + python/hidet/cuda/cublas/ffi.py | 131 +++++++ python/hidet/cuda/cublas/kernels.py | 166 +++++++++ python/hidet/cuda/cublas/utils.py | 42 +++ python/hidet/ir/expr.py | 10 +- python/hidet/ir/library/__init__.py | 1 + python/hidet/ir/library/cuda/__init__.py | 1 + .../hidet/ir/library/cuda/cublas/__init__.py | 15 + .../hidet/ir/library/cuda/cublas/kernels.py | 53 +++ python/hidet/ir/library/cuda/cublas/regs.py | 62 ++++ python/hidet/lang/cuda.py | 2 + .../transforms/annotate_header_and_libs.py | 81 ++++- src/hidet/empty.cpp | 11 + src/hidet/runtime/callbacks.cpp | 4 +- .../{cpu_context.cpp => cpu/context.cpp} | 2 +- .../{cuda_context.cpp => cuda/context.cpp} | 4 +- src/hidet/runtime/cuda/cublas.cpp | 322 ++++++++++++++++++ src/hidet/runtime/cuda/cuda.cpp | 81 +++++ src/hidet/runtime/cuda/utils.h | 12 + src/hidet/runtime/symbols.cpp | 2 +- tests/cuda/test_cublas.py | 93 +++++ 27 files changed, 1139 insertions(+), 64 deletions(-) create mode 100644 include/hidet/runtime/cuda/cublas.h rename src/hidet/packedfunc.cpp => include/hidet/runtime/cuda/cuda.h (74%) rename include/hidet/{packedfunc.h => runtime/cuda/cudnn.h} (58%) create mode 100644 python/hidet/cuda/cublas/__init__.py create mode 100644 python/hidet/cuda/cublas/ffi.py create mode 100644 python/hidet/cuda/cublas/kernels.py create mode 100644 python/hidet/cuda/cublas/utils.py create mode 100644 python/hidet/ir/library/cuda/cublas/__init__.py create mode 100644 python/hidet/ir/library/cuda/cublas/kernels.py create mode 100644 python/hidet/ir/library/cuda/cublas/regs.py create mode 100644 src/hidet/empty.cpp rename src/hidet/runtime/{cpu_context.cpp => cpu/context.cpp} (96%) rename src/hidet/runtime/{cuda_context.cpp => cuda/context.cpp} (94%) create mode 100644 src/hidet/runtime/cuda/cublas.cpp create mode 100644 src/hidet/runtime/cuda/cuda.cpp create mode 100644 src/hidet/runtime/cuda/utils.h create mode 100644 tests/cuda/test_cublas.py diff --git a/CMakeLists.txt b/CMakeLists.txt index fa1b75c58..6dab979b7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,8 +17,10 @@ message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") # add hidet_runtime target add_library(hidet_runtime SHARED - src/hidet/runtime/cuda_context.cpp - src/hidet/runtime/cpu_context.cpp + src/hidet/runtime/cuda/context.cpp + src/hidet/runtime/cuda/cublas.cpp + src/hidet/runtime/cuda/cuda.cpp + src/hidet/runtime/cpu/context.cpp src/hidet/runtime/callbacks.cpp src/hidet/runtime/logging.cpp src/hidet/runtime/symbols.cpp @@ -28,7 +30,7 @@ set_target_properties(hidet_runtime PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_ # add hidet target add_library(hidet SHARED - src/hidet/packedfunc.cpp + src/hidet/empty.cpp # empty source file ) target_include_directories(hidet PRIVATE ${CMAKE_SOURCE_DIR}/include) set_target_properties(hidet PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) diff --git a/include/hidet/runtime/cuda/cublas.h b/include/hidet/runtime/cuda/cublas.h new file mode 100644 index 000000000..ed1108957 --- /dev/null +++ b/include/hidet/runtime/cuda/cublas.h @@ -0,0 +1,38 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +//#include + +#define HIDET_CUBLAS_MAX_GPUS 32 + +typedef void* cublasHandle_t; + +struct CublasContext { + cublasHandle_t handles[HIDET_CUBLAS_MAX_GPUS]; // cublas handle for each gpu on this node + static CublasContext* global(); + static cublasHandle_t current_handle(); +}; + +DLL void hidet_cublas_set_library_path(const char* path); + +// kernel functions +DLL void hidet_cublas_gemm( + int m, int n, int k, int ta, int tb, int tc, void *ptr_a, void* ptr_b, void* ptr_c, bool trans_a, bool trans_b, + int compute_type +); + +DLL void hidet_cublas_strided_gemm( + int b, int m, int n, int k, int ta, int tb, int tc, void *ptr_a, void* ptr_b, void* ptr_c, + int64_t sa, int64_t sb, int64_t sc, + bool trans_a, bool trans_b, int compute_type +); diff --git a/src/hidet/packedfunc.cpp b/include/hidet/runtime/cuda/cuda.h similarity index 74% rename from src/hidet/packedfunc.cpp rename to include/hidet/runtime/cuda/cuda.h index b7daf2804..d386ae533 100644 --- a/src/hidet/packedfunc.cpp +++ b/include/hidet/runtime/cuda/cuda.h @@ -9,14 +9,10 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include +#pragma once +#include -extern "C" { - -DLL void CallPackedFunc(PackedFunc func, void** args) { - auto f = PackedFunc_t(func.func_pointer); - f(func.num_args, func.arg_types, args); -} - -} +DLL int hidet_cuda_device_count(); +DLL int hidet_cuda_get_device(); +DLL void hidet_cuda_set_device(int device); diff --git a/include/hidet/packedfunc.h b/include/hidet/runtime/cuda/cudnn.h similarity index 58% rename from include/hidet/packedfunc.h rename to include/hidet/runtime/cuda/cudnn.h index 999bd20ea..5653e0cb5 100644 --- a/include/hidet/packedfunc.h +++ b/include/hidet/runtime/cuda/cudnn.h @@ -9,27 +9,3 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - -#ifndef DLL -#define DLL extern "C" __attribute__((visibility("default"))) -#endif - -enum ArgType { - INT32 = 1, - FLOAT32 = 2, - POINTER = 3, -}; - -typedef void (*PackedFunc_t)(int num_args, int *arg_types, void** args); - -struct PackedFunc { - int num_args; - int* arg_types; - void** func_pointer; -}; - -#define INT_ARG(p) (*(int*)(p)) -#define FLOAT_ARG(p) (*(float*)(p)) - - diff --git a/include/hidet/runtime/logging.h b/include/hidet/runtime/logging.h index a90ea55bc..e3c526d86 100644 --- a/include/hidet/runtime/logging.h +++ b/include/hidet/runtime/logging.h @@ -47,8 +47,9 @@ class FATALMessage { return this->stream_; } - [[noreturn]] ~FATALMessage() noexcept(false) { - throw HidetException(this->stream_.str().c_str()); + [[noreturn]] ~FATALMessage() { + std::cerr << this->stream_.str() << std::endl; + std::abort(); } }; @@ -67,8 +68,8 @@ class ERRORMessage { return this->stream_; } - ~ERRORMessage() { - hidet_set_last_error(this->stream_.str().c_str()); + ~ERRORMessage() noexcept(false) { + throw HidetException(this->stream_.str().c_str()); } }; diff --git a/python/hidet/cuda/__init__.py b/python/hidet/cuda/__init__.py index 8003ea5b3..7e6efbfa5 100644 --- a/python/hidet/cuda/__init__.py +++ b/python/hidet/cuda/__init__.py @@ -16,3 +16,5 @@ from .memory import malloc, free, malloc_async, free_async, malloc_host, free_host, memcpy_peer, memcpy_peer_async from .memory import memcpy, memcpy_async, memset, memset_async, memory_info from .event import Event + +from . import cublas diff --git a/python/hidet/cuda/cublas/__init__.py b/python/hidet/cuda/cublas/__init__.py new file mode 100644 index 000000000..e08d49ee0 --- /dev/null +++ b/python/hidet/cuda/cublas/__init__.py @@ -0,0 +1,13 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .ffi import cublasComputeType, cudaDataType +from .kernels import gemm, strided_gemm diff --git a/python/hidet/cuda/cublas/ffi.py b/python/hidet/cuda/cublas/ffi.py new file mode 100644 index 000000000..96b948632 --- /dev/null +++ b/python/hidet/cuda/cublas/ffi.py @@ -0,0 +1,131 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +import glob +from enum import IntEnum +from ctypes import c_int32, c_int64, c_void_p, c_bool, c_char_p +from hidet.ffi.ffi import get_func +from hidet.utils.py import initialize + + +class cudaDataType(IntEnum): + """ + See Also: https://docs.nvidia.com/cuda/cublas/index.html#cudadatatype-t + """ + + CUDA_R_16F = 2 + CUDA_C_16F = 6 + CUDA_R_16BF = 14 + CUDA_C_16BF = 15 + CUDA_R_32F = 0 + CUDA_C_32F = 4 + CUDA_R_64F = 1 + CUDA_C_64F = 5 + CUDA_R_4I = 16 + CUDA_C_4I = 17 + CUDA_R_4U = 18 + CUDA_C_4U = 19 + CUDA_R_8I = 3 + CUDA_C_8I = 7 + CUDA_R_8U = 8 + CUDA_C_8U = 9 + CUDA_R_16I = 20 + CUDA_C_16I = 21 + CUDA_R_16U = 22 + CUDA_C_16U = 23 + CUDA_R_32I = 10 + CUDA_C_32I = 11 + CUDA_R_32U = 12 + CUDA_C_32U = 13 + CUDA_R_64I = 24 + CUDA_C_64I = 25 + CUDA_R_64U = 26 + CUDA_C_64U = 27 + CUDA_R_8F_E4M3 = 28 # real as a nv_fp8_e4m3 + CUDA_R_8F_E5M2 = 29 # real as a nv_fp8_e5m2 + + +class cublasComputeType(IntEnum): + """ + See Also: https://docs.nvidia.com/cuda/cublas/index.html#cublascomputetype-t + """ + + CUBLAS_COMPUTE_16F = 64 # half - default + CUBLAS_COMPUTE_16F_PEDANTIC = 65 # half - pedantic + CUBLAS_COMPUTE_32F = 68 # float - default + CUBLAS_COMPUTE_32F_PEDANTIC = 69 # float - pedantic + CUBLAS_COMPUTE_32F_FAST_16F = 74 # float - fast allows down-converting inputs to half or TF32 + CUBLAS_COMPUTE_32F_FAST_16BF = 75 # float - fast allows down-converting inputs to bfloat16 or TF32 + CUBLAS_COMPUTE_32F_FAST_TF32 = 77 # float - fast allows down-converting inputs to TF32 + CUBLAS_COMPUTE_64F = 70 # double - default + CUBLAS_COMPUTE_64F_PEDANTIC = 71 # double - pedantic + CUBLAS_COMPUTE_32I = 72 # signed 32-bit int - default + CUBLAS_COMPUTE_32I_PEDANTIC = 73 # signed 32-bit int - pedantic + + +set_library_path = get_func(func_name='hidet_cublas_set_library_path', arg_types=[c_char_p], restype=None) # path + +gemm = get_func( + func_name='hidet_cublas_gemm', + arg_types=[ + c_int32, # m + c_int32, # n + c_int32, # k + c_int32, # type a + c_int32, # type b + c_int32, # type c + c_void_p, # ptr a + c_void_p, # ptr b + c_void_p, # ptr c + c_bool, # trans a + c_bool, # trans b + c_int32, # compute type + ], + restype=None, +) + +strided_gemm = get_func( + func_name='hidet_cublas_strided_gemm', + arg_types=[ + c_int32, # batch size + c_int32, # m + c_int32, # n + c_int32, # k + c_int32, # type a + c_int32, # type b + c_int32, # type c + c_void_p, # ptr a + c_void_p, # ptr b + c_void_p, # ptr c + c_int64, # stride a + c_int64, # stride b + c_int64, # stride c + c_bool, # trans a + c_bool, # trans b + c_int32, # compute type + ], + restype=None, +) + + +@initialize() +def set_cublas_library_path(): + # use nvidia-cuda-cublas + for path in sys.path: + nvidia_path = os.path.join(path, 'nvidia') + if not os.path.exists(nvidia_path): + continue + cublas_path = glob.glob(os.path.join(nvidia_path, 'cublas', 'lib', 'libcublas.so.[0-9]*')) + if cublas_path: + set_library_path(cublas_path[0].encode('utf-8')) + return diff --git a/python/hidet/cuda/cublas/kernels.py b/python/hidet/cuda/cublas/kernels.py new file mode 100644 index 000000000..fb079260d --- /dev/null +++ b/python/hidet/cuda/cublas/kernels.py @@ -0,0 +1,166 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union +from hidet.ir.dtypes import DataType +from .utils import as_pointer, as_type_code +from .ffi import cublasComputeType, cudaDataType +from . import ffi + + +def gemm( + m: int, + n: int, + k: int, + type_a: Union[int, cudaDataType, DataType], + type_b: Union[int, cudaDataType, DataType], + type_c: Union[int, cudaDataType, DataType], + a, + b, + c, + compute_type: Union[int, cublasComputeType], + trans_a: bool = False, + trans_b: bool = False, +): + """ + Matrix multiplication of two matrices using cublas in row major by default. + + The matrix of A, B, and C are stored in row-major order (if not transposed). + + A: m x k + B: k x n + C: m x n + + + Parameters + ---------- + m: int + Number of rows of matrix op(A) and of matrix C. + n: int + Number of columns of matrix op(B) and of matrix C. + k: int + Number of columns of matrix op(A) and of rows of matrix op(B). + type_a: Union[int, cudaDataType, DataType] + Type of elements in matrix A. + type_b: Union[int, cudaDataType, DataType] + Type of elements in matrix B. + type_c: Union[int, cudaDataType, DataType] + Type of elements in matrix C. + a: Tensor or int + Matrix A, can be either a Tensor or an integer (the address of the matrix). + b: Tensor or int + Matrix B, can be either a Tensor or an integer (the address of the matrix). + c: Tensor or int + Matrix C, can be either a Tensor or an integer (the address of the matrix). + compute_type: Union[int, cublasComputeType] + The compute type of the operation. + trans_a: bool + Whether matrix A is transposed. + trans_b: bool + Whether matrix B is transposed. + """ + ffi.gemm( + m, + n, + k, + as_type_code(type_a), + as_type_code(type_b), + as_type_code(type_c), + as_pointer(a), + as_pointer(b), + as_pointer(c), + trans_a, + trans_b, + compute_type, + ) + + +def strided_gemm( + bs: int, + m: int, + n: int, + k: int, + type_a: Union[int, cudaDataType, DataType], + type_b: Union[int, cudaDataType, DataType], + type_c: Union[int, cudaDataType, DataType], + a, + b, + c, + stride_a: int, + stride_b: int, + stride_c: int, + compute_type: Union[int, cublasComputeType], + trans_a: bool = False, + trans_b: bool = False, +): + """ + Batch matrix multiplication of two matrices using cublas in row major order by default. + + The matrix of A, B, and C are stored in row-major order (if not transposed). + + A: bs x m x k + B: bs x k x n + C: bs x m x n + + + Parameters + ---------- + bs: int + Batch size. + m: int + Number of rows of matrix op(A) and of matrix C. + n: int + Number of columns of matrix op(B) and of matrix C. + k: int + Number of columns of matrix op(A) and of rows of matrix op(B). + type_a: Union[int, DataType] + Type of elements in matrix A. + type_b: Union[int, DataType] + Type of elements in matrix B. + type_c: Union[int, DataType] + Type of elements in matrix C. + a: Tensor or int + Matrix A, can be either a Tensor or an integer (the address of the matrix). + b: Tensor or int + Matrix B, can be either a Tensor or an integer (the address of the matrix). + c: Tensor or int + Matrix C, can be either a Tensor or an integer (the address of the matrix). + stride_a: int + Stride of matrix A on batch dimension. + stride_b: int + Stride of matrix B on batch dimension. + stride_c: int + Stride of matrix C on batch dimension. + trans_a: bool + Whether matrix A is transposed. + trans_b: bool + Whether matrix B is transposed. + compute_type: Union[int, cublasComputeType] + The compute type of the operation. + """ + ffi.strided_gemm( + bs, + m, + n, + k, + as_type_code(type_a), + as_type_code(type_b), + as_type_code(type_c), + as_pointer(a), + as_pointer(b), + as_pointer(c), + stride_a, + stride_b, + stride_c, + trans_a, + trans_b, + compute_type, + ) diff --git a/python/hidet/cuda/cublas/utils.py b/python/hidet/cuda/cublas/utils.py new file mode 100644 index 000000000..584bae2a7 --- /dev/null +++ b/python/hidet/cuda/cublas/utils.py @@ -0,0 +1,42 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.ir import dtypes +from hidet.ir.dtypes import DataType +from .ffi import cudaDataType + + +def as_pointer(obj) -> int: + from hidet.graph.tensor import Tensor + + if isinstance(obj, Tensor): + return obj.storage.addr + elif isinstance(obj, int): + return obj + else: + raise TypeError(f'Expected Tensor or int, but got {type(obj)}') + + +# see the definition of cudaDataType_t in to get the type code of each type +_type_dict = { + dtypes.float16: cudaDataType.CUDA_R_16F, + dtypes.float32: cudaDataType.CUDA_R_32F, + dtypes.float64: cudaDataType.CUDA_R_64F, +} + + +def as_type_code(obj) -> int: + if isinstance(obj, DataType): + return _type_dict[obj] + elif isinstance(obj, int): + return obj + else: + raise TypeError(f'Expected DataType or int, but got {type(obj)}') diff --git a/python/hidet/ir/expr.py b/python/hidet/ir/expr.py index c9d3c115f..5c2d8f396 100644 --- a/python/hidet/ir/expr.py +++ b/python/hidet/ir/expr.py @@ -549,14 +549,14 @@ def __init__(self, hint: Optional[str], type: BaseType, name: Optional[str] = No """ A variable may have a hint, name, and id. - Hint is used to determine the name in codegen. Different vars may have the + self.hint is used to determine the name in codegen. Different vars may have the same hint. If two vars have the same hint such as 'x', the final name would be like 'x1', 'x2'. - OUTDATED: - Name is the determined name in the final code. Used by primitive variables such as 'threadIdx.x'. No variable - should have a same name as primitive objects (including primitive variables and primitive functions). + self.name is used to store the name of the variables that will be used directly in codegen, such as + "threadIdx.x". The field self.name and self.hint are used exclusively. If self.name is not None, + self.hint will be ignored, otherwise, self.hint will be used to determine the name in codegen. - ID is used to track the allocation of Var object in python, which is only used to help us to distinguish + self.id is used to track the allocation of Var object in python, which is only used to help us to distinguish different Var in python debugger. """ self.hint: Optional[str] = hint diff --git a/python/hidet/ir/library/__init__.py b/python/hidet/ir/library/__init__.py index bf809d906..74f6a454d 100644 --- a/python/hidet/ir/library/__init__.py +++ b/python/hidet/ir/library/__init__.py @@ -10,3 +10,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from . import cuda +from .cuda import cublas diff --git a/python/hidet/ir/library/cuda/__init__.py b/python/hidet/ir/library/cuda/__init__.py index 6ce7d5b0f..7e9830590 100644 --- a/python/hidet/ir/library/cuda/__init__.py +++ b/python/hidet/ir/library/cuda/__init__.py @@ -10,3 +10,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from .matmul import matmul_simt +from . import cublas diff --git a/python/hidet/ir/library/cuda/cublas/__init__.py b/python/hidet/ir/library/cuda/cublas/__init__.py new file mode 100644 index 000000000..f7d9010e6 --- /dev/null +++ b/python/hidet/ir/library/cuda/cublas/__init__.py @@ -0,0 +1,15 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.cuda.cublas.utils import as_type_code +from hidet.cuda.cublas.kernels import cublasComputeType, cudaDataType +from .kernels import gemm, bgemm +from . import regs as _regs # register functions diff --git a/python/hidet/ir/library/cuda/cublas/kernels.py b/python/hidet/ir/library/cuda/cublas/kernels.py new file mode 100644 index 000000000..127704daf --- /dev/null +++ b/python/hidet/ir/library/cuda/cublas/kernels.py @@ -0,0 +1,53 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union +from hidet.ir.expr import Expr +from hidet.ir.primitives.func import call_primitive_func + + +def gemm( + m: Union[Expr, int], + n: Union[Expr, int], + k: Union[Expr, int], + type_a: Union[Expr, int], + type_b: Union[Expr, int], + type_c: Union[Expr, int], + a: Expr, + b: Expr, + c: Expr, + trans_a: Union[Expr, bool], + trans_b: Union[Expr, bool], + compute_type: Union[Expr, int], +): + return call_primitive_func( + func_name='cublas.gemm', args=[m, n, k, type_a, type_b, type_c, a, b, c, trans_a, trans_b, compute_type] + ) + + +def bgemm( + bs: Union[Expr, int], + m: Union[Expr, int], + n: Union[Expr, int], + k: Union[Expr, int], + type_a: Union[Expr, int], + type_b: Union[Expr, int], + type_c: Union[Expr, int], + a: Expr, + b: Expr, + c: Expr, + trans_a: Union[Expr, bool], + trans_b: Union[Expr, bool], + compute_type: Union[Expr, int], +): + return call_primitive_func( + func_name='cublas.bgemm', args=[bs, m, n, k, type_a, type_b, type_c, a, b, c, trans_a, trans_b, compute_type] + ) diff --git a/python/hidet/ir/library/cuda/cublas/regs.py b/python/hidet/ir/library/cuda/cublas/regs.py new file mode 100644 index 000000000..05d85edb5 --- /dev/null +++ b/python/hidet/ir/library/cuda/cublas/regs.py @@ -0,0 +1,62 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.ir.dtypes import int32, boolean +from hidet.ir.type import FuncType, void_p, void +from hidet.ir.primitives.func import register_primitive_function +from hidet.utils import initialize + + +@initialize() +def register_cublas_kernels(): + register_primitive_function( + name='cublas.gemm', + func_or_type=FuncType( + param_types=[ + int32, # m + int32, # n + int32, # k + int32, # type_a (cudaDataType) + int32, # type_b (cudaDataType) + int32, # type_c (cudaDataType) + void_p, # a + void_p, # b + void_p, # c + boolean, # trans_a + boolean, # trans_b + int32, # compute_type (cublasComputeType) + ], + ret_type=void, + ), + codegen_name='hidet_cublas_gemm', + ) + register_primitive_function( + name='cublas.bgemm', + func_or_type=FuncType( + param_types=[ + int32, # bs + int32, # m + int32, # n + int32, # k + int32, # type_a (cudaDataType) + int32, # type_b (cudaDataType) + int32, # type_c (cudaDataType) + void_p, # a + void_p, # b + void_p, # c + boolean, # trans_a + boolean, # trans_b + int32, # compute_type (cublasComputeType) + ], + ret_type=void, + ), + codegen_name='hidet_cublas_bgemm', + ) diff --git a/python/hidet/lang/cuda.py b/python/hidet/lang/cuda.py index 7d59da24c..5f892e615 100644 --- a/python/hidet/lang/cuda.py +++ b/python/hidet/lang/cuda.py @@ -27,3 +27,5 @@ from hidet.ir.primitives.cuda.shfl import shfl_sync, shfl_up_sync, shfl_xor_sync, shfl_down_sync from hidet.ir.primitives.cuda.mutex import acquire_lock, release_lock, acquire_seq_semaphore, release_seq_semaphore from hidet.lang.constructs.declare import register_tensor, shared_tensor + +from hidet.ir.library.cuda import cublas diff --git a/python/hidet/transforms/annotate_header_and_libs.py b/python/hidet/transforms/annotate_header_and_libs.py index 53a974a8e..cb012aa8b 100644 --- a/python/hidet/transforms/annotate_header_and_libs.py +++ b/python/hidet/transforms/annotate_header_and_libs.py @@ -9,34 +9,89 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import hidet.ir +from typing import List +from hidet.ir.tools import collect from hidet.ir.module import IRModule +from hidet.ir.expr import Call from hidet.ir.stmt import BlackBoxStmt from hidet.transforms import Pass -def _use_distributed(func) -> bool: - black_stmts = hidet.ir.tools.collect(func.body, [BlackBoxStmt]) - return any(stmt.template_string.startswith('nccl') for stmt in black_stmts) +class Annotator: + def __init__(self): + self.include_dirs: List[str] = [] + self.linking_dirs: List[str] = [] + self.include_headers: List[str] = [] + self.linking_libs: List[str] = [] + def predicate(self, ir_module: IRModule) -> bool: + raise NotImplementedError() -class AnnotateHeaderAndLibsPass(Pass): - def process_module(self, ir_module: IRModule) -> IRModule: - use_dist = any(_use_distributed(func) for func in ir_module.functions.values()) - if not use_dist: - return ir_module + def apply(self): + raise NotImplementedError() + + +class AnnotateNCCL(Annotator): + """ + Annotate the header and libraries for NCCL. + Headers: nccl.h + Libraries: libnccl.so + """ + + def predicate(self, ir_module: IRModule) -> bool: + for func in ir_module.functions.values(): + black_stmts = collect(func.body, [BlackBoxStmt]) + if any(stmt.template_string.startswith('nccl') for stmt in black_stmts): + return True + return False + + def apply(self): from hidet.cuda.nccl.libinfo import get_nccl_include_dirs, get_nccl_library_search_dirs from hidet.cuda.nccl import nccl_available, nccl_library_filename if not nccl_available(): raise RuntimeError("NCCL is not available") + self.include_dirs.extend(get_nccl_include_dirs()) + self.linking_dirs.extend(get_nccl_library_search_dirs()) + self.include_headers.append("nccl.h") + self.linking_libs.append(":" + nccl_library_filename()) + + +class AnnotateCUBLAS(Annotator): + def predicate(self, ir_module: IRModule) -> bool: + for func in ir_module.functions.values(): + calls: List[Call] = collect(func.body, [Call]) + if any(call.func_var.name and call.func_var.name.startswith('cublas.') for call in calls): + return True + return False + + def apply(self): + self.include_headers.append("hidet/runtime/cuda/cublas.h") + + +class AnnotateHeaderAndLibsPass(Pass): + def process_module(self, ir_module: IRModule) -> IRModule: + annotators = [AnnotateNCCL(), AnnotateCUBLAS()] + include_dirs: List[str] = [] + linking_dirs: List[str] = [] + include_headers: List[str] = [] + linking_libs: List[str] = [] + + for annotator in annotators: + if annotator.predicate(ir_module): + annotator.apply() + include_dirs.extend(annotator.include_dirs) + linking_dirs.extend(annotator.linking_dirs) + include_headers.extend(annotator.include_headers) + linking_libs.extend(annotator.linking_libs) + new_module = ir_module.copy() - new_module.include_dirs.extend(get_nccl_include_dirs()) - new_module.linking_dirs.extend(get_nccl_library_search_dirs()) - new_module.include_headers.append(["nccl.h"]) - new_module.linking_libs.append(":" + nccl_library_filename()) + new_module.include_dirs.extend(include_dirs) + new_module.linking_dirs.extend(linking_dirs) + new_module.include_headers.extend(include_headers) + new_module.linking_libs.extend(linking_libs) return new_module diff --git a/src/hidet/empty.cpp b/src/hidet/empty.cpp new file mode 100644 index 000000000..5653e0cb5 --- /dev/null +++ b/src/hidet/empty.cpp @@ -0,0 +1,11 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. diff --git a/src/hidet/runtime/callbacks.cpp b/src/hidet/runtime/callbacks.cpp index 79de6b712..6038916a9 100644 --- a/src/hidet/runtime/callbacks.cpp +++ b/src/hidet/runtime/callbacks.cpp @@ -45,7 +45,7 @@ static FuncType* get_callback_ptr() { auto pool = CallbackRegistryPool::global(); assert(id < pool->id2name.size()); if(id >= pool->id2ptr.size() || pool->id2ptr[id] == nullptr) { - LOG(FATAL) << "Callback function " << pool->id2name[id] << " has not been registered."; + LOG(ERROR) << "Callback function " << pool->id2name[id] << " has not been registered."; } void* ptr = pool->id2ptr[id]; typedef FuncType* FuncPointerType; @@ -56,7 +56,7 @@ DLL void register_callback(const char* name, void *func_ptr) { try { auto pool = CallbackRegistryPool::global(); if (pool->name2id.count(name) == 0) { - LOG(FATAL) << "Function " << std::string(name) << " is not a callback function."; + LOG(ERROR) << "Function " << std::string(name) << " is not a callback function."; } int id = pool->name2id[name]; if(id >= pool->id2ptr.size()) { diff --git a/src/hidet/runtime/cpu_context.cpp b/src/hidet/runtime/cpu/context.cpp similarity index 96% rename from src/hidet/runtime/cpu_context.cpp rename to src/hidet/runtime/cpu/context.cpp index 7c0b2e8a9..7e173f983 100644 --- a/src/hidet/runtime/cpu_context.cpp +++ b/src/hidet/runtime/cpu/context.cpp @@ -26,7 +26,7 @@ static void reserve_cpu_workspace(Workspace &workspace, size_t nbytes) { } workspace.base = reinterpret_cast(allocate_cpu_storage(nbytes)); if(workspace.base == nullptr) { - LOG(FATAL) << "allocate workspace failed."; + LOG(ERROR) << "allocate workspace failed."; } memset(workspace.base, 0, nbytes); } diff --git a/src/hidet/runtime/cuda_context.cpp b/src/hidet/runtime/cuda/context.cpp similarity index 94% rename from src/hidet/runtime/cuda_context.cpp rename to src/hidet/runtime/cuda/context.cpp index 0a6273acf..80ba9010e 100644 --- a/src/hidet/runtime/cuda_context.cpp +++ b/src/hidet/runtime/cuda/context.cpp @@ -25,7 +25,7 @@ static void reserve_cuda_workspace(Workspace &workspace, size_t nbytes) { } workspace.base = reinterpret_cast(allocate_cuda_storage(nbytes)); if(workspace.base == nullptr) { - LOG(FATAL) << "allocate workspace failed."; + LOG(ERROR) << "allocate workspace failed."; } cuda_memset(reinterpret_cast(workspace.base), 0, nbytes); @@ -64,7 +64,7 @@ DLL void set_nccl_comms(int num_comms, void** comms) { DLL void* get_nccl_comm(int idx) { const int num_comms = CudaContext::global()->num_comms; if (idx >= num_comms) { - LOG(FATAL) << "Index of NCCL Communicator out of boundary. (" << idx << " vs " << num_comms << ")"; + LOG(ERROR) << "Index of NCCL Communicator out of boundary. (" << idx << " vs " << num_comms << ")"; } return CudaContext::global()->nccl_comms[idx]; } \ No newline at end of file diff --git a/src/hidet/runtime/cuda/cublas.cpp b/src/hidet/runtime/cuda/cublas.cpp new file mode 100644 index 000000000..e42bec40c --- /dev/null +++ b/src/hidet/runtime/cuda/cublas.cpp @@ -0,0 +1,322 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include +#include +#include +//#include +//#include +#include "./utils.h" + +// types defined in , +// copyright (c) +// NVIDIA Corporation. All rights reserved. +typedef enum { + CUBLAS_STATUS_SUCCESS = 0, + CUBLAS_STATUS_NOT_INITIALIZED = 1, + CUBLAS_STATUS_ALLOC_FAILED = 3, + CUBLAS_STATUS_INVALID_VALUE = 7, + CUBLAS_STATUS_ARCH_MISMATCH = 8, + CUBLAS_STATUS_MAPPING_ERROR = 11, + CUBLAS_STATUS_EXECUTION_FAILED = 13, + CUBLAS_STATUS_INTERNAL_ERROR = 14, + CUBLAS_STATUS_NOT_SUPPORTED = 15, + CUBLAS_STATUS_LICENSE_ERROR = 16 +} cublasStatus_t; + +typedef enum { + CUBLAS_OP_N = 0, + CUBLAS_OP_T = 1, + CUBLAS_OP_C = 2, +} cublasOperation_t; + +typedef enum cudaDataType_t +{ + CUDA_R_16F = 2, + CUDA_C_16F = 6, + CUDA_R_16BF = 14, + CUDA_C_16BF = 15, + CUDA_R_32F = 0, + CUDA_C_32F = 4, + CUDA_R_64F = 1, + CUDA_C_64F = 5, + CUDA_R_4I = 16, + CUDA_C_4I = 17, + CUDA_R_4U = 18, + CUDA_C_4U = 19, + CUDA_R_8I = 3, + CUDA_C_8I = 7, + CUDA_R_8U = 8, + CUDA_C_8U = 9, + CUDA_R_16I = 20, + CUDA_C_16I = 21, + CUDA_R_16U = 22, + CUDA_C_16U = 23, + CUDA_R_32I = 10, + CUDA_C_32I = 11, + CUDA_R_32U = 12, + CUDA_C_32U = 13, + CUDA_R_64I = 24, + CUDA_C_64I = 25, + CUDA_R_64U = 26, + CUDA_C_64U = 27, + CUDA_R_8F_E4M3 = 28, /* real as a nv_fp8_e4m3 */ + CUDA_R_8F_E5M2 = 29, /* real as a nv_fp8_e5m2 */ +} cudaDataType; + +/* Enum for compute type + * + * - default types provide best available performance using all available hardware features + * and guarantee internal storage precision with at least the same precision and range; + * - _PEDANTIC types ensure standard arithmetic and exact specified internal storage format; + * - _FAST types allow for some loss of precision to enable higher throughput arithmetic. + */ +typedef enum { + CUBLAS_COMPUTE_16F = 64, /* half - default */ + CUBLAS_COMPUTE_16F_PEDANTIC = 65, /* half - pedantic */ + CUBLAS_COMPUTE_32F = 68, /* float - default */ + CUBLAS_COMPUTE_32F_PEDANTIC = 69, /* float - pedantic */ + CUBLAS_COMPUTE_32F_FAST_16F = 74, /* float - fast, allows down-converting inputs to half or TF32 */ + CUBLAS_COMPUTE_32F_FAST_16BF = 75, /* float - fast, allows down-converting inputs to bfloat16 or TF32 */ + CUBLAS_COMPUTE_32F_FAST_TF32 = 77, /* float - fast, allows down-converting inputs to TF32 */ + CUBLAS_COMPUTE_64F = 70, /* double - default */ + CUBLAS_COMPUTE_64F_PEDANTIC = 71, /* double - pedantic */ + CUBLAS_COMPUTE_32I = 72, /* signed 32-bit int - default */ + CUBLAS_COMPUTE_32I_PEDANTIC = 73, /* signed 32-bit int - pedantic */ +} cublasComputeType_t; + +typedef enum { + CUBLAS_GEMM_DEFAULT = -1, +} cublasGemmAlgo_t; + + +// define cublas api functions +typedef const char* (*cublasGetStatusName_t)(cublasStatus_t status); +typedef const char* (*cublasGetStatusString_t)(cublasStatus_t status); +typedef cublasStatus_t (*cublasCreate_t)(cublasHandle_t *handle); +typedef cublasStatus_t (*cublasGemmEx_t)( + cublasHandle_t handle, + cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const void *alpha, + const void *A, cudaDataType_t Atype, int lda, + const void *B, cudaDataType_t Btype, int ldb, + const void *beta, + void *C, cudaDataType_t Ctype, int ldc, + cublasComputeType_t computeType, + cublasGemmAlgo_t algo +); +typedef cublasStatus_t (*cublasGemmStridedBatchedEx_t)( + cublasHandle_t handle, + cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, + const void *alpha, + const void *A, cudaDataType_t Atype, int lda, long long int strideA, + const void *B, cudaDataType_t Btype, int ldb, long long int strideB, + const void *beta, + void *C, cudaDataType_t Ctype, int ldc, long long int strideC, + int batchCount, + cublasComputeType_t computeType, + cublasGemmAlgo_t algo +); + + +// cublas api functions +static cublasCreate_t cublasCreate; +static cublasGetStatusName_t cublasGetStatusName; +static cublasGetStatusString_t cublasGetStatusString; +static cublasGemmEx_t cublasGemmEx; +static cublasGemmStridedBatchedEx_t cublasGemmStridedBatchedEx; + +static std::string library_path; +static void* libcublas = nullptr; + +// utility functions +#define CHECK_CUBLAS(status) do { \ + cublasStatus_t err = (status); \ + if(err != 0) { \ + LOG(FATAL) << "cuBLAS error: " << cublasGetStatusString(err) << " (" << cublasGetStatusName(err) << ")"; \ + } \ +} while(0) + + +static void set_alpha_beta(const void** p_alpha, const void** p_beta, cublasComputeType_t c, cudaDataType_t tc) { + if (tc == CUDA_C_32F || tc == CUDA_C_64F) { + LOG(FATAL) << "NotImplementedError: complex numbers are not supported yet" << std::endl; + } + + if(c == CUBLAS_COMPUTE_16F || c == CUBLAS_COMPUTE_16F_PEDANTIC) { + static const int16_t alpha = 0x3c00; // half(1.0) + static const int16_t beta = 0x0000; // half(0.0) + *p_alpha = α + *p_beta = β + } else if( + c == CUBLAS_COMPUTE_32F || c == CUBLAS_COMPUTE_32F_PEDANTIC || c == CUBLAS_COMPUTE_32F_FAST_16F || + c == CUBLAS_COMPUTE_32F_FAST_16BF || c == CUBLAS_COMPUTE_32F_FAST_TF32 + ) { + static const float alpha = 1.0f; + static const float beta = 0.0f; + *p_alpha = α + *p_beta = β + } else if(c == CUBLAS_COMPUTE_32I || c == CUBLAS_COMPUTE_32I_PEDANTIC) { + static const int32_t alpha = 1; + static const int32_t beta = 0; + *p_alpha = α + *p_beta = β + } else if(c == CUBLAS_COMPUTE_64F || c == CUBLAS_COMPUTE_64F_PEDANTIC) { + static const double alpha = 1.0; + static const double beta = 0.0; + *p_alpha = α + *p_beta = β + } else { + LOG(FATAL) << "Unsupported compute type: " << c; + } +} + +static void lazy_load_cublas() { + if(libcublas == nullptr) { + // load cublas shared library + const char* libpath; + if(library_path.empty()) { + libpath = "libcublas.so"; + } else { + libpath = library_path.c_str(); + } + libcublas = dlopen(libpath, RTLD_LAZY); + if(libcublas == nullptr) { + LOG(FATAL) << "Failed to load cublas library: " << libpath << dlerror(); + } + + // load api functions + cublasCreate = get_symbol(libcublas, "cublasCreate_v2"); + cublasGetStatusName = get_symbol(libcublas, "cublasGetStatusName"); + cublasGetStatusString = get_symbol(libcublas, "cublasGetStatusString"); + cublasGemmEx = get_symbol(libcublas, "cublasGemmEx"); + cublasGemmStridedBatchedEx = get_symbol( + libcublas, "cublasGemmStridedBatchedEx" + ); + } +} + + +CublasContext* CublasContext::global() { + static CublasContext instance; + static bool initialized = false; + + if(!initialized) { + // create cublas handle for each gpu + int count = hidet_cuda_device_count(); + assert(count <= HIDET_CUBLAS_MAX_GPUS); + + int current_device = hidet_cuda_get_device(); + for(int i = 0; i < count; i++) { + hidet_cuda_set_device(i); + CHECK_CUBLAS(cublasCreate(&instance.handles[i])); + } + hidet_cuda_set_device(current_device); + + initialized = true; + } + return &instance; +} + +cublasHandle_t CublasContext::current_handle() { + return CublasContext::global()->handles[hidet_cuda_get_device()]; +} + + +// hidet cublas api functions +DLL void hidet_cublas_set_library_path(const char* path) { + if(path) { + library_path = path; + } +} + +DLL void hidet_cublas_gemm( + int m, int n, int k, int ta, int tb, int tc, void *ptr_a, void* ptr_b, void* ptr_c, bool trans_a, bool trans_b, + int compute_type +) { + lazy_load_cublas(); + + const void *p_alpha = nullptr; + const void *p_beta = nullptr; + + set_alpha_beta(&p_alpha, &p_beta, cublasComputeType_t(compute_type), cudaDataType_t(tc)); + + // we apply c^T = b^T @ a^T (c = a @ b) here + CHECK_CUBLAS(cublasGemmEx( + CublasContext::current_handle(), + trans_a ? cublasOperation_t::CUBLAS_OP_T : cublasOperation_t::CUBLAS_OP_N, + trans_b ? cublasOperation_t::CUBLAS_OP_T : cublasOperation_t::CUBLAS_OP_N, + n, + m, + k, + p_alpha, + ptr_b, + cudaDataType(tb), + n, // ldb + ptr_a, + cudaDataType(ta), + k, // lda + p_beta, + ptr_c, + cudaDataType(tc), + n, // ldc + cublasComputeType_t(compute_type), + cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT + )); +} + +DLL void hidet_cublas_strided_gemm( + int b, int m, int n, int k, int ta, int tb, int tc, void *ptr_a, void* ptr_b, void* ptr_c, + int64_t sa, int64_t sb, int64_t sc, + bool trans_a, bool trans_b, int compute_type +) { + lazy_load_cublas(); + + const void *p_alpha = nullptr; + const void *p_beta = nullptr; + + set_alpha_beta(&p_alpha, &p_beta, cublasComputeType_t(compute_type), cudaDataType_t(tc)); + + CHECK_CUBLAS(cublasGemmStridedBatchedEx( + CublasContext::current_handle(), + trans_a ? cublasOperation_t::CUBLAS_OP_T : cublasOperation_t::CUBLAS_OP_N, + trans_b ? cublasOperation_t::CUBLAS_OP_T : cublasOperation_t::CUBLAS_OP_N, + n, + m, + k, + p_alpha, + // b^t + ptr_b, + cudaDataType(tb), + n, // ldb + sb, // strideB + // a^t + ptr_a, + cudaDataType(ta), + k, // lda + sa, // strideA + p_beta, + // c^t + ptr_c, + cudaDataType(tc), + n, // ldc + sc, // strideC + b, // batchCount + cublasComputeType_t(compute_type), + cublasGemmAlgo_t::CUBLAS_GEMM_DEFAULT + )); +} + + diff --git a/src/hidet/runtime/cuda/cuda.cpp b/src/hidet/runtime/cuda/cuda.cpp new file mode 100644 index 000000000..47e4c6576 --- /dev/null +++ b/src/hidet/runtime/cuda/cuda.cpp @@ -0,0 +1,81 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include +#include "./utils.h" + +// CUDA runtime APIs +typedef int cudaError_t; +typedef cudaError_t (*cudaGetDeviceCount_t)(int* count); +typedef cudaError_t (*cudaGetDevice_t)(int* device); +typedef cudaError_t (*cudaSetDevice_t)(int device); +typedef const char* (*cudaGetErrorString_t)(cudaError_t error); + +static std::string library_path; +static void* libcudart = nullptr; +static cudaGetDeviceCount_t cudaGetDeviceCount = nullptr; +static cudaGetDevice_t cudaGetDevice = nullptr; +static cudaSetDevice_t cudaSetDevice = nullptr; +static cudaGetErrorString_t cudaGetErrorString = nullptr; + +// load cuda runtime APIs +static inline void lazy_load_cuda_runtime() { + if(libcudart == nullptr) { + const char* libpath; + if(library_path.empty()) { + libpath = "libcudart.so"; + } else { + libpath = library_path.c_str(); + } + libcudart = dlopen(libpath, RTLD_LAZY); + + if(libcudart == nullptr) { + LOG(FATAL) << "Failed to load libcudart.so: " << dlerror(); + } + + cudaGetDeviceCount = get_symbol(libcudart, "cudaGetDeviceCount"); + cudaGetDevice = get_symbol(libcudart, "cudaGetDevice"); + cudaSetDevice = get_symbol(libcudart, "cudaSetDevice"); + cudaGetErrorString = get_symbol(libcudart, "cudaGetErrorString"); + } +} + +#define CHECK_CUDA(status) do{ \ + cudaError_t err = (status); \ + if (err != 0) { \ + LOG(FATAL) << "CUDA error: " << cudaGetErrorString(err); \ + } \ +} while(0) + +// Hidet exported APIs +DLL void hidet_cuda_set_library_path(const char* path) { + library_path = path; +} + +DLL int hidet_cuda_device_count() { + lazy_load_cuda_runtime(); + int count = 0; + CHECK_CUDA(cudaGetDeviceCount(&count)); + return count; +} + +DLL int hidet_cuda_get_device() { + lazy_load_cuda_runtime(); + int current_device = -1; + CHECK_CUDA(cudaGetDevice(¤t_device)); + return current_device; +} + +DLL void hidet_cuda_set_device(int device) { + lazy_load_cuda_runtime(); + CHECK_CUDA(cudaSetDevice(device)); +} diff --git a/src/hidet/runtime/cuda/utils.h b/src/hidet/runtime/cuda/utils.h new file mode 100644 index 000000000..89879b541 --- /dev/null +++ b/src/hidet/runtime/cuda/utils.h @@ -0,0 +1,12 @@ +#include +#include + +template +inline T get_symbol(void* lib, const char* name) { + T ret = (T)dlsym(lib, name); + if(ret == nullptr) { + LOG(FATAL) << "Failed to load symbol: " << std::endl << " " << dlerror(); + } + return ret; +} + diff --git a/src/hidet/runtime/symbols.cpp b/src/hidet/runtime/symbols.cpp index af3ba8556..ba5d8737d 100644 --- a/src/hidet/runtime/symbols.cpp +++ b/src/hidet/runtime/symbols.cpp @@ -23,7 +23,7 @@ DLL int32_t get_symbol_value(const char* symbol_name) { try { auto it = symbol_mapping.find(symbol_name); if (it == symbol_mapping.end()) { - LOG(FATAL) << "Symbol " << symbol_name << " not found"; + LOG(ERROR) << "Symbol " << symbol_name << " not found"; } return it->second; } catch (HidetException &e) { diff --git a/tests/cuda/test_cublas.py b/tests/cuda/test_cublas.py new file mode 100644 index 000000000..5c04e4b20 --- /dev/null +++ b/tests/cuda/test_cublas.py @@ -0,0 +1,93 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import math +import hidet +from hidet.cuda.cublas import cublasComputeType + + +@pytest.mark.parametrize('m, n, k', [[4, 4, 4], [128, 128, 128], [123, 234, 345]]) +@pytest.mark.parametrize( + 'dtype, compute_type, tol', + [ + (hidet.float16, cublasComputeType.CUBLAS_COMPUTE_16F, 1e-2), + (hidet.float32, cublasComputeType.CUBLAS_COMPUTE_32F, 1e-5), + (hidet.float64, cublasComputeType.CUBLAS_COMPUTE_64F, 1e-8), + ], +) +def test_cublas_gemm(m, n, k, dtype, compute_type, tol): + a = hidet.randn((m, k), device='cuda', dtype=dtype) / math.sqrt(k) + b = hidet.randn((k, n), device='cuda', dtype=dtype) / math.sqrt(k) + c = hidet.empty((m, n), device='cuda', dtype=dtype) + hidet.cuda.cublas.gemm(m, n, k, a.dtype, b.dtype, c.dtype, a, b, c, compute_type) + hidet.utils.assert_close(actual=c, expected=a @ b, rtol=tol, atol=tol) + + +@pytest.mark.parametrize('bs, m, n, k', [[3, 4, 4, 4], [4, 128, 128, 128], [5, 123, 234, 345]]) +@pytest.mark.parametrize( + 'dtype, compute_type, tol', + [ + (hidet.float16, cublasComputeType.CUBLAS_COMPUTE_16F, 1e-2), + (hidet.float32, cublasComputeType.CUBLAS_COMPUTE_32F, 1e-5), + (hidet.float64, cublasComputeType.CUBLAS_COMPUTE_64F, 1e-8), + ], +) +def test_cublas_strided_gemm(bs, m, n, k, dtype, compute_type, tol): + a = hidet.randn((bs, m, k), device='cuda', dtype=dtype) / math.sqrt(k) + b = hidet.randn((bs, k, n), device='cuda', dtype=dtype) / math.sqrt(k) + c = hidet.empty((bs, m, n), device='cuda', dtype=dtype) + hidet.cuda.cublas.strided_gemm(bs, m, n, k, a.dtype, b.dtype, c.dtype, a, b, c, m * k, k * n, m * n, compute_type) + hidet.utils.assert_close(actual=c, expected=a @ b, rtol=tol, atol=tol) + + +def test_cublas_library_func(): + from hidet.lang import attrs + from hidet.lang.cuda import cublas + from hidet.lang.types import f32, i32 + + with hidet.script_module() as script_module: + + @hidet.script + def launch(m_size: i32, n_size: i32, k_size: i32, a: ~f32, b: ~f32, c: ~f32): + attrs.func_kind = 'public' + + cublas.gemm( + m_size, + n_size, + k_size, + cublas.as_type_code(f32), + cublas.as_type_code(f32), + cublas.as_type_code(f32), + a, + b, + c, + False, + False, + cublas.cublasComputeType.CUBLAS_COMPUTE_32F, + ) + + func = script_module.build() + + m = 234 + n = 345 + k = 456 + + a = hidet.randn((m, k), device='cuda', dtype=hidet.float32) / math.sqrt(k) + b = hidet.randn((k, n), device='cuda', dtype=hidet.float32) / math.sqrt(k) + c = hidet.empty((m, n), device='cuda', dtype=hidet.float32) + + func(m, n, k, a, b, c) + hidet.utils.assert_close(actual=c, expected=a @ b, rtol=1e-5, atol=1e-5) + + +if __name__ == '__main__': + pytest.main([__file__])