diff --git a/.github/workflows/CI-localjll.yml b/.github/workflows/CI-localjll.yml index 74e64565f7..68d6d47180 100644 --- a/.github/workflows/CI-localjll.yml +++ b/.github/workflows/CI-localjll.yml @@ -31,7 +31,10 @@ jobs: os: - linux-x86-n2-32 - macOS-latest + - linux-x86-a2-48-a100-4gpu exclude: + - os: linux-x86-a2-48-a100-4gpu + version: "1.10" - os: macOS-latest version: "1.10" uses: ./.github/workflows/CommonCI.yml diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 2b18ac06f5..4e5e7d5a5b 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -1005,9 +1005,9 @@ REACTANT_ABI PjRtBuffer *ArrayFromHostBuffer(PjRtClient *client, void *data, return bres; } - REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer, - void *data, size_t offset, size_t size, PjRtBuffer **bufferP) { + void *data, size_t offset, size_t size, + PjRtBuffer **bufferP) { if (buffer->IsOnCpu()) { auto unsafe = (char *)MyValueOrThrow(buffer->client()->UnsafeBufferPointer(buffer)); @@ -1016,12 +1016,13 @@ REACTANT_ABI void CopyToBuffer(PjRtClient *client, PjRtBuffer *buffer, // data, size); return; } - + auto pid = client->platform_id(); if (pid == xla::TpuId()) { auto dims = buffer->on_device_shape().dimensions(); // TODO: note this assume that we want to copy the entire buffer size. - auto buf2 = ArrayFromHostBuffer(client, data, buffer->element_type(), dims.size(), dims.data(), buffer->device()); + auto buf2 = ArrayFromHostBuffer(client, data, buffer->element_type(), + dims.size(), dims.data(), buffer->device()); *bufferP = buf2; PjRtBufferFree((PjRtBuffer *)buffer); return; @@ -1075,9 +1076,9 @@ REACTANT_ABI void BufferToHost(PjRtBuffer *buffer, void *data) { } } - REACTANT_ABI void CopyFromBuffer(PjRtClient *client, PjRtBuffer *buffer, - void *data, size_t offset, size_t size, PjRtBuffer **bufferP) { + void *data, size_t offset, size_t size, + PjRtBuffer **bufferP) { auto pid = client->platform_id(); if (pid == xla::TpuId()) { @@ -3069,7 +3070,7 @@ struct LinkableRuntime { executables; // Set of allocated pointers to size - std::set> allocations; + std::set> allocations; LinkableRuntime(const std::string &backend) : registry() { InitializeRegistry(wrap(®istry)); @@ -3217,7 +3218,8 @@ REACTANT_ABI void reactantXLAExec(LinkableRuntime **__restrict__ lrtP, for (int64_t i = 0; i < argcnt; i++) { auto &&[argB, argO, argP] = bufferAndOffset(lrt, args[i]); if (argO != 0) { - llvm::errs() << "only zero-offset execution supported, argument " << i << " had byte offset of " << argO << "\n"; + llvm::errs() << "only zero-offset execution supported, argument " << i + << " had byte offset of " << argO << "\n"; exit(1); } baseArrays[i] = argB; @@ -3443,8 +3445,7 @@ class GPUPerformanceModel { fusion_analysis_cache_(device_description_), gpu_hlo_cost_analysis_(hlo_cost_analysis_options_, device_description_), gpu_performance_model_(device_description_, fusion_analysis_cache_, - gpu_performance_model_cache_, - mlir_context_) {} + gpu_performance_model_cache_, mlir_context_) {} void RunAnalysisOnHloModule(std::shared_ptr hlo_module) { hlo_module->entry_computation()->Accept(&gpu_hlo_cost_analysis_); diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 7be2938210..9c26a1c06e 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -1059,6 +1059,7 @@ cc_library( "-Wl,-exported_symbol,_CreateGPUPerformanceModel", "-Wl,-exported_symbol,_RunAnalysisOnHloModule", "-Wl,-exported_symbol,_EstimateRunTimeForInstruction", + "-Wl,-exported_symbol,_registerReactantXLAFFI", ], }), linkstatic = True, diff --git a/deps/ReactantExtra/xla_ffi.cpp b/deps/ReactantExtra/xla_ffi.cpp new file mode 100644 index 0000000000..755af0e129 --- /dev/null +++ b/deps/ReactantExtra/xla_ffi.cpp @@ -0,0 +1,261 @@ +#include "absl/strings/str_format.h" + +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" + +#include "mlir/CAPI/IR.h" + +#if defined(REACTANT_CUDA) +#include "jaxlib/ffi_helpers.h" +#include "jaxlib/gpu/blas_handle_pool.h" +#endif + +#define REACTANT_ABI extern "C" MLIR_CAPI_EXPORTED + +using namespace xla; + +namespace reactant { +namespace cuda { + +#if defined(REACTANT_CUDA) + +#include "third_party/gpus/cuda/include/cuComplex.h" +#include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_fp8.h" +#include "third_party/gpus/cuda/include/cufft.h" +#include "third_party/gpus/cuda/include/cusolverDn.h" +#include "third_party/gpus/cuda/include/cusolver_common.h" + +using namespace jax; + +#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \ + switch (dataType) { \ + case ffi::F32: \ + return impl(__VA_ARGS__); \ + case ffi::F64: \ + return impl(__VA_ARGS__); \ + case ffi::C64: \ + return impl(__VA_ARGS__); \ + case ffi::C128: \ + return impl(__VA_ARGS__); \ + default: \ + break; \ + } + +template +ffi::Error GetHostScalar(CUstream stream, bool use_attribute, double value_real, + double value_imag, ffi::AnyBuffer buffer, + T *host_value) { + if (use_attribute) { + if constexpr (std::is_same::value) { + *host_value = static_cast(value_real); + } else if constexpr (std::is_same::value) { + *host_value = value_real; + } else if constexpr (std::is_same::value) { + *host_value = cuComplex{static_cast(value_real), + static_cast(value_imag)}; + } else if constexpr (std::is_same::value) { + *host_value = cuDoubleComplex{value_real, value_imag}; + } + } else { + // Ensure buffer has exactly 1 element + if (buffer.element_count() != 1) { + return ffi::Error::InvalidArgument( + absl::StrFormat("Expected scalar buffer with 1 element, got %d", + buffer.element_count())); + } + // memcpy to host + cudaMemcpyAsync(host_value, buffer.untyped_data(), sizeof(T), + cudaMemcpyDeviceToHost, stream); + } + return ffi::Error::Success(); +} + +inline ffi::Error CublasStatusToError(cublasStatus_t status, + const char *op_name) { + if (status == CUBLAS_STATUS_SUCCESS) { + return ffi::Error::Success(); + } + const char *error_name; + switch (status) { + case CUBLAS_STATUS_NOT_INITIALIZED: + error_name = "CUBLAS_STATUS_NOT_INITIALIZED"; + break; + case CUBLAS_STATUS_ALLOC_FAILED: + error_name = "CUBLAS_STATUS_ALLOC_FAILED"; + break; + case CUBLAS_STATUS_INVALID_VALUE: + error_name = "CUBLAS_STATUS_INVALID_VALUE"; + break; + case CUBLAS_STATUS_ARCH_MISMATCH: + error_name = "CUBLAS_STATUS_ARCH_MISMATCH"; + break; + case CUBLAS_STATUS_MAPPING_ERROR: + error_name = "CUBLAS_STATUS_MAPPING_ERROR"; + break; + case CUBLAS_STATUS_EXECUTION_FAILED: + error_name = "CUBLAS_STATUS_EXECUTION_FAILED"; + break; + case CUBLAS_STATUS_INTERNAL_ERROR: + error_name = "CUBLAS_STATUS_INTERNAL_ERROR"; + break; + case CUBLAS_STATUS_NOT_SUPPORTED: + error_name = "CUBLAS_STATUS_NOT_SUPPORTED"; + break; + default: + error_name = "UNKNOWN"; + break; + } + return ffi::Error::InvalidArgument( + absl::StrFormat("%s failed with status %s", op_name, error_name)); +} + +namespace blas { + +template +ffi::Error Syrk(cublasHandle_t handle, cublasFillMode_t uplo, + cublasOperation_t trans, int n, int k, const T *alpha, + const T *a, int lda, const T *beta, T *c, int ldc) { + return ffi::Error::InvalidArgument("Unsupported type for syrk"); +} + +#define SYRK_SPECIALIZATION(T, cublas_func) \ + template <> \ + ffi::Error Syrk(cublasHandle_t handle, cublasFillMode_t uplo, \ + cublasOperation_t trans, int n, int k, const T *alpha, \ + const T *a, int lda, const T *beta, T *c, int ldc) { \ + cublasStatus_t status = \ + cublas_func(handle, uplo, trans, n, k, alpha, a, lda, beta, c, ldc); \ + return CublasStatusToError(status, #cublas_func); \ + } + +SYRK_SPECIALIZATION(float, cublasSsyrk) +SYRK_SPECIALIZATION(double, cublasDsyrk) +SYRK_SPECIALIZATION(cuComplex, cublasCsyrk) +SYRK_SPECIALIZATION(cuDoubleComplex, cublasZsyrk) + +#undef SYRK_SPECIALIZATION + +} // namespace blas + +// Symmetric rank-k update: syrk + +template +ffi::Error SyrkImpl(CUstream stream, bool transpose, bool uplo_, + ffi::AnyBuffer a, ffi::AnyBuffer c_in, const T *alpha, + const T *beta, ffi::Result c_out) { + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + auto size = transpose ? cols : rows; + FFI_RETURN_IF_ERROR( + CheckShape(c_in.dimensions(), {batch, size, size}, "c_in", "syrk")); + FFI_RETURN_IF_ERROR( + CheckShape(c_out->dimensions(), {batch, size, size}, "c_out", "syrk")); + + FFI_ASSIGN_OR_RETURN(auto n, + MaybeCastNoOverflow(transpose ? cols : rows)); + FFI_ASSIGN_OR_RETURN(auto k, + MaybeCastNoOverflow(transpose ? rows : cols)); + cublasFillMode_t uplo = + uplo_ ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + cublasOperation_t trans = transpose ? CUBLAS_OP_T : CUBLAS_OP_N; + + const T *a_data = static_cast(a.untyped_data()); + T *c_data = static_cast(c_in.untyped_data()); + T *c_out_data = static_cast(c_out->untyped_data()); + + if (c_data != c_out_data) { + cudaError_t err = cudaMemcpyAsync(c_out_data, c_data, c_in.size_bytes(), + cudaMemcpyDeviceToDevice, stream); + if (err != cudaSuccess) { + return ffi::Error::InvalidArgument(absl::StrFormat( + "cudaMemcpyAsync failed: %s", cudaGetErrorString(err))); + } + } + FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); + // lda is the leading dimension of a, ldc is the leading dimension of c + // For column-major (which cuBLAS uses), lda = number of rows of a, ldc = n + int lda = transpose ? k : n; + int ldc = n; + for (int i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR(blas::Syrk(handle.get(), uplo, trans, n, k, alpha, + a_data, lda, beta, c_out_data, ldc)); + a_data += k * n; + c_out_data += n * n; + } + return ffi::Error::Success(); +} + +template +ffi::Error SyrkImpl(CUstream stream, bool transpose, bool uplo, + bool use_alpha_attribute, double alpha_real, + double alpha_imag, bool use_beta_attribute, + double beta_real, double beta_imag, ffi::AnyBuffer a, + ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_, + ffi::AnyBuffer beta_, ffi::Result c_out) { + T host_alpha, host_beta; + FFI_RETURN_IF_ERROR(GetHostScalar(stream, use_alpha_attribute, alpha_real, + alpha_imag, alpha_, &host_alpha)); + FFI_RETURN_IF_ERROR(GetHostScalar(stream, use_beta_attribute, beta_real, + beta_imag, beta_, &host_beta)); + return SyrkImpl(stream, transpose, uplo, a, c_in, &host_alpha, &host_beta, + c_out); +} + +ffi::Error SyrkDispatch(CUstream stream, bool transpose, bool uplo, + bool use_alpha_attribute, double alpha_real, + double alpha_imag, bool use_beta_attribute, + double beta_real, double beta_imag, ffi::AnyBuffer a, + ffi::AnyBuffer c_in, ffi::AnyBuffer alpha_, + ffi::AnyBuffer beta_, + ffi::Result c_out) { + auto dataType = c_in.element_type(); + SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, transpose, uplo, + use_alpha_attribute, alpha_real, alpha_imag, + use_beta_attribute, beta_real, beta_imag, a, c_in, + alpha_, beta_, c_out); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in syrk", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER( + SyrkFfi, SyrkDispatch, + xla::ffi::Ffi::Bind() + .Ctx>() + .Attr("transpose") // transpose + .Attr("uplo") // uplo + .Attr("use_alpha_attribute") // use_alpha_attribute + .Attr("alpha_real") // alpha_real + .Attr("alpha_imag") // alpha_imag + .Attr("use_beta_attribute") // use_beta_attribute + .Attr("beta_real") // beta_real + .Attr("beta_imag") // beta_imag + .Arg() // a + .Arg() // c_in + .Arg() // alpha + .Arg() // beta + .Ret() // c_out +); + +void registerReactantXLACUDAFFI() { + XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "reactant_cublas_syrk_ffi", + "CUDA", SyrkFfi); +} + +#undef SOLVER_BLAS_DISPATCH_IMPL + +#else + +void registerReactantXLACUDAFFI() {} + +#endif + +} // namespace cuda +} // namespace reactant + +REACTANT_ABI void registerReactantXLAFFI() { + reactant::cuda::registerReactantXLACUDAFFI(); + return; +}