Skip to content

Commit

Permalink
Impl native command for cublas
Browse files Browse the repository at this point in the history
See SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND extension document for
details.

Generalize helpers funcs and use them for blas l1, l2, l3, batch

Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
  • Loading branch information
JackAKirk committed Sep 19, 2024
1 parent e19072c commit 44867dc
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 115 deletions.
96 changes: 74 additions & 22 deletions src/blas/backends/cublas/cublas_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,25 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran
auto b_ = sc.get_mem<cuTypeB *>(b_acc);
auto c_ = sc.get_mem<cuTypeC *>(c_acc);
cublasStatus_t err;
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c_, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx,
err, handle, get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_,
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta, c_,
get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx",
cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_,
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta,
c_, get_cublas_datatype<cuTypeC>(), ldc, stride_c,
batch_size, get_cublas_datatype<cuTypeS>(),
cublas_gemm_algo);
#endif
});
});
}
Expand Down Expand Up @@ -608,12 +621,25 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra
onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
cublasStatus_t err;
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx,
err, handle, get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b,
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta, c,
get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx",
cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b,
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta,
c, get_cublas_datatype<cuTypeC>(), ldc, stride_c,
batch_size, get_cublas_datatype<cuTypeS>(),
cublas_gemm_algo);
#endif
});
});
return done;
Expand Down Expand Up @@ -687,14 +713,28 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr
int64_t offset = 0;
cublasStatus_t err;
for (int64_t i = 0; i < group_count; i++) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T("cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle,
get_cublas_operation(transa[i]),
get_cublas_operation(transb[i]), (int)m[i], (int)n[i],
(int)k[i], &alpha[i], (const void *const *)(a + offset),
get_cublas_datatype<cuTypeA>(), (int)lda[i],
(const void *const *)(b + offset),
get_cublas_datatype<cuTypeB>(), (int)ldb[i], &beta[i],
(void *const *)(c + offset), get_cublas_datatype<cuTypeC>(),
(int)ldc[i], (int)group_size[i],
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle,
get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i],
(int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset),
get_cublas_datatype<cuTypeA>(), (int)lda[i], (const void *const *)(b + offset),
get_cublas_datatype<cuTypeB>(), (int)ldb[i], &beta[i],
(void *const *)(c + offset), get_cublas_datatype<cuTypeC>(), (int)ldc[i],
(int)group_size[i], get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
get_cublas_datatype<cuTypeA>(), (int)lda[i],
(const void *const *)(b + offset), get_cublas_datatype<cuTypeB>(),
(int)ldb[i], &beta[i], (void *const *)(c + offset),
get_cublas_datatype<cuTypeC>(), (int)ldc[i], (int)group_size[i],
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#endif
offset += group_size[i];
}
});
Expand Down Expand Up @@ -792,12 +832,24 @@ inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &que
for (int64_t i = 0; i < group_count; i++) {
auto **a_ = reinterpret_cast<const cuDataType **>(a);
auto **b_ = reinterpret_cast<cuDataType **>(b);
CUBLAS_ERROR_FUNC_T_SYNC(
func_name, func, err, handle, get_cublas_side_mode(left_right[i]),
get_cublas_fill_mode(upper_lower[i]), get_cublas_operation(trans[i]),
get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i],
(cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i],
(int)group_size[i]);
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T(func_name, func, err, handle,
get_cublas_side_mode(left_right[i]),
get_cublas_fill_mode(upper_lower[i]),
get_cublas_operation(trans[i]),
get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i],
(cuDataType *)&alpha[i], a_ + offset, (int)lda[i],
b_ + offset, (int)ldb[i], (int)group_size[i]);
#else
CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle,
get_cublas_side_mode(left_right[i]),
get_cublas_fill_mode(upper_lower[i]),
get_cublas_operation(trans[i]),
get_cublas_diag_type(unit_diag[i]), (int)m[i],
(int)n[i], (cuDataType *)&alpha[i], a_ + offset,
(int)lda[i], b_ + offset, (int)ldb[i],
(int)group_size[i]);
#endif
offset += group_size[i];
}
});
Expand Down
27 changes: 27 additions & 0 deletions src/blas/backends/cublas/cublas_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,12 @@ class cuda_error : virtual public std::runtime_error {
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId); \
cuStreamSynchronize(currentStreamId);

#define CUBLAS_ERROR_FUNC_T(name, func, err, handle, ...) \
err = func(handle, __VA_ARGS__); \
if (err != CUBLAS_STATUS_SUCCESS) { \
throw cublas_error(std::string(name) + std::string(" : "), err); \
}

#define CUBLAS_ERROR_FUNC_T_SYNC(name, func, err, handle, ...) \
err = func(handle, __VA_ARGS__); \
if (err != CUBLAS_STATUS_SUCCESS) { \
Expand All @@ -199,6 +205,27 @@ class cuda_error : virtual public std::runtime_error {
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId); \
cuStreamSynchronize(currentStreamId);

template <class Func, class... Types>
inline void cublas_native_func(Func func, cublasStatus_t err,
cublasHandle_t handle, Types... args) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC(func, err, handle, args...)
#else
CUBLAS_ERROR_FUNC_SYNC(func, err, handle, args...)
#endif
};

template <class Func, class... Types>
inline void cublas_native_named_func(const char *func_name, Func func,
cublasStatus_t err, cublasHandle_t handle,
Types... args) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T(func_name, func, err, handle, args...)
#else
CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, args...)
#endif
};

inline cublasOperation_t get_cublas_operation(oneapi::mkl::transpose trn) {
switch (trn) {
case oneapi::mkl::transpose::nontrans: return CUBLAS_OP_N;
Expand Down
Loading

0 comments on commit 44867dc

Please sign in to comment.