Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions exla/c_src/exla/custom_calls/eig.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
#pragma once

#include <algorithm>
#include <complex>
#include <iostream>
#include <numeric>
#include <vector>

#include "Eigen/Eigenvalues"
#include "xla/ffi/api/ffi.h"
#include "xla/ffi/ffi_api.h"

namespace ffi = xla::ffi;

// For real input types, compute complex eigenvalues/eigenvectors
template <typename DataType, typename ComplexType>
void single_matrix_eig_cpu_custom_call_real(ComplexType *eigenvalues_out,
ComplexType *eigenvectors_out,
DataType *in, uint64_t m,
uint64_t n) {
typedef Eigen::Matrix<DataType, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor>
RowMajorMatrix;
typedef Eigen::Matrix<ComplexType, Eigen::Dynamic, 1> ComplexVector;
typedef Eigen::Matrix<ComplexType, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor>
ComplexRowMajorMatrix;

// Map the input matrix
Eigen::Map<RowMajorMatrix> input(in, m, n);

// Compute the Eigenvalue decomposition for general (non-symmetric) matrices
Eigen::EigenSolver<RowMajorMatrix> eigensolver(input);

if (eigensolver.info() != Eigen::Success) {
std::cerr << "Eigenvalue decomposition failed!" << std::endl;
return;
}

// Get the eigenvalues and eigenvectors (both are complex)
ComplexVector eigenvalues = eigensolver.eigenvalues();
ComplexRowMajorMatrix eigenvectors = eigensolver.eigenvectors();

// Create a vector of indices and sort it based on eigenvalues magnitude in
// decreasing order
std::vector<int> indices(m);
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&eigenvalues](int i, int j) {
return std::abs(eigenvalues(i)) > std::abs(eigenvalues(j));
});

// Sort eigenvalues and rearrange eigenvectors
ComplexVector sorted_eigenvalues(m);
ComplexRowMajorMatrix sorted_eigenvectors(m, n);
for (int i = 0; i < m; ++i) {
sorted_eigenvalues(i) = eigenvalues(indices[i]);
sorted_eigenvectors.col(i) = eigenvectors.col(indices[i]);
}

// Copy the sorted eigenvalues to the output
std::memcpy(eigenvalues_out, sorted_eigenvalues.data(),
m * sizeof(ComplexType));

// Copy the sorted eigenvectors to the output
std::memcpy(eigenvectors_out, sorted_eigenvectors.data(),
m * n * sizeof(ComplexType));
}

// For complex input types
template <typename ComplexType>
void single_matrix_eig_cpu_custom_call_complex(ComplexType *eigenvalues_out,
ComplexType *eigenvectors_out,
ComplexType *in, uint64_t m,
uint64_t n) {
typedef Eigen::Matrix<ComplexType, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor>
ComplexRowMajorMatrix;
typedef Eigen::Matrix<ComplexType, Eigen::Dynamic, 1> ComplexVector;

// Map the input matrix
Eigen::Map<ComplexRowMajorMatrix> input(in, m, n);

// Compute the Eigenvalue decomposition for complex matrices
Eigen::ComplexEigenSolver<ComplexRowMajorMatrix> eigensolver(input);

if (eigensolver.info() != Eigen::Success) {
std::cerr << "Eigenvalue decomposition failed!" << std::endl;
return;
}

// Get the eigenvalues and eigenvectors
ComplexVector eigenvalues = eigensolver.eigenvalues();
ComplexRowMajorMatrix eigenvectors = eigensolver.eigenvectors();

// Create a vector of indices and sort it based on eigenvalues magnitude in
// decreasing order
std::vector<int> indices(m);
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&eigenvalues](int i, int j) {
return std::abs(eigenvalues(i)) > std::abs(eigenvalues(j));
});

// Sort eigenvalues and rearrange eigenvectors
ComplexVector sorted_eigenvalues(m);
ComplexRowMajorMatrix sorted_eigenvectors(m, n);
for (int i = 0; i < m; ++i) {
sorted_eigenvalues(i) = eigenvalues(indices[i]);
sorted_eigenvectors.col(i) = eigenvectors.col(indices[i]);
}

