11#include " absl/strings/str_format.h"
22
3- #include " jaxlib/ffi_helpers.h"
4- #include " jaxlib/gpu/blas_handle_pool.h"
53#include " xla/ffi/api/c_api.h"
64#include " xla/ffi/api/ffi.h"
75#include " xla/ffi/ffi_api.h"
108
119#define REACTANT_ABI extern " C" MLIR_CAPI_EXPORTED
1210
13- using namespace jax ;
1411using namespace xla ;
1512
1613namespace reactant {
1714namespace cuda {
1815
1916#if defined(REACTANT_CUDA)
2017
18+ #include " jaxlib/ffi_helpers.h"
19+ #include " jaxlib/gpu/blas_handle_pool.h"
20+
2121#include " third_party/gpus/cuda/include/cuComplex.h"
2222#include " third_party/gpus/cuda/include/cublas_v2.h"
2323#include " third_party/gpus/cuda/include/cuda.h"
@@ -26,6 +26,8 @@ namespace cuda {
2626#include " third_party/gpus/cuda/include/cusolverDn.h"
2727#include " third_party/gpus/cuda/include/cusolver_common.h"
2828
29+ using namespace jax ;
30+
2931#define SOLVER_BLAS_DISPATCH_IMPL (impl, ...) \
3032 switch (dataType) { \
3133 case ffi::F32: \
@@ -41,26 +43,32 @@ namespace cuda {
4143 }
4244
4345template <typename T>
44- T GetHostScalar (CUstream stream, bool use_attribute, double value_real,
45- double value_imag, ffi::AnyBuffer buffer) {
46- T host_value;
46+ ffi::Error GetHostScalar (CUstream stream, bool use_attribute, double value_real,
47+ double value_imag, ffi::AnyBuffer buffer,
48+ T * host_value) {
4749 if (use_attribute) {
4850 if constexpr (std::is_same<T, float >::value) {
49- host_value = static_cast <float >(value_real);
51+ * host_value = static_cast <float >(value_real);
5052 } else if constexpr (std::is_same<T, double >::value) {
51- host_value = value_real;
53+ * host_value = value_real;
5254 } else if constexpr (std::is_same<T, cuComplex>::value) {
53- host_value = cuComplex{static_cast <float >(value_real),
54- static_cast <float >(value_imag)};
55+ * host_value = cuComplex{static_cast <float >(value_real),
56+ static_cast <float >(value_imag)};
5557 } else if constexpr (std::is_same<T, cuDoubleComplex>::value) {
56- host_value = cuDoubleComplex{value_real, value_imag};
58+ * host_value = cuDoubleComplex{value_real, value_imag};
5759 }
5860 } else {
61+ // Ensure buffer has exactly 1 element
62+ if (buffer.element_count () != 1 ) {
63+ return ffi::Error::InvalidArgument (
64+ absl::StrFormat (" Expected scalar buffer with 1 element, got %d" ,
65+ buffer.element_count ()));
66+ }
5967 // memcpy to host
60- cudaMemcpyAsync (& host_value, buffer.untyped_data (), sizeof (T),
68+ cudaMemcpyAsync (host_value, buffer.untyped_data (), sizeof (T),
6169 cudaMemcpyDeviceToHost, stream);
6270 }
63- return host_value ;
71+ return ffi::Error::Success () ;
6472}
6573
6674inline ffi::Error CublasStatusToError (cublasStatus_t status,
@@ -185,10 +193,11 @@ ffi::Error SyrkImpl(CUstream stream, bool transpose, bool uplo,
185193 double beta_real, double beta_imag, ffi::AnyBuffer a,
186194 ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_,
187195 ffi::AnyBuffer beta_, ffi::Result<ffi::AnyBuffer> c_out) {
188- T host_alpha = GetHostScalar<T>(stream, use_alpha_attribute, alpha_real,
189- alpha_imag, alpha_);
190- T host_beta =
191- GetHostScalar<T>(stream, use_beta_attribute, beta_real, beta_imag, beta_);
196+ T host_alpha, host_beta;
197+ FFI_RETURN_IF_ERROR (GetHostScalar<T>(stream, use_alpha_attribute, alpha_real,
198+ alpha_imag, alpha_, &host_alpha));
199+ FFI_RETURN_IF_ERROR (GetHostScalar<T>(stream, use_beta_attribute, beta_real,
200+ beta_imag, beta_, &host_beta));
192201 return SyrkImpl<T>(stream, transpose, uplo, a, c_in, &host_alpha, &host_beta,
193202 c_out);
194203}
0 commit comments