Skip to content

Commit

Permalink
Add a GPU implementation of lax.linalg.eig.
Browse files Browse the repository at this point in the history
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately).

This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.)

We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable.

For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now.

PiperOrigin-RevId: 691072237
  • Loading branch information
dfm authored and Google-ML-Automation committed Nov 6, 2024
1 parent dc33a28 commit d24721b
Show file tree
Hide file tree
Showing 15 changed files with 1,023 additions and 44 deletions.
17 changes: 17 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1953,3 +1953,20 @@ def _update_garbage_collection_guard(state, key, val):
),
include_in_jit_key=True,
)

gpu_use_magma = enum_state(
name='jax_gpu_use_magma',
enum_values=[
'off', 'on',
# TODO(danfm): After doing more complete benchmarks, add an 'auto'
# option which will automatically enable MAGMA for problem sizes where
# it typically gets better performance.
# 'auto',
],
default='off',
help=(
'Enable experimental support for MAGMA-backed lax.linalg.eig on GPU. '
'See the documentation for lax.linalg.eig for more details about how '
'to use this feature.'
),
)
96 changes: 95 additions & 1 deletion jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,25 @@ def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True,
compute_right_eigenvectors: bool = True) -> list[Array]:
"""Eigendecomposition of a general matrix.
Nonsymmetric eigendecomposition is at present only implemented on CPU.
Nonsymmetric eigendecomposition is only implemented on CPU and GPU. On GPU,
the default implementation calls LAPACK directly on the host CPU, but an
experimental GPU implementation using `MAGMA <https://icl.utk.edu/magma/>`_
is also available. The MAGMA implementation is typically slower than the
equivalent LAPACK implementation for small matrices (less than about 2048),
but it may perform better for larger matrices.
To enable the MAGMA implementation, you must install MAGMA yourself (there
are Debian and conda-forge packages, or you can build from source). Then set
the `jax_gpu_use_magma` configuration variable to `"on"`:
.. code-block:: python
jax.config.update('jax_gpu_use_magma', 'on')
JAX will try to ``dlopen`` the installed MAGMA shared library, raising an
error if it is not found. To explicitly specify the path to the MAGMA
library, set the environment variable `JAX_GPU_MAGMA_PATH` to the full
installation path.
Args:
x: A batch of square matrices with shape ``[..., n, n]``.
Expand Down Expand Up @@ -763,6 +781,78 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
return output


def _eig_gpu_impl(target_name_prefix, x, *, compute_left_eigenvectors,
compute_right_eigenvectors):
gpu_solver.initialize_hybrid_kernels()
dtype = x.dtype
is_real = dtype == np.float32 or dtype == np.float64
if is_real:
target_name = f"{target_name_prefix}hybrid_eig_real"
complex_dtype = np.complex64 if dtype == np.float32 else np.complex128
else:
target_name = f"{target_name_prefix}hybrid_eig_comp"
assert dtype == np.complex64 or dtype == np.complex128
complex_dtype = dtype

batch_dims = x.shape[:-2]
n, m = x.shape[-2:]
assert n == m
num_batch_dims = len(batch_dims)

layout = tuple(range(num_batch_dims)) + (num_batch_dims + 1, num_batch_dims)
out_types = [
api.ShapeDtypeStruct(batch_dims + (n,), dtype),
api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype),
api.ShapeDtypeStruct(batch_dims + (n, n), complex_dtype),
api.ShapeDtypeStruct(batch_dims, np.int32),
]
out_layouts = [None, layout, layout, None]
if is_real:
out_types = [api.ShapeDtypeStruct(batch_dims + (n,), dtype)] + out_types
out_layouts = [None] + out_layouts

magma = False
if config.gpu_use_magma.value == 'on':
magma = dtype != np.complex128

fun = ffi.ffi_call(target_name, out_types, input_layouts=[layout],
output_layouts=out_layouts)
*w, vl, vr, info = fun(x, magma=magma, left=compute_left_eigenvectors,
right=compute_right_eigenvectors)
if is_real:
assert len(w) == 2
w = lax.complex(*w)
else:
assert len(w) == 1
w = w[0]
ok = lax.eq(info, lax.zeros_like_array(info))
ok = _broadcast_to(ok[..., None], w.shape)
w = lax.select(ok, w, lax.full_like(w, np.nan + np.nan * 1j))
ok = _broadcast_to(ok[..., None], x.shape)
output = [w]
if compute_left_eigenvectors:
vl = lax.select(ok, vl, lax.full_like(vl, np.nan + np.nan * 1j))
output.append(vl)
if compute_right_eigenvectors:
vr = lax.select(ok, vr, lax.full_like(vr, np.nan + np.nan * 1j))
output.append(vr)
return output


def _eig_gpu_lowering(target_name_prefix, ctx, operand, *,
compute_left_eigenvectors, compute_right_eigenvectors):
if ctx.is_forward_compat():
raise NotImplementedError(
"Export of nonsymmetric eigendecomposition on GPU is not supported "
"because of forward compatibility. The "
"'jax_export_ignore_forward_compatibility' configuration option can be "
"used to disable this check.")
rule = mlir.lower_fun(partial(_eig_gpu_impl, target_name_prefix),
multiple_results=True)
return rule(ctx, operand, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)


def eig_batching_rule(batched_args, batch_dims, *, compute_left_eigenvectors,
compute_right_eigenvectors):
x, = batched_args
Expand Down Expand Up @@ -793,6 +883,10 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors,
eig_p.def_abstract_eval(eig_abstract_eval)
mlir.register_lowering(eig_p, eig_lower)
mlir.register_lowering(eig_p, _eig_cpu_lowering, platform='cpu')
mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'cu'),
platform='cuda')
mlir.register_lowering(eig_p, partial(_eig_gpu_lowering, 'hip'),
platform='rocm')
batching.primitive_batchers[eig_p] = eig_batching_rule
ad.primitive_jvps[eig_p] = eig_jvp_rule

Expand Down
4 changes: 3 additions & 1 deletion jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,9 @@ def eig(a: ArrayLike) -> tuple[Array, Array]:
- This differs from :func:`numpy.linalg.eig` in that the return type of
:func:`jax.numpy.linalg.eig` is always complex64 for 32-bit input, and complex128
for 64-bit input.
- At present, non-symmetric eigendecomposition is only implemented on the CPU backend.
- At present, non-symmetric eigendecomposition is only implemented on the CPU and
GPU backends. For more details about the GPU implementation, see the
documentation for :func:`jax.lax.linalg.eig`.
See also:
- :func:`jax.numpy.linalg.eigh`: eigenvectors and eigenvalues of a Hermitian matrix.
Expand Down
28 changes: 0 additions & 28 deletions jaxlib/cpu/lapack_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1094,34 +1094,6 @@ template struct EigenvalueDecompositionSymmetric<ffi::DataType::F64>;
template struct EigenvalueDecompositionHermitian<ffi::DataType::C64>;
template struct EigenvalueDecompositionHermitian<ffi::DataType::C128>;

// LAPACK uses a packed representation to represent a mixture of real
// eigenvectors and complex conjugate pairs. This helper unpacks the
// representation into regular complex matrices.
template <typename T>
static void UnpackEigenvectors(lapack_int n, const T* eigenvals_imag,
const T* packed, std::complex<T>* unpacked) {
for (int j = 0; j < n;) {
if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) {
// Real values in each row without imaginary part
// Second row of the imaginary part is not provided
for (int i = 0; i < n; ++i) {
unpacked[j * n + i] = {packed[j * n + i], 0.};
}
++j;
} else {
// Complex values where the real part is in the jth row
// and the imaginary part is in the next row (j + 1)
for (int i = 0; i < n; ++i) {
const T real_part = packed[j * n + i];
const T imag_part = packed[(j + 1) * n + i];
unpacked[j * n + i] = {real_part, imag_part};
unpacked[(j + 1) * n + i] = {real_part, -imag_part};
}
j += 2;
}
}
}

// lapack geev

template <typename T>
Expand Down
29 changes: 29 additions & 0 deletions jaxlib/cpu/lapack_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef JAXLIB_CPU_LAPACK_KERNELS_H_
#define JAXLIB_CPU_LAPACK_KERNELS_H_

#include <complex>
#include <cstdint>
#include <optional>
#include <type_traits>
Expand Down Expand Up @@ -462,6 +463,34 @@ struct EigenvalueDecompositionHermitian {

// lapack geev

// LAPACK uses a packed representation to represent a mixture of real
// eigenvectors and complex conjugate pairs. This helper unpacks the
// representation into regular complex matrices.
template <typename T, typename Int=lapack_int>
static void UnpackEigenvectors(Int n, const T* eigenvals_imag,
const T* packed, std::complex<T>* unpacked) {
for (int j = 0; j < n;) {
if (eigenvals_imag[j] == 0. || std::isnan(eigenvals_imag[j])) {
// Real values in each row without imaginary part
// Second row of the imaginary part is not provided
for (int i = 0; i < n; ++i) {
unpacked[j * n + i] = {packed[j * n + i], 0.};
}
++j;
} else {
// Complex values where the real part is in the jth row
// and the imaginary part is in the next row (j + 1)
for (int i = 0; i < n; ++i) {
const T real_part = packed[j * n + i];
const T imag_part = packed[(j + 1) * n + i];
unpacked[j * n + i] = {real_part, imag_part};
unpacked[(j + 1) * n + i] = {real_part, -imag_part};
}
j += 2;
}
}
}

template <typename T>
struct RealGeev {
using FnType = void(char* jobvl, char* jobvr, lapack_int* n, T* a,
Expand Down
50 changes: 50 additions & 0 deletions jaxlib/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,55 @@ pybind_extension(
],
)

cc_library(
name = "cuda_hybrid_kernels",
srcs = ["//jaxlib/gpu:hybrid_kernels.cc"],
hdrs = ["//jaxlib/gpu:hybrid_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:ffi_helpers",
"//jaxlib/cpu:lapack_kernels",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@xla//xla/ffi/api:ffi",
],
)

pybind_extension(
name = "_hybrid",
srcs = ["//jaxlib/gpu:hybrid.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
linkopts = select({
"@xla//xla/python:use_jax_cuda_pip_rpaths": [
"-Wl,-rpath,$$ORIGIN/../../nvidia/cuda_runtime/lib",
],
"//conditions:default": [],
}),
module_name = "_hybrid",
deps = [
":cuda_gpu_kernel_helpers",
":cuda_hybrid_kernels",
":cuda_vendor",
"//jaxlib:kernel_nanobind_helpers",
"//jaxlib/cpu:lapack_kernels",
"@local_config_cuda//cuda:cuda_headers",
"@nanobind",
"@xla//xla/ffi/api:ffi",
"@xla//xla/tsl/cuda:cudart",
],
)

cc_library(
name = "cuda_gpu_kernels",
srcs = ["//jaxlib/gpu:gpu_kernels.cc"],
Expand Down Expand Up @@ -633,6 +682,7 @@ py_library(
name = "cuda_gpu_support",
deps = [
":_blas",
":_hybrid",
":_linalg",
":_prng",
":_rnn",
Expand Down
3 changes: 3 additions & 0 deletions jaxlib/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ exports_files(srcs = [
"gpu_kernel_helpers.cc",
"gpu_kernel_helpers.h",
"gpu_kernels.cc",
"hybrid.cc",
"hybrid_kernels.cc",
"hybrid_kernels.h",
"linalg.cc",
"linalg_kernels.cc",
"linalg_kernels.cu.cc",
Expand Down
59 changes: 59 additions & 0 deletions jaxlib/gpu/hybrid.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "nanobind/nanobind.h"
#include "jaxlib/cpu/lapack_kernels.h"
#include "jaxlib/gpu/hybrid_kernels.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "xla/ffi/api/ffi.h"

namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
namespace ffi = xla::ffi;
namespace nb = nanobind;

void GetLapackKernelsFromScipy() {
static bool initialized = false; // Protected by GIL
if (initialized) return;
nb::module_ cython_blas = nb::module_::import_("scipy.linalg.cython_blas");
nb::module_ cython_lapack =
nb::module_::import_("scipy.linalg.cython_lapack");
nb::dict lapack_capi = cython_lapack.attr("__pyx_capi__");
auto lapack_ptr = [&](const char* name) {
return nb::cast<nb::capsule>(lapack_capi[name]).data();
};

AssignKernelFn<EigenvalueDecomposition<ffi::F32>>(lapack_ptr("sgeev"));
AssignKernelFn<EigenvalueDecomposition<ffi::F64>>(lapack_ptr("dgeev"));
AssignKernelFn<EigenvalueDecompositionComplex<ffi::C64>>(lapack_ptr("cgeev"));
AssignKernelFn<EigenvalueDecompositionComplex<ffi::C128>>(
lapack_ptr("zgeev"));
}

NB_MODULE(_hybrid, m) {
m.def("initialize", GetLapackKernelsFromScipy);
m.def("registrations", []() {
nb::dict dict;
dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal);
dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp);
return dict;
});
}

} // namespace
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
Loading

0 comments on commit d24721b

Please sign in to comment.