// Copy the sorted eigenvalues to the output
std::memcpy(eigenvalues_out, sorted_eigenvalues.data(),
m * sizeof(ComplexType));

// Copy the sorted eigenvectors to the output
std::memcpy(eigenvectors_out, sorted_eigenvectors.data(),
m * n * sizeof(ComplexType));
}

// For real types (f32, f64)
template <typename DataType, typename ComplexType, typename BufferType,
typename ComplexBufferType>
ffi::Error
eig_cpu_custom_call_impl_real(BufferType operand,
ffi::Result<ComplexBufferType> eigenvalues,
ffi::Result<ComplexBufferType> eigenvectors) {
auto operand_dims = operand.dimensions();
auto eigenvalues_dims = eigenvalues->dimensions();
auto eigenvectors_dims = eigenvectors->dimensions();

uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];

uint64_t batch_items = 1;
for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) {
batch_items *= *it;
}

uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1];
uint64_t eigenvectors_stride = m * n;
uint64_t inner_stride = m * n;

for (uint64_t i = 0; i < batch_items; i++) {
single_matrix_eig_cpu_custom_call_real<DataType, ComplexType>(
eigenvalues->typed_data() + i * eigenvalues_stride,
eigenvectors->typed_data() + i * eigenvectors_stride,
operand.typed_data() + i * inner_stride, m, n);
}

return ffi::Error::Success();
}

// For complex types (c64, c128)
template <typename ComplexType, typename BufferType>
ffi::Error
eig_cpu_custom_call_impl_complex(BufferType operand,
ffi::Result<BufferType> eigenvalues,
ffi::Result<BufferType> eigenvectors) {
auto operand_dims = operand.dimensions();
auto eigenvalues_dims = eigenvalues->dimensions();
auto eigenvectors_dims = eigenvectors->dimensions();

uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2];
uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1];

uint64_t batch_items = 1;
for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) {
batch_items *= *it;
}

uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1];
uint64_t eigenvectors_stride = m * n;
uint64_t inner_stride = m * n;

for (uint64_t i = 0; i < batch_items; i++) {
single_matrix_eig_cpu_custom_call_complex<ComplexType>(
eigenvalues->typed_data() + i * eigenvalues_stride,
eigenvectors->typed_data() + i * eigenvectors_stride,
operand.typed_data() + i * inner_stride, m, n);
}

return ffi::Error::Success();
}
20 changes: 20 additions & 0 deletions exla/c_src/exla/custom_calls/eig_c128.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "eig.h"

