diff --git a/jax/_src/config.py b/jax/_src/config.py
index 215ef443c799..c90138df1fe9 100644
--- a/jax/_src/config.py
+++ b/jax/_src/config.py
@@ -1953,3 +1953,20 @@ def _update_garbage_collection_guard(state, key, val):
),
include_in_jit_key=True,
)
+
+gpu_use_magma = enum_state(
+ name='jax_gpu_use_magma',
+ enum_values=[
+ 'off', 'on',
+ # TODO(danfm): After doing more complete benchmarks, add an 'auto'
+ # option which will automatically enable MAGMA for problem sizes where
+ # it typically gets better performance.
+ # 'auto',
+ ],
+ default='off',
+ help=(
+ 'Enable experimental support for MAGMA-backed lax.linalg.eig on GPU. '
+ 'See the documentation for lax.linalg.eig for more details about how '
+ 'to use this feature.'
+ ),
+)
diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py
index 0e0390abc78f..0240416ca29c 100644
--- a/jax/_src/lax/linalg.py
+++ b/jax/_src/lax/linalg.py
@@ -124,7 +124,25 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
compute_right_eigenvectors: bool = True) -> list[Array]:
"""Eigendecomposition of a general matrix.
- Nonsymmetric eigendecomposition is at present only implemented on CPU.
+ Nonsymmetric eigendecomposition is only implemented on CPU and GPU. On GPU,
+ the default implementation calls LAPACK directly on the host CPU, but an
+ experimental GPU implementation using `MAGMA `_
+ is also available. The MAGMA implementation is typically slower than the
+ equivalent LAPACK implementation for small matrices (less than about 2048),
+ but it may perform better for larger matrices.
+
+ To enable the MAGMA implementation, you must install MAGMA yourself (there
+ are Debian and conda-forge packages, or you can build from source). Then set
+ the `jax_gpu_use_magma` configuration variable to `"on"`:
+
+ .. code-block:: python
+
+ jax.config.update('jax_gpu_use_magma', 'on')
+
+ JAX will try to ``dlopen`` the installed MAGMA shared library, raising an
+ error if it is not found. To explicitly specify the path to the MAGMA
+ library, set the environment variable `JAX_GPU_MAGMA_PATH` to the full
+ installation path.
Args:
x: A batch of square matrices with shape ``[..., n, n]``.
@@ -763,6 +781,78 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
return output
+def _eig_gpu_impl(target_name_prefix, x, *, compute_left_eigenvectors,
+ compute_right_eigenvectors):
+ gpu_solver.initialize_hybrid_kernels()
+ dtype = x.dtype
+ is_real = dtype == np.float32 or dtype == np.float64
+ if is_real:
+ target_name = f"{target_name_prefix}hybrid_eig_real"
+ complex_dtype = np.complex64 if dtype == np.float32 else np.complex128
+ else:
+ target_name = f"{target_name_prefix}hybrid_eig_comp"
+ assert dtype == np.complex64 or dtype == np.complex128
+ complex_dtype = dtype
+
+ batch_dims = x.shape[:-2]
+ n, m = x.shape[-2:]
+ assert n == m
+ num_batch_dims = len(batch_dims)
+
+ layout = tuple(range(num_batch_dims)) + (num_batch_dims + 1, num_batch_dims)
+ out_types = [
+ api.ShapeDtypeStruct(batch_dims + (n,), dtype),
+ api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype),
+ api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype),
+ api.ShapeDtypeStruct(batch_dims, np.int32),
+ ]
+ out_layouts = [None, layout, layout, None]
+ if is_real:
+ out_types = [api.ShapeDtypeStruct(batch_dims + (n,), dtype)] + out_types
+ out_layouts = [None] + out_layouts
+
+ magma = False
+ if config.gpu_use_magma.value == 'on':
+ magma = dtype != np.complex128
+
+ fun = ffi.ffi_call(target_name, out_types, input_layouts=[layout],
+ output_layouts=out_layouts)
+ *w, vl, vr, info = fun(x, magma=magma, left=compute_left_eigenvectors,
+ right=compute_right_eigenvectors)
+ if is_real:
+ assert len(w) == 2
+ w = lax.complex(*w)
+ else:
+ assert len(w) == 1
+ w = w[0]
+ ok = lax.eq(info, lax.zeros_like_array(info))
+ ok = _broadcast_to(ok[..., None], w.shape)
+ w = lax.select(ok, w, lax.full_like(w, np.nan + np.nan * 1j))
+ ok = _broadcast_to(ok[..., None], x.shape)
+ output = [w]
+ if compute_left_eigenvectors:
+ vl = lax.select(ok, vl, lax.full_like(vl, np.nan + np.nan * 1j))
+ output.append(vl)
+ if compute_right_eigenvectors:
+ vr = lax.select(ok, vr, lax.full_like(vr, np.nan + np.nan * 1j))
+ output.append(vr)
+ return output
+
+
+def _eig_gpu_lowering(target_name_prefix, ctx, operand, *,
+ compute_left_eigenvectors, compute_right_eigenvectors):
+ if ctx.is_forward_compat():
+ raise NotImplementedError(
+ "Export of nonsymmetric eigendecomposition on GPU is not supported "
+ "because of forward compatibility. The "
+ "'jax_export_ignore_forward_compatibility' configuration option can be "
+ "used to disable this check.")
+ rule = mlir.lower_fun(partial(_eig_gpu_impl, target_name_prefix),
+ multiple_results=True)
+ return rule(ctx, operand, compute_left_eigenvectors=compute_left_eigenvectors,
+ compute_right_eigenvectors=compute_right_eigenvectors)
+
+
def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors,
compute_right_eigenvectors):
x, = batched_args
@@ -793,6 +883,10 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
eig_p.def_abstract_eval(eig_abstract_eval)
mlir.register_lowering(eig_p, eig_lower)
mlir.register_lowering(eig_p, _eig_cpu_lowering, platform='cpu')
+mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'cu'),
+ platform='cuda')
+mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'hip'),
+ platform='rocm')
batching.primitive_batchers[eig_p] = eig_batching_rule
ad.primitive_jvps[eig_p] = eig_jvp_rule
diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py
index 03f864919887..76a4abff48ad 100644
--- a/jax/_src/numpy/linalg.py
+++ b/jax/_src/numpy/linalg.py
@@ -731,7 +731,9 @@ def eig(a: ArrayLike) -> tuple[Array, Array]:
- This differs from :func:`numpy.linalg.eig` in that the return type of
:func:`jax.numpy.linalg.eig` is always complex64 for 32-bit input, and complex128
for 64-bit input.
- - At present, non-symmetric eigendecomposition is only implemented on the CPU backend.
+ - At present, non-symmetric eigendecomposition is only implemented on the CPU and
+ GPU backends. For more details about the GPU implementation, see the
+ documentation for :func:`jax.lax.linalg.eig`.
See also:
- :func:`jax.numpy.linalg.eigh`: eigenvectors and eigenvalues of a Hermitian matrix.
diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc
index 19b82a5ce149..ed815e1b1bd2 100644
--- a/jaxlib/cpu/lapack_kernels.cc
+++ b/jaxlib/cpu/lapack_kernels.cc
@@ -1094,34 +1094,6 @@ template struct EigenvalueDecompositionSymmetric;
template struct EigenvalueDecompositionHermitian;
template struct EigenvalueDecompositionHermitian;
-// LAPACK uses a packed representation to represent a mixture of real
-// eigenvectors and complex conjugate pairs. This helper unpacks the
-// representation into regular complex matrices.
-template
-static void UnpackEigenvectors(lapack_int n, const T* eigenvals_imag,
- const T* packed, std::complex* unpacked) {
- for (int j = 0; j < n;) {
- if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) {
- // Real values in each row without imaginary part
- // Second row of the imaginary part is not provided
- for (int i = 0; i < n; ++i) {
- unpacked[j * n + i] = {packed[j * n + i], 0.};
- }
- ++j;
- } else {
- // Complex values where the real part is in the jth row
- // and the imaginary part is in the next row (j + 1)
- for (int i = 0; i < n; ++i) {
- const T real_part = packed[j * n + i];
- const T imag_part = packed[(j + 1) * n + i];
- unpacked[j * n + i] = {real_part, imag_part};
- unpacked[(j + 1) * n + i] = {real_part, -imag_part};
- }
- j += 2;
- }
- }
-}
-
// lapack geev
template
diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h
index 7d15e494fffc..cddcb1162120 100644
--- a/jaxlib/cpu/lapack_kernels.h
+++ b/jaxlib/cpu/lapack_kernels.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef JAXLIB_CPU_LAPACK_KERNELS_H_
#define JAXLIB_CPU_LAPACK_KERNELS_H_
+#include
#include
#include
#include
@@ -462,6 +463,34 @@ struct EigenvalueDecompositionHermitian {
// lapack geev
+// LAPACK uses a packed representation to represent a mixture of real
+// eigenvectors and complex conjugate pairs. This helper unpacks the
+// representation into regular complex matrices.
+template
+static void UnpackEigenvectors(Int n, const T* eigenvals_imag,
+ const T* packed, std::complex* unpacked) {
+ for (int j = 0; j < n;) {
+ if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) {
+ // Real values in each row without imaginary part
+ // Second row of the imaginary part is not provided
+ for (int i = 0; i < n; ++i) {
+ unpacked[j * n + i] = {packed[j * n + i], 0.};
+ }
+ ++j;
+ } else {
+ // Complex values where the real part is in the jth row
+ // and the imaginary part is in the next row (j + 1)
+ for (int i = 0; i < n; ++i) {
+ const T real_part = packed[j * n + i];
+ const T imag_part = packed[(j + 1) * n + i];
+ unpacked[j * n + i] = {real_part, imag_part};
+ unpacked[(j + 1) * n + i] = {real_part, -imag_part};
+ }
+ j += 2;
+ }
+ }
+}
+
template
struct RealGeev {
using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a,
diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD
index 34e40d12d5be..afce2c000ecc 100644
--- a/jaxlib/cuda/BUILD
+++ b/jaxlib/cuda/BUILD
@@ -476,6 +476,55 @@ pybind_extension(
],
)
+cc_library(
+ name = "cuda_hybrid_kernels",
+ srcs = ["//jaxlib/gpu:hybrid_kernels.cc"],
+ hdrs = ["//jaxlib/gpu:hybrid_kernels.h"],
+ deps = [
+ ":cuda_gpu_kernel_helpers",
+ ":cuda_vendor",
+ "//jaxlib:ffi_helpers",
+ "//jaxlib/cpu:lapack_kernels",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:span",
+ "@xla//xla/ffi/api:ffi",
+ ],
+)
+
+pybind_extension(
+ name = "_hybrid",
+ srcs = ["//jaxlib/gpu:hybrid.cc"],
+ copts = [
+ "-fexceptions",
+ "-fno-strict-aliasing",
+ ],
+ features = ["-use_header_modules"],
+ linkopts = select({
+ "@xla//xla/python:use_jax_cuda_pip_rpaths": [
+ "-Wl,-rpath,$$ORIGIN/../../nvidia/cuda_runtime/lib",
+ ],
+ "//conditions:default": [],
+ }),
+ module_name = "_hybrid",
+ deps = [
+ ":cuda_gpu_kernel_helpers",
+ ":cuda_hybrid_kernels",
+ ":cuda_vendor",
+ "//jaxlib:kernel_nanobind_helpers",
+ "//jaxlib/cpu:lapack_kernels",
+ "@local_config_cuda//cuda:cuda_headers",
+ "@nanobind",
+ "@xla//xla/ffi/api:ffi",
+ "@xla//xla/tsl/cuda:cudart",
+ ],
+)
+
cc_library(
name = "cuda_gpu_kernels",
srcs = ["//jaxlib/gpu:gpu_kernels.cc"],
@@ -633,6 +682,7 @@ py_library(
name = "cuda_gpu_support",
deps = [
":_blas",
+ ":_hybrid",
":_linalg",
":_prng",
":_rnn",
diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD
index 7d50a91cfcda..e888f6a42a9b 100644
--- a/jaxlib/gpu/BUILD
+++ b/jaxlib/gpu/BUILD
@@ -37,6 +37,9 @@ exports_files(srcs = [
"gpu_kernel_helpers.cc",
"gpu_kernel_helpers.h",
"gpu_kernels.cc",
+ "hybrid.cc",
+ "hybrid_kernels.cc",
+ "hybrid_kernels.h",
"linalg.cc",
"linalg_kernels.cc",
"linalg_kernels.cu.cc",
diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc
new file mode 100644
index 000000000000..b7e0becdcc5b
--- /dev/null
+++ b/jaxlib/gpu/hybrid.cc
@@ -0,0 +1,59 @@
+/* Copyright 2021 The JAX Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "nanobind/nanobind.h"
+#include "jaxlib/cpu/lapack_kernels.h"
+#include "jaxlib/gpu/hybrid_kernels.h"
+#include "jaxlib/gpu/vendor.h"
+#include "jaxlib/kernel_nanobind_helpers.h"
+#include "xla/ffi/api/ffi.h"
+
+namespace jax {
+namespace JAX_GPU_NAMESPACE {
+namespace {
+namespace ffi = xla::ffi;
+namespace nb = nanobind;
+
+void GetLapackKernelsFromScipy() {
+ static bool initialized = false; // Protected by GIL
+ if (initialized) return;
+ nb::module_ cython_blas = nb::module_::import_("scipy.linalg.cython_blas");
+ nb::module_ cython_lapack =
+ nb::module_::import_("scipy.linalg.cython_lapack");
+ nb::dict lapack_capi = cython_lapack.attr("__pyx_capi__");
+ auto lapack_ptr = [&](const char* name) {
+ return nb::cast(lapack_capi[name]).data();
+ };
+
+ AssignKernelFn>(lapack_ptr("sgeev"));
+ AssignKernelFn>(lapack_ptr("dgeev"));
+ AssignKernelFn>(lapack_ptr("cgeev"));
+ AssignKernelFn>(
+ lapack_ptr("zgeev"));
+}
+
+NB_MODULE(_hybrid, m) {
+ m.def("initialize", GetLapackKernelsFromScipy);
+ m.def("registrations", []() {
+ nb::dict dict;
+ dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal);
+ dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp);
+ return dict;
+ });
+}
+
+} // namespace
+} // namespace JAX_GPU_NAMESPACE
+} // namespace jax
diff --git a/jaxlib/gpu/hybrid_kernels.cc b/jaxlib/gpu/hybrid_kernels.cc
new file mode 100644
index 000000000000..c93d7d350d50
--- /dev/null
+++ b/jaxlib/gpu/hybrid_kernels.cc
@@ -0,0 +1,626 @@
+/* Copyright 2024 The JAX Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "jaxlib/gpu/hybrid_kernels.h"
+
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "absl/algorithm/container.h"
+#include "absl/base/thread_annotations.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/strings/str_format.h"
+#include "absl/synchronization/mutex.h"
+#include "absl/types/span.h"
+#include "jaxlib/cpu/lapack_kernels.h"
+#include "jaxlib/ffi_helpers.h"
+#include "jaxlib/gpu/gpu_kernel_helpers.h"
+#include "jaxlib/gpu/vendor.h"
+#include "xla/ffi/api/ffi.h"
+
+namespace jax {
+namespace JAX_GPU_NAMESPACE {
+
+namespace ffi = ::xla::ffi;
+
+// This helper class is used to define a host buffer that can be copied to and
+// from a device buffer.
+template
+class HostBuffer {
+ public:
+ HostBuffer(std::size_t size) : size_(size) {
+ data_ = std::unique_ptr(new T[size]);
+ }
+
+ absl::Status CopyFromDevice(gpuStream_t stream, const T* buffer) {
+ return JAX_AS_STATUS(gpuMemcpyAsync(data_.get(), buffer, size_ * sizeof(T),
+ gpuMemcpyDeviceToHost, stream));
+ }
+
+ absl::Status CopyToDevice(gpuStream_t stream, T* buffer) {
+ return JAX_AS_STATUS(gpuMemcpyAsync(buffer, data_.get(), size_ * sizeof(T),
+ gpuMemcpyHostToDevice, stream));
+ }
+
+ T* get() const { return data_.get(); }
+
+ private:
+ std::unique_ptr data_;
+ size_t size_;
+};
+
+// Forwarded from MAGMA for use as an input parameter.
+typedef enum {
+ MagmaNoVec = 301,
+ MagmaVec = 302,
+} magma_vec_t;
+
+// Compile time lookup of MAGMA function names depending on the data type.
+template
+struct always_false : std::false_type {};
+
+template
+struct MagmaGeev {
+ static_assert(always_false::value, "unsupported data type");
+};
+template <>
+struct MagmaGeev {
+ static constexpr char name[] = "magma_sgeev";
+};
+template <>
+struct MagmaGeev {
+ static constexpr char name[] = "magma_dgeev";
+};
+template <>
+struct MagmaGeev {
+ static constexpr char name[] = "magma_cgeev";
+};
+template <>
+struct MagmaGeev {
+ static constexpr char name[] = "magma_zgeev";
+};
+
+// This class is used for dlopening the MAGMA shared library, initializing it,
+// and looking up MAGMA symbols. When used via the `FindMagmaSymbol` function
+// (defined below), the library will only be loaded and initialized once per
+// process.
+class MagmaLookup {
+ public:
+ explicit MagmaLookup() = default;
+ ~MagmaLookup() {
+ if (initialized_) {
+ void* magma_finalize = dlsym(handle_, "magma_finalize");
+ if (magma_finalize != nullptr) {
+ reinterpret_cast(magma_finalize)();
+ }
+ }
+ if (handle_ != nullptr) {
+ dlclose(handle_);
+ }
+ }
+
+ absl::Status Initialize() {
+ if (!initialized_) {
+ std::vector paths;
+ const char* magma_lib_path = std::getenv("JAX_GPU_MAGMA_PATH");
+ if (magma_lib_path != nullptr) {
+ paths.push_back(magma_lib_path);
+ } else {
+ paths.push_back("libmagma.so.2");
+ paths.push_back("libmagma.so");
+ paths.push_back(nullptr);
+ }
+ void* magma_init = nullptr;
+ for (const auto& path : paths) {
+ handle_ = dlopen(path, RTLD_LAZY);
+ if (handle_ != nullptr) {
+ magma_init = dlsym(handle_, "magma_init");
+ if (magma_init != nullptr) {
+ if (path != nullptr) {
+ lib_path_ = std::string(path);
+ }
+ break;
+ }
+ }
+ }
+ if (handle_ == nullptr || magma_init == nullptr) {
+ return absl::InternalError(
+ "Unable to dlopen a MAGMA shared library that defines a magma_init "
+ "symbol. Use the JAX_GPU_MAGMA_PATH environment variable to "
+ "specify an explicit path to the library.");
+ }
+
+ reinterpret_cast(magma_init)();
+ initialized_ = true;
+ }
+ return absl::OkStatus();
+ }
+
+ absl::StatusOr Find(const char name[]) {
+ if (!initialized_) {
+ return absl::InternalError("MAGMA support has not been initialized.");
+ }
+
+ auto it = symbols_.find(name);
+ if (it != symbols_.end()) return it->second;
+
+ void* symbol = dlsym(handle_, name);
+ if (symbol == nullptr) {
+ if (lib_path_.has_value()) {
+ return absl::InternalError(absl::StrFormat(
+ "Unable to load the symbol '%s' from the MAGMA library at '%s'.",
+ name, lib_path_.value()));
+
+ } else {
+ return absl::InternalError(absl::StrFormat(
+ "Unable to load a globally defined symbol called '%s'. Use the "
+ "JAX_GPU_MAGMA_PATH environment variable to specify an explicit "
+ "path to the library.",
+ name));
+ }
+ }
+
+ symbols_.insert({name, symbol});
+ return symbol;
+ }
+
+ private:
+ bool initialized_ = false;
+ void* handle_ = nullptr;
+ std::optional lib_path_ = std::nullopt;
+ absl::flat_hash_map symbols_;
+};
+
+// Lookup the MAGMA symbol for the given function name. This function only
+// dlopen the MAGMA library once per process.
+absl::StatusOr FindMagmaSymbol(const char name[]) {
+ static absl::Mutex mu;
+ static MagmaLookup& lookup = *new MagmaLookup ABSL_GUARDED_BY(mu);
+ absl::MutexLock lock(&mu);
+ auto status = lookup.Initialize();
+ if (!status.ok()) {
+ return status;
+ }
+ return lookup.Find(name);
+}
+
+// Real-valued eigendecomposition
+
+template
+class EigRealHost {
+ using Real = ffi::NativeType;
+
+ public:
+ explicit EigRealHost() = default;
+ EigRealHost(EigRealHost&&) = default;
+
+ absl::StatusOr lwork(int n, bool left, bool right) {
+ n_ = n;
+ jobvl_ = left ? 'V' : 'N';
+ jobvr_ = right ? 'V' : 'N';
+ int64_t lwork = EigenvalueDecomposition::GetWorkspaceSize(
+ n, static_cast(jobvl_),
+ static_cast(jobvr_));
+ return MaybeCastNoOverflow(lwork);
+ }
+
+ void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work,
+ int lwork, int* info) {
+ EigenvalueDecomposition::fn(&jobvl_, &jobvr_, &n_, x, &n_, wr, wi,
+ vl, &n_, vr, &n_, work, &lwork, info);
+ }
+
+ private:
+ int n_;
+ char jobvl_, jobvr_;
+};
+
+template
+class EigRealMagma {
+ using Real = ffi::NativeType;
+ using Fn = int(magma_vec_t, magma_vec_t, int, Real*, int, Real*, Real*, Real*,
+ int, Real*, int, Real*, int, int*);
+
+ public:
+ explicit EigRealMagma() = default;
+ EigRealMagma(EigRealMagma&&) = default;
+
+ absl::StatusOr lwork(int n, bool left, bool right) {
+ n_ = n;
+ jobvl_ = left ? MagmaVec : MagmaNoVec;
+ jobvr_ = right ? MagmaVec : MagmaNoVec;
+
+ auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name);
+ if (!maybe_ptr.ok()) return maybe_ptr.status();
+ fn_ = reinterpret_cast(*maybe_ptr);
+
+ int query_info;
+ Real query_host;
+ fn_(jobvl_, jobvr_, n, nullptr, n, nullptr, nullptr, nullptr, n, nullptr, n,
+ &query_host, -1, &query_info);
+ return static_cast(query_host);
+ }
+
+ void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work,
+ int lwork, int* info) {
+ fn_(jobvl_, jobvr_, n_, x, n_, wr, wi, vl, n_, vr, n_, work, lwork, info);
+ }
+
+ private:
+ int n_;
+ magma_vec_t jobvl_, jobvr_;
+ Fn* fn_ = nullptr;
+};
+
+template
+ffi::Error EigReal(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream,
+ bool left, bool right, ffi::AnyBuffer x,
+ ffi::Result wr,
+ ffi::Result wi,
+ ffi::Result vl,
+ ffi::Result vr,
+ ffi::Result> info) {
+ using Real = ffi::NativeType;
+ using Complex = ffi::NativeType;
+
+ auto x_host = HostBuffer(x.element_count());
+ FFI_RETURN_IF_ERROR_STATUS(
+ x_host.CopyFromDevice(stream, x.typed_data()));
+
+ auto wr_host = HostBuffer(batch * cols);
+ auto wi_host = HostBuffer(batch * cols);
+ auto vl_host = HostBuffer(batch * cols * cols);
+ auto vr_host = HostBuffer(batch * cols * cols);
+ auto info_host = HostBuffer(batch);
+
+ FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols));
+ FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right));
+ auto work_host = AllocateScratchMemory(lwork);
+ auto work_left = AllocateScratchMemory(cols * cols);
+ auto work_right = AllocateScratchMemory(cols * cols);
+
+ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
+
+ const auto is_finite = [](auto* data, int64_t size) {
+ return absl::c_all_of(absl::MakeSpan(data, size),
+ [](auto value) { return std::isfinite(value); });
+ };
+
+ for (int64_t i = 0; i < batch; ++i) {
+ if (is_finite(x_host.get() + i * cols * cols, cols * cols)) {
+ impl.compute(x_host.get() + i * cols * cols, wr_host.get() + i * cols,
+ wi_host.get() + i * cols, work_left.get(), work_right.get(),
+ work_host.get(), lwork, info_host.get() + i);
+ if (info_host.get()[i] == 0) {
+ if (left) {
+ UnpackEigenvectors(n, wi_host.get() + i * cols, work_left.get(),
+ vl_host.get() + i * cols * cols);
+ }
+ if (right) {
+ UnpackEigenvectors(n, wi_host.get() + i * cols, work_right.get(),
+ vr_host.get() + i * cols * cols);
+ }
+ }
+ } else {
+ info_host.get()[i] = -4;
+ }
+ }
+
+ FFI_RETURN_IF_ERROR_STATUS(
+ wr_host.CopyToDevice(stream, wr->typed_data()));
+ FFI_RETURN_IF_ERROR_STATUS(
+ wi_host.CopyToDevice(stream, wi->typed_data()));
+ if (left) {
+ FFI_RETURN_IF_ERROR_STATUS(
+ vl_host.CopyToDevice(stream, vl->typed_data()));
+ }
+ if (right) {
+ FFI_RETURN_IF_ERROR_STATUS(
+ vr_host.CopyToDevice(stream, vr->typed_data()));
+ }
+ FFI_RETURN_IF_ERROR_STATUS(
+ info_host.CopyToDevice(stream, info->typed_data()));
+ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
+
+ return ffi::Error::Success();
+}
+
+ffi::Error EigRealDispatch(gpuStream_t stream, bool magma, bool left,
+ bool right, ffi::AnyBuffer x,
+ ffi::Result wr,
+ ffi::Result wi,
+ ffi::Result vl,
+ ffi::Result vr,
+ ffi::Result> info) {
+ auto dataType = x.element_type();
+ if (dataType != wr->element_type() || dataType != wi->element_type() ||
+ ffi::ToComplex(dataType) != vl->element_type() ||
+ ffi::ToComplex(dataType) != vr->element_type()) {
+ return ffi::Error::InvalidArgument(
+ "The inputs and outputs to eig must have the same element type");
+ }
+
+ FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
+ SplitBatch2D(x.dimensions()));
+ if (rows != cols) {
+ return ffi::Error::InvalidArgument(
+ "The input matrix to eig must be square");
+ }
+ FFI_RETURN_IF_ERROR(CheckShape(wr->dimensions(), {batch, cols}, "wr", "eig"));
+ FFI_RETURN_IF_ERROR(CheckShape(wi->dimensions(), {batch, cols}, "wi", "eig"));
+ if (left) {
+ FFI_RETURN_IF_ERROR(
+ CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig"));
+ }
+ if (right) {
+ FFI_RETURN_IF_ERROR(
+ CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig"));
+ }
+ FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig"));
+
+ switch (dataType) {
+ case ffi::F32:
+ if (magma) {
+ return EigReal(EigRealMagma(), batch, cols, stream,
+ left, right, x, wr, wi, vl, vr, info);
+ } else {
+ return EigReal(EigRealHost(), batch, cols, stream,
+ left, right, x, wr, wi, vl, vr, info);
+ }
+ case ffi::F64:
+ if (magma) {
+ return EigReal(EigRealMagma(), batch, cols, stream,
+ left, right, x, wr, wi, vl, vr, info);
+ } else {
+ return EigReal(EigRealHost(), batch, cols, stream,
+ left, right, x, wr, wi, vl, vr, info);
+ }
+ default:
+ return ffi::Error::InvalidArgument(absl::StrFormat(
+ "Unsupported dtype %s in eig_real", absl::FormatStreamed(dataType)));
+ }
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigReal, EigRealDispatch,
+ ffi::Ffi::Bind()
+ .Ctx>()
+ .Attr("magma")
+ .Attr("left")
+ .Attr("right")
+ .Arg() // x
+ .Ret() // wr
+ .Ret() // wi
+ .Ret() // vl
+ .Ret() // vr
+ .Ret>() // info
+);
+
+// Complex-valued eigendecomposition
+
+template
+class EigCompHost {
+ using Real = ffi::NativeType;
+ using Complex = ffi::NativeType;
+
+ public:
+ explicit EigCompHost() = default;
+ EigCompHost(EigCompHost&&) = default;
+
+ absl::StatusOr lwork(int n, bool left, bool right) {
+ n_ = n;
+ jobvl_ = left ? 'V' : 'N';
+ jobvr_ = right ? 'V' : 'N';
+ int64_t lwork = EigenvalueDecompositionComplex::GetWorkspaceSize(
+ n, static_cast(jobvl_),
+ static_cast(jobvr_));
+ return MaybeCastNoOverflow(lwork);
+ }
+
+ void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work,
+ int lwork, Real* rwork, int* info) {
+ EigenvalueDecompositionComplex::fn(&jobvl_, &jobvr_, &n_, x, &n_,
+ w, vl, &n_, vr, &n_, work,
+ &lwork, rwork, info);
+ }
+
+ private:
+ int n_;
+ char jobvl_, jobvr_;
+};
+
+template
+class EigCompMagma {
+ using Real = ffi::NativeType;
+ using Complex = ffi::NativeType;
+ using Fn = int(magma_vec_t, magma_vec_t, int, Complex*, int, Complex*,
+ Complex*, int, Complex*, int, Complex*, int, Real*, int*);
+
+ public:
+ explicit EigCompMagma() = default;
+ EigCompMagma(EigCompMagma&&) = default;
+
+ absl::StatusOr lwork(int n, bool left, bool right) {
+ n_ = n;
+ jobvl_ = left ? MagmaVec : MagmaNoVec;
+ jobvr_ = right ? MagmaVec : MagmaNoVec;
+ lda_ = std::max(n_, 1);
+ ldvl_ = left ? n_ : 1;
+ ldvr_ = right ? n_ : 1;
+
+ auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name);
+ if (!maybe_ptr.ok()) return maybe_ptr.status();
+ fn_ = reinterpret_cast(*maybe_ptr);
+
+ int query_info;
+ Complex query_host;
+ fn_(jobvl_, jobvr_, n_, nullptr, lda_, nullptr, nullptr, ldvl_, nullptr,
+ ldvr_, &query_host, -1, nullptr, &query_info);
+ return static_cast(query_host.real());
+ }
+
+ void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work,
+ int lwork, Real* rwork, int* info) {
+ fn_(jobvl_, jobvr_, n_, x, lda_, w, vl, ldvl_, vr, ldvr_, work, lwork,
+ rwork, info);
+ }
+
+ private:
+ int n_, lda_, ldvl_, ldvr_;
+ magma_vec_t jobvl_, jobvr_;
+ Fn* fn_ = nullptr;
+};
+
+template
+ffi::Error EigComp(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream,
+ bool left, bool right, ffi::AnyBuffer x,
+ ffi::Result w,
+ ffi::Result vl,
+ ffi::Result vr,
+ ffi::Result> info) {
+ using Complex = ffi::NativeType;
+
+ auto x_host = HostBuffer(x.element_count());
+ FFI_RETURN_IF_ERROR_STATUS(
+ x_host.CopyFromDevice(stream, x.typed_data()));
+
+ auto w_host = HostBuffer(batch * cols);
+ auto vl_host = HostBuffer(batch * cols * cols);
+ auto vr_host = HostBuffer(batch * cols * cols);
+ auto info_host = HostBuffer(batch);
+
+ FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols));
+ FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right));
+ auto work_host = AllocateScratchMemory(lwork);
+ auto rwork_host =
+ AllocateScratchMemory(2 * cols * cols);
+
+ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
+
+ const auto is_finite = [](auto* data, int64_t size) {
+ return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) {
+ return std::isfinite(z.real()) && std::isfinite(z.imag());
+ });
+ };
+
+ for (int64_t i = 0; i < batch; ++i) {
+ if (is_finite(x_host.get() + i * cols * cols, cols * cols)) {
+ impl.compute(x_host.get() + i * cols * cols, w_host.get() + i * cols,
+ vl_host.get() + i * cols * cols,
+ vr_host.get() + i * cols * cols, work_host.get(), lwork,
+ rwork_host.get(), info_host.get() + i);
+ } else {
+ info_host.get()[i] = -4;
+ }
+ }
+
+ FFI_RETURN_IF_ERROR_STATUS(
+ w_host.CopyToDevice(stream, w->typed_data()));
+ if (left) {
+ FFI_RETURN_IF_ERROR_STATUS(
+ vl_host.CopyToDevice(stream, vl->typed_data()));
+ }
+ if (right) {
+ FFI_RETURN_IF_ERROR_STATUS(
+ vr_host.CopyToDevice(stream, vr->typed_data()));
+ }
+ FFI_RETURN_IF_ERROR_STATUS(
+ info_host.CopyToDevice(stream, info->typed_data()));
+ FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream)));
+
+ return ffi::Error::Success();
+}
+
+ffi::Error EigCompDispatch(gpuStream_t stream, bool magma, bool left,
+ bool right, ffi::AnyBuffer x,
+ ffi::Result w,
+ ffi::Result vl,
+ ffi::Result vr,
+ ffi::Result> info) {
+ auto dataType = x.element_type();
+ if (dataType != w->element_type() || dataType != vl->element_type() ||
+ dataType != vr->element_type()) {
+ return ffi::Error::InvalidArgument(
+ "The inputs and outputs to eig must have the same element type");
+ }
+
+ FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
+ SplitBatch2D(x.dimensions()));
+ if (rows != cols) {
+ return ffi::Error::InvalidArgument(
+ "The input matrix to eig must be square");
+ }
+ FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "eig"));
+ if (left) {
+ FFI_RETURN_IF_ERROR(
+ CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig"));
+ }
+ if (right) {
+ FFI_RETURN_IF_ERROR(
+ CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig"));
+ }
+ FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig"));
+
+ switch (dataType) {
+ case ffi::C64:
+ if (magma) {
+ return EigComp(EigCompMagma(), batch, cols, stream,
+ left, right, x, w, vl, vr, info);
+ } else {
+ return EigComp(EigCompHost(), batch, cols, stream,
+ left, right, x, w, vl, vr, info);
+ }
+ case ffi::C128:
+ if (magma) {
+ return ffi::Error::InvalidArgument(
+ "Using MAGMA as the backend for eig_comp is not supported for "
+ "complex128 data types");
+ } else {
+ return EigComp(EigCompHost(), batch, cols, stream,
+ left, right, x, w, vl, vr, info);
+ }
+ default:
+ return ffi::Error::InvalidArgument(absl::StrFormat(
+ "Unsupported dtype %s in eig_comp", absl::FormatStreamed(dataType)));
+ }
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigComp, EigCompDispatch,
+ ffi::Ffi::Bind()
+ .Ctx>()
+ .Attr("magma")
+ .Attr("left")
+ .Attr("right")
+ .Arg() // x
+ .Ret() // w
+ .Ret() // vl
+ .Ret() // vr
+ .Ret>() // info
+);
+
+} // namespace JAX_GPU_NAMESPACE
+} // namespace jax
diff --git a/jaxlib/gpu/hybrid_kernels.h b/jaxlib/gpu/hybrid_kernels.h
new file mode 100644
index 000000000000..a70e72c745aa
--- /dev/null
+++ b/jaxlib/gpu/hybrid_kernels.h
@@ -0,0 +1,31 @@
+/* Copyright 2024 The JAX Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef JAXLIB_GPU_HYBRID_KERNELS_H_
+#define JAXLIB_GPU_HYBRID_KERNELS_H_
+
+#include "jaxlib/gpu/vendor.h"
+#include "xla/ffi/api/ffi.h"
+
+namespace jax {
+namespace JAX_GPU_NAMESPACE {
+
+XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigReal);
+XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigComp);
+
+} // namespace JAX_GPU_NAMESPACE
+} // namespace jax
+
+#endif // JAXLIB_GPU_HYBRID_KERNELS_H_
diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py
index 03fd43e9ef89..434e724de902 100644
--- a/jaxlib/gpu_solver.py
+++ b/jaxlib/gpu_solver.py
@@ -56,6 +56,21 @@
xla_client.register_custom_call_target(_name, _value, platform="CUDA",
api_version=api_version)
+for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
+ try:
+ _cuhybrid = importlib.import_module(
+ f"{cuda_module_name}._hybrid", package="jaxlib"
+ )
+ except ImportError:
+ _cuhybrid = None
+ else:
+ break
+
+if _cuhybrid:
+ for _name, _value in _cuhybrid.registrations().items():
+ xla_client.register_custom_call_target(_name, _value, platform="CUDA",
+ api_version=1)
+
try:
from .rocm import _blas as _hipblas # pytype: disable=import-error
except ImportError:
@@ -88,6 +103,27 @@
xla_client.register_custom_call_target(_name, _value, platform="ROCM",
api_version=api_version)
+for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
+ try:
+ _hiphybrid = importlib.import_module(
+ f"{rocm_module_name}._hybrid", package="jaxlib"
+ )
+ except ImportError:
+ _hiphybrid = None
+ else:
+ break
+
+if _hiphybrid:
+ for _name, _value in _hiphybrid.registrations().items():
+ xla_client.register_custom_call_target(_name, _value, platform="ROCM",
+ api_version=1)
+
+def initialize_hybrid_kernels():
+ if _cuhybrid:
+ _cuhybrid.initialize()
+ if _hiphybrid:
+ _hiphybrid.initialize()
+
def _real_type(dtype):
"""Returns the real equivalent of 'dtype'."""
return np.finfo(dtype).dtype
diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD
index c9b73a5785f1..1076f9a77bf8 100644
--- a/jaxlib/rocm/BUILD
+++ b/jaxlib/rocm/BUILD
@@ -389,6 +389,48 @@ pybind_extension(
],
)
+cc_library(
+ name = "hip_hybrid_kernels",
+ srcs = ["//jaxlib/gpu:hybrid_kernels.cc"],
+ hdrs = ["//jaxlib/gpu:hybrid_kernels.h"],
+ deps = [
+ ":hip_gpu_kernel_helpers",
+ ":hip_vendor",
+ "//jaxlib:ffi_helpers",
+ "//jaxlib/cpu:lapack_kernels",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_absl//absl/types:span",
+ "@xla//xla/ffi/api:ffi",
+ ],
+)
+
+pybind_extension(
+ name = "_hybrid",
+ srcs = ["//jaxlib/gpu:hybrid.cc"],
+ copts = [
+ "-fexceptions",
+ "-fno-strict-aliasing",
+ ],
+ features = ["-use_header_modules"],
+ module_name = "_hybrid",
+ deps = [
+ ":hip_gpu_kernel_helpers",
+ ":hip_hybrid_kernels",
+ ":hip_vendor",
+ "//jaxlib:kernel_nanobind_helpers",
+ "//jaxlib/cpu:lapack_kernels",
+ "@local_config_rocm//rocm:rocm_headers",
+ "@nanobind",
+ "@xla//xla/ffi/api:ffi",
+ ],
+)
+
cc_library(
name = "triton_kernels",
srcs = ["//jaxlib/gpu:triton_kernels.cc"],
@@ -456,6 +498,7 @@ py_library(
name = "rocm_gpu_support",
deps = [
":_blas",
+ ":_hybrid",
":_linalg",
":_prng",
":_solver",
diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py
index 5b3ac636303a..9a47c6ad5409 100644
--- a/jaxlib/tools/build_gpu_kernels_wheel.py
+++ b/jaxlib/tools/build_gpu_kernels_wheel.py
@@ -108,6 +108,7 @@ def prepare_wheel_cuda(
f"__main__/jaxlib/cuda/_rnn.{pyext}",
f"__main__/jaxlib/cuda/_sparse.{pyext}",
f"__main__/jaxlib/cuda/_triton.{pyext}",
+ f"__main__/jaxlib/cuda/_hybrid.{pyext}",
f"__main__/jaxlib/cuda/_versions.{pyext}",
f"__main__/jaxlib/cuda_plugin_extension.{pyext}",
f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}",
@@ -144,6 +145,7 @@ def prepare_wheel_rocm(
f"__main__/jaxlib/rocm/_linalg.{pyext}",
f"__main__/jaxlib/rocm/_prng.{pyext}",
f"__main__/jaxlib/rocm/_sparse.{pyext}",
+ f"__main__/jaxlib/cuda/_hybrid.{pyext}",
f"__main__/jaxlib/rocm/_triton.{pyext}",
f"__main__/jaxlib/rocm_plugin_extension.{pyext}",
"__main__/jaxlib/version.py",
diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py
index 438cebca2b06..4db36fa0ea97 100644
--- a/jaxlib/tools/build_wheel.py
+++ b/jaxlib/tools/build_wheel.py
@@ -231,6 +231,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
f"__main__/jaxlib/cuda/_rnn.{pyext}",
f"__main__/jaxlib/cuda/_sparse.{pyext}",
f"__main__/jaxlib/cuda/_triton.{pyext}",
+ f"__main__/jaxlib/cuda/_hybrid.{pyext}",
f"__main__/jaxlib/cuda/_versions.{pyext}",
],
)
@@ -244,6 +245,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
f"__main__/jaxlib/rocm/_prng.{pyext}",
f"__main__/jaxlib/rocm/_sparse.{pyext}",
f"__main__/jaxlib/rocm/_triton.{pyext}",
+ f"__main__/jaxlib/rocm/_hybrid.{pyext}",
],
)
diff --git a/tests/linalg_test.py b/tests/linalg_test.py
index 5ace4b5ecf18..cac018ef8978 100644
--- a/tests/linalg_test.py
+++ b/tests/linalg_test.py
@@ -34,6 +34,7 @@
from jax._src.lax import linalg as lax_linalg
from jax._src import test_util as jtu
from jax._src import xla_bridge
+from jax._src.lib import version as jaxlib_version
from jax._src.numpy.util import promote_dtypes_inexact
config.parse_flags_with_absl()
@@ -234,11 +235,13 @@ def testIssue1213(self):
compute_left_eigenvectors=[False, True],
compute_right_eigenvectors=[False, True],
)
- # TODO(phawkins): enable when there is an eigendecomposition implementation
- # for GPU/TPU.
- @jtu.run_on_devices("cpu")
+ # TODO(danfm): enable when there is an eigendecomposition implementation
+ # for TPU.
+ @jtu.run_on_devices("cpu", "gpu")
def testEig(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
+ if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35):
+ self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
rng = jtu.rand_default(self.rng())
n = shape[-1]
args_maker = lambda: [rng(shape, dtype)]
@@ -277,12 +280,14 @@ def check_left_eigenvectors(a, w, vl):
compute_left_eigenvectors=[False, True],
compute_right_eigenvectors=[False, True],
)
- # TODO(phawkins): enable when there is an eigendecomposition implementation
- # for GPU/TPU.
- @jtu.run_on_devices("cpu")
+ # TODO(danfm): enable when there is an eigendecomposition implementation
+ # for TPU.
+ @jtu.run_on_devices("cpu", "gpu")
def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors,
compute_right_eigenvectors):
"""Verifies that `eig` fails gracefully if given non-finite inputs."""
+ if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35):
+ self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
a = jnp.full(shape, jnp.nan, dtype)
results = lax.linalg.eig(
a, compute_left_eigenvectors=compute_left_eigenvectors,
@@ -294,14 +299,16 @@ def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors,
shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)],
dtype=float_types + complex_types,
)
- # TODO(phawkins): enable when there is an eigendecomposition implementation
- # for GPU/TPU.
- @jtu.run_on_devices("cpu")
+ # TODO(danfm): enable when there is an eigendecomposition implementation
+ # for TPU.
+ @jtu.run_on_devices("cpu", "gpu")
def testEigvalsGrad(self, shape, dtype):
# This test sometimes fails for large matrices. I (@j-towns) suspect, but
# haven't checked, that might be because of perturbations causing the
# ordering of eigenvalues to change, which will trip up check_grads. So we
# just test on small-ish matrices.
+ if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35):
+ self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()
@@ -313,10 +320,12 @@ def testEigvalsGrad(self, shape, dtype):
shape=[(4, 4), (5, 5), (50, 50)],
dtype=float_types + complex_types,
)
- # TODO: enable when there is an eigendecomposition implementation
- # for GPU/TPU.
- @jtu.run_on_devices("cpu")
+ # TODO(danfm): enable when there is an eigendecomposition implementation
+ # for TPU.
+ @jtu.run_on_devices("cpu", "gpu")
def testEigvals(self, shape, dtype):
+ if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35):
+ self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
a, = args_maker()
@@ -324,9 +333,11 @@ def testEigvals(self, shape, dtype):
w2 = jnp.linalg.eigvals(a)
self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 2e-14})
- @jtu.run_on_devices("cpu")
+ @jtu.run_on_devices("cpu", "gpu")
def testEigvalsInf(self):
# https://github.com/jax-ml/jax/issues/2661
+ if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35):
+ self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
x = jnp.array([[jnp.inf]])
self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x))))
@@ -334,8 +345,10 @@ def testEigvalsInf(self):
shape=[(1, 1), (4, 4), (5, 5)],
dtype=float_types + complex_types,
)
- @jtu.run_on_devices("cpu")
+ @jtu.run_on_devices("cpu", "gpu")
def testEigBatching(self, shape, dtype):
+ if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35):
+ self.skipTest("eig on GPU requires jaxlib version > 0.4.35")
rng = jtu.rand_default(self.rng())
shape = (10,) + shape
args = rng(shape, dtype)