-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a GPU implementation of
lax.linalg.eig
.
This feature has been in the queue for a long time (see #1259), and some folks have found that they can use `pure_callback` to call the CPU version as a workaround. It has recently come up that there can be issues when using `pure_callback` with JAX calls in the body (#24255; this should be investigated separately). This change adds a native solution for computing `lax.linalg.eig` on GPU. By default, this is implemented by calling LAPACK on host directly because this has good performance for small to moderately sized problems (less than about 2048^2). For larger matrices, a GPU-backed implementation based on [MAGMA](https://icl.utk.edu/magma/) can have significantly better performance. (I should note that I haven't done a huge amount of benchmarking yet, but this was the breakeven point used by PyTorch, and I find roughly similar behavior so far.) We don't want to add MAGMA as a required dependency, but if a user has installed it, JAX can use it when the `jax_gpu_eig_magma` configuration variable is set to `"on"`. By default, we try to dlopen `libmagma.so`, but the path to a non-standard installation location can be specified using the `JAX_GPU_MAGMA_PATH` environment variable. For reasons that I don't yet totally understand, the MAGMA implementation segfaults deep in the MAGMA internals for complex128 inputs, so I've disabled that configuration for now. PiperOrigin-RevId: 691072237
- Loading branch information
1 parent
dc33a28
commit d24721b
Showing
15 changed files
with
1,023 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* Copyright 2021 The JAX Authors. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "nanobind/nanobind.h" | ||
#include "jaxlib/cpu/lapack_kernels.h" | ||
#include "jaxlib/gpu/hybrid_kernels.h" | ||
#include "jaxlib/gpu/vendor.h" | ||
#include "jaxlib/kernel_nanobind_helpers.h" | ||
#include "xla/ffi/api/ffi.h" | ||
|
||
namespace jax { | ||
namespace JAX_GPU_NAMESPACE { | ||
namespace { | ||
namespace ffi = xla::ffi; | ||
namespace nb = nanobind; | ||
|
||
void GetLapackKernelsFromScipy() { | ||
static bool initialized = false; // Protected by GIL | ||
if (initialized) return; | ||
nb::module_ cython_blas = nb::module_::import_("scipy.linalg.cython_blas"); | ||
nb::module_ cython_lapack = | ||
nb::module_::import_("scipy.linalg.cython_lapack"); | ||
nb::dict lapack_capi = cython_lapack.attr("__pyx_capi__"); | ||
auto lapack_ptr = [&](const char* name) { | ||
return nb::cast<nb::capsule>(lapack_capi[name]).data(); | ||
}; | ||
|
||
AssignKernelFn<EigenvalueDecomposition<ffi::F32>>(lapack_ptr("sgeev")); | ||
AssignKernelFn<EigenvalueDecomposition<ffi::F64>>(lapack_ptr("dgeev")); | ||
AssignKernelFn<EigenvalueDecompositionComplex<ffi::C64>>(lapack_ptr("cgeev")); | ||
AssignKernelFn<EigenvalueDecompositionComplex<ffi::C128>>( | ||
lapack_ptr("zgeev")); | ||
} | ||
|
||
NB_MODULE(_hybrid, m) { | ||
m.def("initialize", GetLapackKernelsFromScipy); | ||
m.def("registrations", []() { | ||
nb::dict dict; | ||
dict[JAX_GPU_PREFIX "hybrid_eig_real"] = EncapsulateFfiHandler(kEigReal); | ||
dict[JAX_GPU_PREFIX "hybrid_eig_comp"] = EncapsulateFfiHandler(kEigComp); | ||
return dict; | ||
}); | ||
} | ||
|
||
} // namespace | ||
} // namespace JAX_GPU_NAMESPACE | ||
} // namespace jax |
Oops, something went wrong.