diff --git a/jax/_src/config.py b/jax/_src/config.py index 215ef443c799..c90138df1fe9 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1953,3 +1953,20 @@ def _update_garbage_collection_guard(state, key, val): ), include_in_jit_key=True, ) + +gpu_use_magma = enum_state( + name='jax_gpu_use_magma', + enum_values=[ + 'off', 'on', + # TODO(danfm): After doing more complete benchmarks, add an 'auto' + # option which will automatically enable MAGMA for problem sizes where + # it typically gets better performance. + # 'auto', + ], + default='off', + help=( + 'Enable experimental support for MAGMA-backed lax.linalg.eig on GPU. ' + 'See the documentation for lax.linalg.eig for more details about how ' + 'to use this feature.' + ), +) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 0e0390abc78f..0240416ca29c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -124,7 +124,25 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, compute_right_eigenvectors: bool = True) -> list[Array]: """Eigendecomposition of a general matrix. - Nonsymmetric eigendecomposition is at present only implemented on CPU. + Nonsymmetric eigendecomposition is only implemented on CPU and GPU. On GPU, + the default implementation calls LAPACK directly on the host CPU, but an + experimental GPU implementation using `MAGMA `_ + is also available. The MAGMA implementation is typically slower than the + equivalent LAPACK implementation for small matrices (less than about 2048), + but it may perform better for larger matrices. + + To enable the MAGMA implementation, you must install MAGMA yourself (there + are Debian and conda-forge packages, or you can build from source). Then set + the `jax_gpu_use_magma` configuration variable to `"on"`: + + .. code-block:: python + + jax.config.update('jax_gpu_use_magma', 'on') + + JAX will try to ``dlopen`` the installed MAGMA shared library, raising an + error if it is not found. To explicitly specify the path to the MAGMA + library, set the environment variable `JAX_GPU_MAGMA_PATH` to the full + installation path. Args: x: A batch of square matrices with shape ``[..., n, n]``. @@ -763,6 +781,78 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, return output +def _eig_gpu_impl(target_name_prefix, x, *, compute_left_eigenvectors, + compute_right_eigenvectors): + gpu_solver.initialize_hybrid_kernels() + dtype = x.dtype + is_real = dtype == np.float32 or dtype == np.float64 + if is_real: + target_name = f"{target_name_prefix}hybrid_eig_real" + complex_dtype = np.complex64 if dtype == np.float32 else np.complex128 + else: + target_name = f"{target_name_prefix}hybrid_eig_comp" + assert dtype == np.complex64 or dtype == np.complex128 + complex_dtype = dtype + + batch_dims = x.shape[:-2] + n, m = x.shape[-2:] + assert n == m + num_batch_dims = len(batch_dims) + + layout = tuple(range(num_batch_dims)) + (num_batch_dims + 1, num_batch_dims) + out_types = [ + api.ShapeDtypeStruct(batch_dims + (n,), dtype), + api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype), + api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype), + api.ShapeDtypeStruct(batch_dims, np.int32), + ] + out_layouts = [None, layout, layout, None] + if is_real: + out_types = [api.ShapeDtypeStruct(batch_dims + (n,), dtype)] + out_types + out_layouts = [None] + out_layouts + + magma = False + if config.gpu_use_magma.value == 'on': + magma = dtype != np.complex128 + + fun = ffi.ffi_call(target_name, out_types, input_layouts=[layout], + output_layouts=out_layouts) + *w, vl, vr, info = fun(x, magma=magma, left=compute_left_eigenvectors, + right=compute_right_eigenvectors) + if is_real: + assert len(w) == 2 + w = lax.complex(*w) + else: + assert len(w) == 1 + w = w[0] + ok = lax.eq(info, lax.zeros_like_array(info)) + ok = _broadcast_to(ok[..., None], w.shape) + w = lax.select(ok, w, lax.full_like(w, np.nan + np.nan * 1j)) + ok = _broadcast_to(ok[..., None], x.shape) + output = [w] + if compute_left_eigenvectors: + vl = lax.select(ok, vl, lax.full_like(vl, np.nan + np.nan * 1j)) + output.append(vl) + if compute_right_eigenvectors: + vr = lax.select(ok, vr, lax.full_like(vr, np.nan + np.nan * 1j)) + output.append(vr) + return output + + +def _eig_gpu_lowering(target_name_prefix, ctx, operand, *, + compute_left_eigenvectors, compute_right_eigenvectors): + if ctx.is_forward_compat(): + raise NotImplementedError( + "Export of nonsymmetric eigendecomposition on GPU is not supported " + "because of forward compatibility. The " + "'jax_export_ignore_forward_compatibility' configuration option can be " + "used to disable this check.") + rule = mlir.lower_fun(partial(_eig_gpu_impl, target_name_prefix), + multiple_results=True) + return rule(ctx, operand, compute_left_eigenvectors=compute_left_eigenvectors, + compute_right_eigenvectors=compute_right_eigenvectors) + + def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors, compute_right_eigenvectors): x, = batched_args @@ -793,6 +883,10 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, eig_p.def_abstract_eval(eig_abstract_eval) mlir.register_lowering(eig_p, eig_lower) mlir.register_lowering(eig_p, _eig_cpu_lowering, platform='cpu') +mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'cu'), + platform='cuda') +mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'hip'), + platform='rocm') batching.primitive_batchers[eig_p] = eig_batching_rule ad.primitive_jvps[eig_p] = eig_jvp_rule diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 03f864919887..76a4abff48ad 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -731,7 +731,9 @@ def eig(a: ArrayLike) -> tuple[Array, Array]: - This differs from :func:`numpy.linalg.eig` in that the return type of :func:`jax.numpy.linalg.eig` is always complex64 for 32-bit input, and complex128 for 64-bit input. - - At present, non-symmetric eigendecomposition is only implemented on the CPU backend. + - At present, non-symmetric eigendecomposition is only implemented on the CPU and + GPU backends. For more details about the GPU implementation, see the + documentation for :func:`jax.lax.linalg.eig`. See also: - :func:`jax.numpy.linalg.eigh`: eigenvectors and eigenvalues of a Hermitian matrix. diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index 19b82a5ce149..ed815e1b1bd2 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -1094,34 +1094,6 @@ template struct EigenvalueDecompositionSymmetric; template struct EigenvalueDecompositionHermitian; template struct EigenvalueDecompositionHermitian; -// LAPACK uses a packed representation to represent a mixture of real -// eigenvectors and complex conjugate pairs. This helper unpacks the -// representation into regular complex matrices. -template -static void UnpackEigenvectors(lapack_int n, const T* eigenvals_imag, - const T* packed, std::complex* unpacked) { - for (int j = 0; j < n;) { - if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) { - // Real values in each row without imaginary part - // Second row of the imaginary part is not provided - for (int i = 0; i < n; ++i) { - unpacked[j * n + i] = {packed[j * n + i], 0.}; - } - ++j; - } else { - // Complex values where the real part is in the jth row - // and the imaginary part is in the next row (j + 1) - for (int i = 0; i < n; ++i) { - const T real_part = packed[j * n + i]; - const T imag_part = packed[(j + 1) * n + i]; - unpacked[j * n + i] = {real_part, imag_part}; - unpacked[(j + 1) * n + i] = {real_part, -imag_part}; - } - j += 2; - } - } -} - // lapack geev template diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 7d15e494fffc..cddcb1162120 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef JAXLIB_CPU_LAPACK_KERNELS_H_ #define JAXLIB_CPU_LAPACK_KERNELS_H_ +#include #include #include #include @@ -462,6 +463,34 @@ struct EigenvalueDecompositionHermitian { // lapack geev +// LAPACK uses a packed representation to represent a mixture of real +// eigenvectors and complex conjugate pairs. This helper unpacks the +// representation into regular complex matrices. +template +static void UnpackEigenvectors(Int n, const T* eigenvals_imag, + const T* packed, std::complex* unpacked) { + for (int j = 0; j < n;) { + if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) { + // Real values in each row without imaginary part + // Second row of the imaginary part is not provided + for (int i = 0; i < n; ++i) { + unpacked[j * n + i] = {packed[j * n + i], 0.}; + } + ++j; + } else { + // Complex values where the real part is in the jth row + // and the imaginary part is in the next row (j + 1) + for (int i = 0; i < n; ++i) { + const T real_part = packed[j * n + i]; + const T imag_part = packed[(j + 1) * n + i]; + unpacked[j * n + i] = {real_part, imag_part}; + unpacked[(j + 1) * n + i] = {real_part, -imag_part}; + } + j += 2; + } + } +} + template struct RealGeev { using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a, diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index 34e40d12d5be..afce2c000ecc 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -476,6 +476,55 @@ pybind_extension( ], ) +cc_library( + name = "cuda_hybrid_kernels", + srcs = ["//jaxlib/gpu:hybrid_kernels.cc"], + hdrs = ["//jaxlib/gpu:hybrid_kernels.h"], + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", + "//jaxlib:ffi_helpers", + "//jaxlib/cpu:lapack_kernels", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@xla//xla/ffi/api:ffi", + ], +) + +pybind_extension( + name = "_hybrid", + srcs = ["//jaxlib/gpu:hybrid.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + linkopts = select({ + "@xla//xla/python:use_jax_cuda_pip_rpaths": [ + "-Wl,-rpath,$$ORIGIN/../../nvidia/cuda_runtime/lib", + ], + "//conditions:default": [], + }), + module_name = "_hybrid", + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_hybrid_kernels", + ":cuda_vendor", + "//jaxlib:kernel_nanobind_helpers", + "//jaxlib/cpu:lapack_kernels", + "@local_config_cuda//cuda:cuda_headers", + "@nanobind", + "@xla//xla/ffi/api:ffi", + "@xla//xla/tsl/cuda:cudart", + ], +) + cc_library( name = "cuda_gpu_kernels", srcs = ["//jaxlib/gpu:gpu_kernels.cc"], @@ -633,6 +682,7 @@ py_library( name = "cuda_gpu_support", deps = [ ":_blas", + ":_hybrid", ":_linalg", ":_prng", ":_rnn", diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index 7d50a91cfcda..e888f6a42a9b 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -37,6 +37,9 @@ exports_files(srcs = [ "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", "gpu_kernels.cc", + "hybrid.cc", + "hybrid_kernels.cc", + "hybrid_kernels.h", "linalg.cc", "linalg_kernels.cc", "linalg_kernels.cu.cc", diff --git a/jaxlib/gpu/hybrid.cc b/jaxlib/gpu/hybrid.cc new file mode 100644 index 000000000000..b7e0becdcc5b --- /dev/null +++ b/jaxlib/gpu/hybrid.cc @@ -0,0 +1,59 @@ +/* Copyright 2021 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "nanobind/nanobind.h" +#include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/gpu/hybrid_kernels.h" +#include "jaxlib/gpu/vendor.h" +#include "jaxlib/kernel_nanobind_helpers.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +namespace { +namespace ffi = xla::ffi; +namespace nb = nanobind; + +void GetLapackKernelsFromScipy() { + static bool initialized = false; // Protected by GIL + if (initialized) return; + nb::module_ cython_blas = nb::module_::import_("scipy.linalg.cython_blas"); + nb::module_ cython_lapack = + nb::module_::import_("scipy.linalg.cython_lapack"); + nb::dict lapack_capi = cython_lapack.attr("__pyx_capi__"); + auto lapack_ptr = [&](const char* name) { + return nb::cast(lapack_capi[name]).data(); + }; + + AssignKernelFn>(lapack_ptr("sgeev")); + AssignKernelFn>(lapack_ptr("dgeev")); + AssignKernelFn>(lapack_ptr("cgeev")); + AssignKernelFn>( + lapack_ptr("zgeev")); +} + +NB_MODULE(_hybrid, m) { + m.def("initialize", GetLapackKernelsFromScipy); + m.def("registrations", []() { + nb::dict dict; + dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal); + dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp); + return dict; + }); +} + +} // namespace +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/hybrid_kernels.cc b/jaxlib/gpu/hybrid_kernels.cc new file mode 100644 index 000000000000..c93d7d350d50 --- /dev/null +++ b/jaxlib/gpu/hybrid_kernels.cc @@ -0,0 +1,626 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/hybrid_kernels.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "jaxlib/cpu/lapack_kernels.h" +#include "jaxlib/ffi_helpers.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +namespace ffi = ::xla::ffi; + +// This helper class is used to define a host buffer that can be copied to and +// from a device buffer. +template +class HostBuffer { + public: + HostBuffer(std::size_t size) : size_(size) { + data_ = std::unique_ptr(new T[size]); + } + + absl::Status CopyFromDevice(gpuStream_t stream, const T* buffer) { + return JAX_AS_STATUS(gpuMemcpyAsync(data_.get(), buffer, size_ * sizeof(T), + gpuMemcpyDeviceToHost, stream)); + } + + absl::Status CopyToDevice(gpuStream_t stream, T* buffer) { + return JAX_AS_STATUS(gpuMemcpyAsync(buffer, data_.get(), size_ * sizeof(T), + gpuMemcpyHostToDevice, stream)); + } + + T* get() const { return data_.get(); } + + private: + std::unique_ptr data_; + size_t size_; +}; + +// Forwarded from MAGMA for use as an input parameter. +typedef enum { + MagmaNoVec = 301, + MagmaVec = 302, +} magma_vec_t; + +// Compile time lookup of MAGMA function names depending on the data type. +template +struct always_false : std::false_type {}; + +template +struct MagmaGeev { + static_assert(always_false::value, "unsupported data type"); +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_sgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_dgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_cgeev"; +}; +template <> +struct MagmaGeev { + static constexpr char name[] = "magma_zgeev"; +}; + +// This class is used for dlopening the MAGMA shared library, initializing it, +// and looking up MAGMA symbols. When used via the `FindMagmaSymbol` function +// (defined below), the library will only be loaded and initialized once per +// process. +class MagmaLookup { + public: + explicit MagmaLookup() = default; + ~MagmaLookup() { + if (initialized_) { + void* magma_finalize = dlsym(handle_, "magma_finalize"); + if (magma_finalize != nullptr) { + reinterpret_cast(magma_finalize)(); + } + } + if (handle_ != nullptr) { + dlclose(handle_); + } + } + + absl::Status Initialize() { + if (!initialized_) { + std::vector paths; + const char* magma_lib_path = std::getenv("JAX_GPU_MAGMA_PATH"); + if (magma_lib_path != nullptr) { + paths.push_back(magma_lib_path); + } else { + paths.push_back("libmagma.so.2"); + paths.push_back("libmagma.so"); + paths.push_back(nullptr); + } + void* magma_init = nullptr; + for (const auto& path : paths) { + handle_ = dlopen(path, RTLD_LAZY); + if (handle_ != nullptr) { + magma_init = dlsym(handle_, "magma_init"); + if (magma_init != nullptr) { + if (path != nullptr) { + lib_path_ = std::string(path); + } + break; + } + } + } + if (handle_ == nullptr || magma_init == nullptr) { + return absl::InternalError( + "Unable to dlopen a MAGMA shared library that defines a magma_init " + "symbol. Use the JAX_GPU_MAGMA_PATH environment variable to " + "specify an explicit path to the library."); + } + + reinterpret_cast(magma_init)(); + initialized_ = true; + } + return absl::OkStatus(); + } + + absl::StatusOr Find(const char name[]) { + if (!initialized_) { + return absl::InternalError("MAGMA support has not been initialized."); + } + + auto it = symbols_.find(name); + if (it != symbols_.end()) return it->second; + + void* symbol = dlsym(handle_, name); + if (symbol == nullptr) { + if (lib_path_.has_value()) { + return absl::InternalError(absl::StrFormat( + "Unable to load the symbol '%s' from the MAGMA library at '%s'.", + name, lib_path_.value())); + + } else { + return absl::InternalError(absl::StrFormat( + "Unable to load a globally defined symbol called '%s'. Use the " + "JAX_GPU_MAGMA_PATH environment variable to specify an explicit " + "path to the library.", + name)); + } + } + + symbols_.insert({name, symbol}); + return symbol; + } + + private: + bool initialized_ = false; + void* handle_ = nullptr; + std::optional lib_path_ = std::nullopt; + absl::flat_hash_map symbols_; +}; + +// Lookup the MAGMA symbol for the given function name. This function only +// dlopen the MAGMA library once per process. +absl::StatusOr FindMagmaSymbol(const char name[]) { + static absl::Mutex mu; + static MagmaLookup& lookup = *new MagmaLookup ABSL_GUARDED_BY(mu); + absl::MutexLock lock(&mu); + auto status = lookup.Initialize(); + if (!status.ok()) { + return status; + } + return lookup.Find(name); +} + +// Real-valued eigendecomposition + +template +class EigRealHost { + using Real = ffi::NativeType; + + public: + explicit EigRealHost() = default; + EigRealHost(EigRealHost&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? 'V' : 'N'; + jobvr_ = right ? 'V' : 'N'; + int64_t lwork = EigenvalueDecomposition::GetWorkspaceSize( + n, static_cast(jobvl_), + static_cast(jobvr_)); + return MaybeCastNoOverflow(lwork); + } + + void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work, + int lwork, int* info) { + EigenvalueDecomposition::fn(&jobvl_, &jobvr_, &n_, x, &n_, wr, wi, + vl, &n_, vr, &n_, work, &lwork, info); + } + + private: + int n_; + char jobvl_, jobvr_; +}; + +template +class EigRealMagma { + using Real = ffi::NativeType; + using Fn = int(magma_vec_t, magma_vec_t, int, Real*, int, Real*, Real*, Real*, + int, Real*, int, Real*, int, int*); + + public: + explicit EigRealMagma() = default; + EigRealMagma(EigRealMagma&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? MagmaVec : MagmaNoVec; + jobvr_ = right ? MagmaVec : MagmaNoVec; + + auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name); + if (!maybe_ptr.ok()) return maybe_ptr.status(); + fn_ = reinterpret_cast(*maybe_ptr); + + int query_info; + Real query_host; + fn_(jobvl_, jobvr_, n, nullptr, n, nullptr, nullptr, nullptr, n, nullptr, n, + &query_host, -1, &query_info); + return static_cast(query_host); + } + + void compute(Real* x, Real* wr, Real* wi, Real* vl, Real* vr, Real* work, + int lwork, int* info) { + fn_(jobvl_, jobvr_, n_, x, n_, wr, wi, vl, n_, vr, n_, work, lwork, info); + } + + private: + int n_; + magma_vec_t jobvl_, jobvr_; + Fn* fn_ = nullptr; +}; + +template +ffi::Error EigReal(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result wr, + ffi::Result wi, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + + auto x_host = HostBuffer(x.element_count()); + FFI_RETURN_IF_ERROR_STATUS( + x_host.CopyFromDevice(stream, x.typed_data())); + + auto wr_host = HostBuffer(batch * cols); + auto wi_host = HostBuffer(batch * cols); + auto vl_host = HostBuffer(batch * cols * cols); + auto vr_host = HostBuffer(batch * cols * cols); + auto info_host = HostBuffer(batch); + + FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right)); + auto work_host = AllocateScratchMemory(lwork); + auto work_left = AllocateScratchMemory(cols * cols); + auto work_right = AllocateScratchMemory(cols * cols); + + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + const auto is_finite = [](auto* data, int64_t size) { + return absl::c_all_of(absl::MakeSpan(data, size), + [](auto value) { return std::isfinite(value); }); + }; + + for (int64_t i = 0; i < batch; ++i) { + if (is_finite(x_host.get() + i * cols * cols, cols * cols)) { + impl.compute(x_host.get() + i * cols * cols, wr_host.get() + i * cols, + wi_host.get() + i * cols, work_left.get(), work_right.get(), + work_host.get(), lwork, info_host.get() + i); + if (info_host.get()[i] == 0) { + if (left) { + UnpackEigenvectors(n, wi_host.get() + i * cols, work_left.get(), + vl_host.get() + i * cols * cols); + } + if (right) { + UnpackEigenvectors(n, wi_host.get() + i * cols, work_right.get(), + vr_host.get() + i * cols * cols); + } + } + } else { + info_host.get()[i] = -4; + } + } + + FFI_RETURN_IF_ERROR_STATUS( + wr_host.CopyToDevice(stream, wr->typed_data())); + FFI_RETURN_IF_ERROR_STATUS( + wi_host.CopyToDevice(stream, wi->typed_data())); + if (left) { + FFI_RETURN_IF_ERROR_STATUS( + vl_host.CopyToDevice(stream, vl->typed_data())); + } + if (right) { + FFI_RETURN_IF_ERROR_STATUS( + vr_host.CopyToDevice(stream, vr->typed_data())); + } + FFI_RETURN_IF_ERROR_STATUS( + info_host.CopyToDevice(stream, info->typed_data())); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + return ffi::Error::Success(); +} + +ffi::Error EigRealDispatch(gpuStream_t stream, bool magma, bool left, + bool right, ffi::AnyBuffer x, + ffi::Result wr, + ffi::Result wi, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + auto dataType = x.element_type(); + if (dataType != wr->element_type() || dataType != wi->element_type() || + ffi::ToComplex(dataType) != vl->element_type() || + ffi::ToComplex(dataType) != vr->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to eig must have the same element type"); + } + + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(x.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to eig must be square"); + } + FFI_RETURN_IF_ERROR(CheckShape(wr->dimensions(), {batch, cols}, "wr", "eig")); + FFI_RETURN_IF_ERROR(CheckShape(wi->dimensions(), {batch, cols}, "wi", "eig")); + if (left) { + FFI_RETURN_IF_ERROR( + CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig")); + } + if (right) { + FFI_RETURN_IF_ERROR( + CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig")); + + switch (dataType) { + case ffi::F32: + if (magma) { + return EigReal(EigRealMagma(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } else { + return EigReal(EigRealHost(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } + case ffi::F64: + if (magma) { + return EigReal(EigRealMagma(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } else { + return EigReal(EigRealHost(), batch, cols, stream, + left, right, x, wr, wi, vl, vr, info); + } + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in eig_real", absl::FormatStreamed(dataType))); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigReal, EigRealDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Attr("magma") + .Attr("left") + .Attr("right") + .Arg() // x + .Ret() // wr + .Ret() // wi + .Ret() // vl + .Ret() // vr + .Ret>() // info +); + +// Complex-valued eigendecomposition + +template +class EigCompHost { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + + public: + explicit EigCompHost() = default; + EigCompHost(EigCompHost&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? 'V' : 'N'; + jobvr_ = right ? 'V' : 'N'; + int64_t lwork = EigenvalueDecompositionComplex::GetWorkspaceSize( + n, static_cast(jobvl_), + static_cast(jobvr_)); + return MaybeCastNoOverflow(lwork); + } + + void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work, + int lwork, Real* rwork, int* info) { + EigenvalueDecompositionComplex::fn(&jobvl_, &jobvr_, &n_, x, &n_, + w, vl, &n_, vr, &n_, work, + &lwork, rwork, info); + } + + private: + int n_; + char jobvl_, jobvr_; +}; + +template +class EigCompMagma { + using Real = ffi::NativeType; + using Complex = ffi::NativeType; + using Fn = int(magma_vec_t, magma_vec_t, int, Complex*, int, Complex*, + Complex*, int, Complex*, int, Complex*, int, Real*, int*); + + public: + explicit EigCompMagma() = default; + EigCompMagma(EigCompMagma&&) = default; + + absl::StatusOr lwork(int n, bool left, bool right) { + n_ = n; + jobvl_ = left ? MagmaVec : MagmaNoVec; + jobvr_ = right ? MagmaVec : MagmaNoVec; + lda_ = std::max(n_, 1); + ldvl_ = left ? n_ : 1; + ldvr_ = right ? n_ : 1; + + auto maybe_ptr = FindMagmaSymbol(MagmaGeev::name); + if (!maybe_ptr.ok()) return maybe_ptr.status(); + fn_ = reinterpret_cast(*maybe_ptr); + + int query_info; + Complex query_host; + fn_(jobvl_, jobvr_, n_, nullptr, lda_, nullptr, nullptr, ldvl_, nullptr, + ldvr_, &query_host, -1, nullptr, &query_info); + return static_cast(query_host.real()); + } + + void compute(Complex* x, Complex* w, Complex* vl, Complex* vr, Complex* work, + int lwork, Real* rwork, int* info) { + fn_(jobvl_, jobvr_, n_, x, lda_, w, vl, ldvl_, vr, ldvr_, work, lwork, + rwork, info); + } + + private: + int n_, lda_, ldvl_, ldvr_; + magma_vec_t jobvl_, jobvr_; + Fn* fn_ = nullptr; +}; + +template +ffi::Error EigComp(Impl impl, int64_t batch, int64_t cols, gpuStream_t stream, + bool left, bool right, ffi::AnyBuffer x, + ffi::Result w, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + using Complex = ffi::NativeType; + + auto x_host = HostBuffer(x.element_count()); + FFI_RETURN_IF_ERROR_STATUS( + x_host.CopyFromDevice(stream, x.typed_data())); + + auto w_host = HostBuffer(batch * cols); + auto vl_host = HostBuffer(batch * cols * cols); + auto vr_host = HostBuffer(batch * cols * cols); + auto info_host = HostBuffer(batch); + + FFI_ASSIGN_OR_RETURN(int n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(int lwork, impl.lwork(n, left, right)); + auto work_host = AllocateScratchMemory(lwork); + auto rwork_host = + AllocateScratchMemory(2 * cols * cols); + + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + const auto is_finite = [](auto* data, int64_t size) { + return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) { + return std::isfinite(z.real()) && std::isfinite(z.imag()); + }); + }; + + for (int64_t i = 0; i < batch; ++i) { + if (is_finite(x_host.get() + i * cols * cols, cols * cols)) { + impl.compute(x_host.get() + i * cols * cols, w_host.get() + i * cols, + vl_host.get() + i * cols * cols, + vr_host.get() + i * cols * cols, work_host.get(), lwork, + rwork_host.get(), info_host.get() + i); + } else { + info_host.get()[i] = -4; + } + } + + FFI_RETURN_IF_ERROR_STATUS( + w_host.CopyToDevice(stream, w->typed_data())); + if (left) { + FFI_RETURN_IF_ERROR_STATUS( + vl_host.CopyToDevice(stream, vl->typed_data())); + } + if (right) { + FFI_RETURN_IF_ERROR_STATUS( + vr_host.CopyToDevice(stream, vr->typed_data())); + } + FFI_RETURN_IF_ERROR_STATUS( + info_host.CopyToDevice(stream, info->typed_data())); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + + return ffi::Error::Success(); +} + +ffi::Error EigCompDispatch(gpuStream_t stream, bool magma, bool left, + bool right, ffi::AnyBuffer x, + ffi::Result w, + ffi::Result vl, + ffi::Result vr, + ffi::Result> info) { + auto dataType = x.element_type(); + if (dataType != w->element_type() || dataType != vl->element_type() || + dataType != vr->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to eig must have the same element type"); + } + + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(x.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to eig must be square"); + } + FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "eig")); + if (left) { + FFI_RETURN_IF_ERROR( + CheckShape(vl->dimensions(), {batch, rows, cols}, "vl", "eig")); + } + if (right) { + FFI_RETURN_IF_ERROR( + CheckShape(vr->dimensions(), {batch, rows, cols}, "vr", "eig")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "eig")); + + switch (dataType) { + case ffi::C64: + if (magma) { + return EigComp(EigCompMagma(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } else { + return EigComp(EigCompHost(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } + case ffi::C128: + if (magma) { + return ffi::Error::InvalidArgument( + "Using MAGMA as the backend for eig_comp is not supported for " + "complex128 data types"); + } else { + return EigComp(EigCompHost(), batch, cols, stream, + left, right, x, w, vl, vr, info); + } + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in eig_comp", absl::FormatStreamed(dataType))); + } +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(kEigComp, EigCompDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Attr("magma") + .Attr("left") + .Attr("right") + .Arg() // x + .Ret() // w + .Ret() // vl + .Ret() // vr + .Ret>() // info +); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/hybrid_kernels.h b/jaxlib/gpu/hybrid_kernels.h new file mode 100644 index 000000000000..a70e72c745aa --- /dev/null +++ b/jaxlib/gpu/hybrid_kernels.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef JAXLIB_GPU_HYBRID_KERNELS_H_ +#define JAXLIB_GPU_HYBRID_KERNELS_H_ + +#include "jaxlib/gpu/vendor.h" +#include "xla/ffi/api/ffi.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigReal); +XLA_FFI_DECLARE_HANDLER_SYMBOL(kEigComp); + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAXLIB_GPU_HYBRID_KERNELS_H_ diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 03fd43e9ef89..434e724de902 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -56,6 +56,21 @@ xla_client.register_custom_call_target(_name, _value, platform="CUDA", api_version=api_version) +for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: + try: + _cuhybrid = importlib.import_module( + f"{cuda_module_name}._hybrid", package="jaxlib" + ) + except ImportError: + _cuhybrid = None + else: + break + +if _cuhybrid: + for _name, _value in _cuhybrid.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="CUDA", + api_version=1) + try: from .rocm import _blas as _hipblas # pytype: disable=import-error except ImportError: @@ -88,6 +103,27 @@ xla_client.register_custom_call_target(_name, _value, platform="ROCM", api_version=api_version) +for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: + try: + _hiphybrid = importlib.import_module( + f"{rocm_module_name}._hybrid", package="jaxlib" + ) + except ImportError: + _hiphybrid = None + else: + break + +if _hiphybrid: + for _name, _value in _hiphybrid.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="ROCM", + api_version=1) + +def initialize_hybrid_kernels(): + if _cuhybrid: + _cuhybrid.initialize() + if _hiphybrid: + _hiphybrid.initialize() + def _real_type(dtype): """Returns the real equivalent of 'dtype'.""" return np.finfo(dtype).dtype diff --git a/jaxlib/rocm/BUILD b/jaxlib/rocm/BUILD index c9b73a5785f1..1076f9a77bf8 100644 --- a/jaxlib/rocm/BUILD +++ b/jaxlib/rocm/BUILD @@ -389,6 +389,48 @@ pybind_extension( ], ) +cc_library( + name = "hip_hybrid_kernels", + srcs = ["//jaxlib/gpu:hybrid_kernels.cc"], + hdrs = ["//jaxlib/gpu:hybrid_kernels.h"], + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_vendor", + "//jaxlib:ffi_helpers", + "//jaxlib/cpu:lapack_kernels", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@xla//xla/ffi/api:ffi", + ], +) + +pybind_extension( + name = "_hybrid", + srcs = ["//jaxlib/gpu:hybrid.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + module_name = "_hybrid", + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_hybrid_kernels", + ":hip_vendor", + "//jaxlib:kernel_nanobind_helpers", + "//jaxlib/cpu:lapack_kernels", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", + "@xla//xla/ffi/api:ffi", + ], +) + cc_library( name = "triton_kernels", srcs = ["//jaxlib/gpu:triton_kernels.cc"], @@ -456,6 +498,7 @@ py_library( name = "rocm_gpu_support", deps = [ ":_blas", + ":_hybrid", ":_linalg", ":_prng", ":_solver", diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 5b3ac636303a..9a47c6ad5409 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -108,6 +108,7 @@ def prepare_wheel_cuda( f"__main__/jaxlib/cuda/_rnn.{pyext}", f"__main__/jaxlib/cuda/_sparse.{pyext}", f"__main__/jaxlib/cuda/_triton.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", f"__main__/jaxlib/cuda_plugin_extension.{pyext}", f"__main__/jaxlib/mosaic/gpu/_mosaic_gpu_ext.{pyext}", @@ -144,6 +145,7 @@ def prepare_wheel_rocm( f"__main__/jaxlib/rocm/_linalg.{pyext}", f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_sparse.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", f"__main__/jaxlib/rocm_plugin_extension.{pyext}", "__main__/jaxlib/version.py", diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 438cebca2b06..4db36fa0ea97 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -231,6 +231,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/cuda/_rnn.{pyext}", f"__main__/jaxlib/cuda/_sparse.{pyext}", f"__main__/jaxlib/cuda/_triton.{pyext}", + f"__main__/jaxlib/cuda/_hybrid.{pyext}", f"__main__/jaxlib/cuda/_versions.{pyext}", ], ) @@ -244,6 +245,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/rocm/_prng.{pyext}", f"__main__/jaxlib/rocm/_sparse.{pyext}", f"__main__/jaxlib/rocm/_triton.{pyext}", + f"__main__/jaxlib/rocm/_hybrid.{pyext}", ], ) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 5ace4b5ecf18..cac018ef8978 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -34,6 +34,7 @@ from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge +from jax._src.lib import version as jaxlib_version from jax._src.numpy.util import promote_dtypes_inexact config.parse_flags_with_absl() @@ -234,11 +235,13 @@ def testIssue1213(self): compute_left_eigenvectors=[False, True], compute_right_eigenvectors=[False, True], ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + # TODO(danfm): enable when there is an eigendecomposition implementation + # for TPU. + @jtu.run_on_devices("cpu", "gpu") def testEig(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) n = shape[-1] args_maker = lambda: [rng(shape, dtype)] @@ -277,12 +280,14 @@ def check_left_eigenvectors(a, w, vl): compute_left_eigenvectors=[False, True], compute_right_eigenvectors=[False, True], ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + # TODO(danfm): enable when there is an eigendecomposition implementation + # for TPU. + @jtu.run_on_devices("cpu", "gpu") def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, compute_right_eigenvectors): """Verifies that `eig` fails gracefully if given non-finite inputs.""" + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") a = jnp.full(shape, jnp.nan, dtype) results = lax.linalg.eig( a, compute_left_eigenvectors=compute_left_eigenvectors, @@ -294,14 +299,16 @@ def testEigHandlesNanInputs(self, shape, dtype, compute_left_eigenvectors, shape=[(4, 4), (5, 5), (8, 8), (7, 6, 6)], dtype=float_types + complex_types, ) - # TODO(phawkins): enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + # TODO(danfm): enable when there is an eigendecomposition implementation + # for TPU. + @jtu.run_on_devices("cpu", "gpu") def testEigvalsGrad(self, shape, dtype): # This test sometimes fails for large matrices. I (@j-towns) suspect, but # haven't checked, that might be because of perturbations causing the # ordering of eigenvalues to change, which will trip up check_grads. So we # just test on small-ish matrices. + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() @@ -313,10 +320,12 @@ def testEigvalsGrad(self, shape, dtype): shape=[(4, 4), (5, 5), (50, 50)], dtype=float_types + complex_types, ) - # TODO: enable when there is an eigendecomposition implementation - # for GPU/TPU. - @jtu.run_on_devices("cpu") + # TODO(danfm): enable when there is an eigendecomposition implementation + # for TPU. + @jtu.run_on_devices("cpu", "gpu") def testEigvals(self, shape, dtype): + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(shape, dtype)] a, = args_maker() @@ -324,9 +333,11 @@ def testEigvals(self, shape, dtype): w2 = jnp.linalg.eigvals(a) self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 2e-14}) - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigvalsInf(self): # https://github.com/jax-ml/jax/issues/2661 + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") x = jnp.array([[jnp.inf]]) self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x)))) @@ -334,8 +345,10 @@ def testEigvalsInf(self): shape=[(1, 1), (4, 4), (5, 5)], dtype=float_types + complex_types, ) - @jtu.run_on_devices("cpu") + @jtu.run_on_devices("cpu", "gpu") def testEigBatching(self, shape, dtype): + if jtu.test_device_matches(["gpu"]) and jaxlib_version <= (0, 4, 35): + self.skipTest("eig on GPU requires jaxlib version > 0.4.35") rng = jtu.rand_default(self.rng()) shape = (10,) + shape args = rng(shape, dtype)