Skip to content

Commit 2edeaeb

Browse files
committed
fix: cpu build
1 parent 0e41c3f commit 2edeaeb

File tree

2 files changed

+27
-18
lines changed

2 files changed

+27
-18
lines changed

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "ead4414a40c594814a129adb54934720a0140c86"
7+
ENZYMEXLA_COMMIT = "245a66b57e9a3b7c23a0225d8eabfc2825761029"
88

99
ENZYMEXLA_SHA256 = ""
1010

deps/ReactantExtra/xla_ffi.cpp

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#include "absl/strings/str_format.h"
22

3-
#include "jaxlib/ffi_helpers.h"
4-
#include "jaxlib/gpu/blas_handle_pool.h"
53
#include "xla/ffi/api/c_api.h"
64
#include "xla/ffi/api/ffi.h"
75
#include "xla/ffi/ffi_api.h"
@@ -10,14 +8,16 @@
108

119
#define REACTANT_ABI extern "C" MLIR_CAPI_EXPORTED
1210

13-
using namespace jax;
1411
using namespace xla;
1512

1613
namespace reactant {
1714
namespace cuda {
1815

1916
#if defined(REACTANT_CUDA)
2017

18+
#include "jaxlib/ffi_helpers.h"
19+
#include "jaxlib/gpu/blas_handle_pool.h"
20+
2121
#include "third_party/gpus/cuda/include/cuComplex.h"
2222
#include "third_party/gpus/cuda/include/cublas_v2.h"
2323
#include "third_party/gpus/cuda/include/cuda.h"
@@ -26,6 +26,8 @@ namespace cuda {
2626
#include "third_party/gpus/cuda/include/cusolverDn.h"
2727
#include "third_party/gpus/cuda/include/cusolver_common.h"
2828

29+
using namespace jax;
30+
2931
#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \
3032
switch (dataType) { \
3133
case ffi::F32: \
@@ -41,26 +43,32 @@ namespace cuda {
4143
}
4244

4345
template <typename T>
44-
T GetHostScalar(CUstream stream, bool use_attribute, double value_real,
45-
double value_imag, ffi::AnyBuffer buffer) {
46-
T host_value;
46+
ffi::Error GetHostScalar(CUstream stream, bool use_attribute, double value_real,
47+
double value_imag, ffi::AnyBuffer buffer,
48+
T *host_value) {
4749
if (use_attribute) {
4850
if constexpr (std::is_same<T, float>::value) {
49-
host_value = static_cast<float>(value_real);
51+
*host_value = static_cast<float>(value_real);
5052
} else if constexpr (std::is_same<T, double>::value) {
51-
host_value = value_real;
53+
*host_value = value_real;
5254
} else if constexpr (std::is_same<T, cuComplex>::value) {
53-
host_value = cuComplex{static_cast<float>(value_real),
54-
static_cast<float>(value_imag)};
55+
*host_value = cuComplex{static_cast<float>(value_real),
56+
static_cast<float>(value_imag)};
5557
} else if constexpr (std::is_same<T, cuDoubleComplex>::value) {
56-
host_value = cuDoubleComplex{value_real, value_imag};
58+
*host_value = cuDoubleComplex{value_real, value_imag};
5759
}
5860
} else {
61+
// Ensure buffer has exactly 1 element
62+
if (buffer.element_count() != 1) {
63+
return ffi::Error::InvalidArgument(
64+
absl::StrFormat("Expected scalar buffer with 1 element, got %d",
65+
buffer.element_count()));
66+
}
5967
// memcpy to host
60-
cudaMemcpyAsync(&host_value, buffer.untyped_data(), sizeof(T),
68+
cudaMemcpyAsync(host_value, buffer.untyped_data(), sizeof(T),
6169
cudaMemcpyDeviceToHost, stream);
6270
}
63-
return host_value;
71+
return ffi::Error::Success();
6472
}
6573

6674
inline ffi::Error CublasStatusToError(cublasStatus_t status,
@@ -185,10 +193,11 @@ ffi::Error SyrkImpl(CUstream stream, bool transpose, bool uplo,
185193
double beta_real, double beta_imag, ffi::AnyBuffer a,
186194
ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_,
187195
ffi::AnyBuffer beta_, ffi::Result<ffi::AnyBuffer> c_out) {
188-
T host_alpha = GetHostScalar<T>(stream, use_alpha_attribute, alpha_real,
189-
alpha_imag, alpha_);
190-
T host_beta =
191-
GetHostScalar<T>(stream, use_beta_attribute, beta_real, beta_imag, beta_);
196+
T host_alpha, host_beta;
197+
FFI_RETURN_IF_ERROR(GetHostScalar<T>(stream, use_alpha_attribute, alpha_real,
198+
alpha_imag, alpha_, &host_alpha));
199+
FFI_RETURN_IF_ERROR(GetHostScalar<T>(stream, use_beta_attribute, beta_real,
200+
beta_imag, beta_, &host_beta));
192201
return SyrkImpl<T>(stream, transpose, uplo, a, c_in, &host_alpha, &host_beta,
193202
c_out);
194203
}

0 commit comments

Comments
 (0)