ffi::Error
eig_cpu_custom_call_c128_impl(ffi::Buffer<ffi::C128> operand,
ffi::ResultBuffer<ffi::C128> eigenvalues,
ffi::ResultBuffer<ffi::C128> eigenvectors) {
return eig_cpu_custom_call_impl_complex<std::complex<double>,
ffi::Buffer<ffi::C128>>(
operand, eigenvalues, eigenvectors);
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(eig_cpu_custom_call_c128,
eig_cpu_custom_call_c128_impl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::C128>>()
.Ret<ffi::Buffer<ffi::C128>>()
.Ret<ffi::Buffer<ffi::C128>>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eig_cpu_custom_call_c128",
"Host", eig_cpu_custom_call_c128);
20 changes: 20 additions & 0 deletions exla/c_src/exla/custom_calls/eig_c64.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "eig.h"

ffi::Error
eig_cpu_custom_call_c64_impl(ffi::Buffer<ffi::C64> operand,
ffi::ResultBuffer<ffi::C64> eigenvalues,
ffi::ResultBuffer<ffi::C64> eigenvectors) {
return eig_cpu_custom_call_impl_complex<std::complex<float>,
ffi::Buffer<ffi::C64>>(
operand, eigenvalues, eigenvectors);
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(eig_cpu_custom_call_c64,
eig_cpu_custom_call_c64_impl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::C64>>()
.Ret<ffi::Buffer<ffi::C64>>()
.Ret<ffi::Buffer<ffi::C64>>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eig_cpu_custom_call_c64", "Host",
eig_cpu_custom_call_c64);
20 changes: 20 additions & 0 deletions exla/c_src/exla/custom_calls/eig_f32.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "eig.h"

ffi::Error
eig_cpu_custom_call_f32_impl(ffi::Buffer<ffi::F32> operand,
ffi::ResultBuffer<ffi::C64> eigenvalues,
ffi::ResultBuffer<ffi::C64> eigenvectors) {
return eig_cpu_custom_call_impl_real<
float, std::complex<float>, ffi::Buffer<ffi::F32>, ffi::Buffer<ffi::C64>>(
operand, eigenvalues, eigenvectors);
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(eig_cpu_custom_call_f32,
eig_cpu_custom_call_f32_impl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::F32>>()
.Ret<ffi::Buffer<ffi::C64>>()
.Ret<ffi::Buffer<ffi::C64>>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eig_cpu_custom_call_f32", "Host",
eig_cpu_custom_call_f32);
21 changes: 21 additions & 0 deletions exla/c_src/exla/custom_calls/eig_f64.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include "eig.h"

ffi::Error
eig_cpu_custom_call_f64_impl(ffi::Buffer<ffi::F64> operand,
ffi::ResultBuffer<ffi::C128> eigenvalues,
ffi::ResultBuffer<ffi::C128> eigenvectors) {
return eig_cpu_custom_call_impl_real<double, std::complex<double>,
ffi::Buffer<ffi::F64>,
ffi::Buffer<ffi::C128>>(
operand, eigenvalues, eigenvectors);
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(eig_cpu_custom_call_f64,
eig_cpu_custom_call_f64_impl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::F64>>()
.Ret<ffi::Buffer<ffi::C128>>()
.Ret<ffi::Buffer<ffi::C128>>());

XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eig_cpu_custom_call_f64", "Host",
eig_cpu_custom_call_f64);
31 changes: 31 additions & 0 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,37 @@ defmodule EXLA.Defn do
{[to_type(eigenvals, eigenvals_expr.type), to_type(eigenvecs, eigenvecs_expr.type)], cache}
end

defp cached_recur_operator(
:optional,
%T{
data: %Expr{
args: [
%{data: %{op: :eig, args: [tensor, _opts]}},
{eigenvals_expr, eigenvecs_expr},
_callback
]
}
},
%{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state,
cache
) do
# We match only on platform: :host for MLIR, as we want to support
# eig-on-cpu as a custom call only in this case
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()

# Ensure output is complex type, converting to at least c64
out_type = Nx.Type.to_complex(op_type(tensor))

{eigenvals, eigenvecs} =
Value.eig(
tensor,
expr_to_typespec(%{eigenvals_expr | type: out_type}),
expr_to_typespec(%{eigenvecs_expr | type: out_type})
)

{[to_type(eigenvals, eigenvals_expr.type), to_type(eigenvecs, eigenvecs_expr.type)], cache}
end

defp cached_recur_operator(
:optional,
%T{
Expand Down
36 changes: 36 additions & 0 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,42 @@ defmodule EXLA.MLIR.Value do
{eigenvals, eigenvecs}
end

def eig(%Value{function: func} = value, eigenvals_typespec, eigenvecs_typespec) do
%{type: op_type} = get_typespec(value)

operands = [value]
result_types = typespecs_to_mlir_types([eigenvals_typespec, eigenvecs_typespec])

call_target_name =
case op_type do
{:f, 32} ->
"eig_cpu_custom_call_f32"

{:f, 64} ->
"eig_cpu_custom_call_f64"

{:c, 64} ->
"eig_cpu_custom_call_c64"

{:c, 128} ->
"eig_cpu_custom_call_c128"

type ->
# Due to matching on EXLA.Defn, we are sure that the device here is always :host
raise "Eig decomposition not supported on :host device for type #{inspect(type)}"
end

attributes = [
call_target_name: attr_string(call_target_name),
api_version: attr_i32(4)
]

[eigenvals, eigenvecs] =
op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes)

{eigenvals, eigenvecs}
end

def qr(%Value{function: func} = value, q_typespec, r_typespec) do
%{type: op_type} = get_typespec(value)

Expand Down
Loading