Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions src/blas/backends/cublas/cublas_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ inline void gemm_batch(const char *func_name, Func func, sycl::queue &queue, tra
auto b_ = sc.get_mem<cuDataType *>(b_acc);
auto c_ = sc.get_mem<cuDataType *>(c_acc);
cublasStatus_t err;
CUBLAS_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, a_,
lda, stride_a, b_, ldb, stride_b, (cuDataType *)&beta, c_, ldc,
stride_c, batch_size);
CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha,
a_, lda, stride_a, b_, ldb, stride_b, (cuDataType *)&beta, c_,
ldc, stride_c, batch_size);
});
});
}
Expand Down Expand Up @@ -495,10 +495,10 @@ inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &que
auto b_ = reinterpret_cast<const cuDataType *>(b);
auto c_ = reinterpret_cast<cuDataType *>(c);
cublasStatus_t err;
CUBLAS_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, a_,
lda, stride_a, b_, ldb, stride_b, (cuDataType *)&beta, c_, ldc,
stride_c, batch_size);
CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha,
a_, lda, stride_a, b_, ldb, stride_b, (cuDataType *)&beta, c_,
ldc, stride_c, batch_size);
});
});
return done;
Expand Down Expand Up @@ -550,11 +550,11 @@ inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &que
auto **a_ = reinterpret_cast<const cuDataType **>(a);
auto **b_ = reinterpret_cast<const cuDataType **>(b);
auto **c_ = reinterpret_cast<cuDataType **>(c);
CUBLAS_ERROR_FUNC_T(func_name, func, err, handle, get_cublas_operation(transa[i]),
get_cublas_operation(transb[i]), (int)m[i], (int)n[i],
(int)k[i], (cuDataType *)&alpha[i], a_ + offset, (int)lda[i],
b_ + offset, (int)ldb[i], (cuDataType *)&beta[i], c_ + offset,
(int)ldc[i], (int)group_size[i]);
CUBLAS_ERROR_FUNC_T_SYNC(
func_name, func, err, handle, get_cublas_operation(transa[i]),
get_cublas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i],
(cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i],
(cuDataType *)&beta[i], c_ + offset, (int)ldc[i], (int)group_size[i]);
offset += group_size[i];
}
});
Expand Down Expand Up @@ -632,7 +632,7 @@ inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &que
for (int64_t i = 0; i < group_count; i++) {
auto **a_ = reinterpret_cast<const cuDataType **>(a);
auto **b_ = reinterpret_cast<cuDataType **>(b);
CUBLAS_ERROR_FUNC_T(
CUBLAS_ERROR_FUNC_T_SYNC(
func_name, func, err, handle, get_cublas_side_mode(left_right[i]),
get_cublas_fill_mode(upper_lower[i]), get_cublas_operation(trans[i]),
get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i],
Expand Down
18 changes: 15 additions & 3 deletions src/blas/backends/cublas/cublas_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,23 @@ class cuda_error : virtual public std::runtime_error {
throw cublas_error(std::string(#name) + std::string(" : "), err); \
}

#define CUBLAS_ERROR_FUNC_T(name, func, err, ...) \
err = func(__VA_ARGS__); \
#define CUBLAS_ERROR_FUNC_SYNC(name, err, handle, ...) \
err = name(handle, __VA_ARGS__); \
if (err != CUBLAS_STATUS_SUCCESS) { \
throw cublas_error(std::string(#name) + std::string(" : "), err); \
} \
cudaStream_t currentStreamId; \
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId); \
cuStreamSynchronize(currentStreamId);

#define CUBLAS_ERROR_FUNC_T_SYNC(name, func, err, handle, ...) \
err = func(handle, __VA_ARGS__); \
if (err != CUBLAS_STATUS_SUCCESS) { \
throw cublas_error(std::string(name) + std::string(" : "), err); \
}
} \
cudaStream_t currentStreamId; \
CUBLAS_ERROR_FUNC(cublasGetStream, err, handle, &currentStreamId); \
cuStreamSynchronize(currentStreamId);

inline cublasOperation_t get_cublas_operation(oneapi::mkl::transpose trn) {
switch (trn) {
Expand Down
Loading