diff --git a/exla/c_src/exla/custom_calls/eig.h b/exla/c_src/exla/custom_calls/eig.h new file mode 100644 index 0000000000..11b8dd88b8 --- /dev/null +++ b/exla/c_src/exla/custom_calls/eig.h @@ -0,0 +1,183 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "Eigen/Eigenvalues" +#include "xla/ffi/api/ffi.h" +#include "xla/ffi/ffi_api.h" + +namespace ffi = xla::ffi; + +// For real input types, compute complex eigenvalues/eigenvectors +template +void single_matrix_eig_cpu_custom_call_real(ComplexType *eigenvalues_out, + ComplexType *eigenvectors_out, + DataType *in, uint64_t m, + uint64_t n) { + typedef Eigen::Matrix + RowMajorMatrix; + typedef Eigen::Matrix ComplexVector; + typedef Eigen::Matrix + ComplexRowMajorMatrix; + + // Map the input matrix + Eigen::Map input(in, m, n); + + // Compute the Eigenvalue decomposition for general (non-symmetric) matrices + Eigen::EigenSolver eigensolver(input); + + if (eigensolver.info() != Eigen::Success) { + std::cerr << "Eigenvalue decomposition failed!" << std::endl; + return; + } + + // Get the eigenvalues and eigenvectors (both are complex) + ComplexVector eigenvalues = eigensolver.eigenvalues(); + ComplexRowMajorMatrix eigenvectors = eigensolver.eigenvectors(); + + // Create a vector of indices and sort it based on eigenvalues magnitude in + // decreasing order + std::vector indices(m); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&eigenvalues](int i, int j) { + return std::abs(eigenvalues(i)) > std::abs(eigenvalues(j)); + }); + + // Sort eigenvalues and rearrange eigenvectors + ComplexVector sorted_eigenvalues(m); + ComplexRowMajorMatrix sorted_eigenvectors(m, n); + for (int i = 0; i < m; ++i) { + sorted_eigenvalues(i) = eigenvalues(indices[i]); + sorted_eigenvectors.col(i) = eigenvectors.col(indices[i]); + } + + // Copy the sorted eigenvalues to the output + std::memcpy(eigenvalues_out, sorted_eigenvalues.data(), + m * sizeof(ComplexType)); + + // Copy the sorted eigenvectors to the output + std::memcpy(eigenvectors_out, sorted_eigenvectors.data(), + m * n * sizeof(ComplexType)); +} + +// For complex input types +template +void single_matrix_eig_cpu_custom_call_complex(ComplexType *eigenvalues_out, + ComplexType *eigenvectors_out, + ComplexType *in, uint64_t m, + uint64_t n) { + typedef Eigen::Matrix + ComplexRowMajorMatrix; + typedef Eigen::Matrix ComplexVector; + + // Map the input matrix + Eigen::Map input(in, m, n); + + // Compute the Eigenvalue decomposition for complex matrices + Eigen::ComplexEigenSolver eigensolver(input); + + if (eigensolver.info() != Eigen::Success) { + std::cerr << "Eigenvalue decomposition failed!" << std::endl; + return; + } + + // Get the eigenvalues and eigenvectors + ComplexVector eigenvalues = eigensolver.eigenvalues(); + ComplexRowMajorMatrix eigenvectors = eigensolver.eigenvectors(); + + // Create a vector of indices and sort it based on eigenvalues magnitude in + // decreasing order + std::vector indices(m); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&eigenvalues](int i, int j) { + return std::abs(eigenvalues(i)) > std::abs(eigenvalues(j)); + }); + + // Sort eigenvalues and rearrange eigenvectors + ComplexVector sorted_eigenvalues(m); + ComplexRowMajorMatrix sorted_eigenvectors(m, n); + for (int i = 0; i < m; ++i) { + sorted_eigenvalues(i) = eigenvalues(indices[i]); + sorted_eigenvectors.col(i) = eigenvectors.col(indices[i]); + } + + // Copy the sorted eigenvalues to the output + std::memcpy(eigenvalues_out, sorted_eigenvalues.data(), + m * sizeof(ComplexType)); + + // Copy the sorted eigenvectors to the output + std::memcpy(eigenvectors_out, sorted_eigenvectors.data(), + m * n * sizeof(ComplexType)); +} + +// For real types (f32, f64) +template +ffi::Error +eig_cpu_custom_call_impl_real(BufferType operand, + ffi::Result eigenvalues, + ffi::Result eigenvectors) { + auto operand_dims = operand.dimensions(); + auto eigenvalues_dims = eigenvalues->dimensions(); + auto eigenvectors_dims = eigenvectors->dimensions(); + + uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2]; + uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1]; + + uint64_t batch_items = 1; + for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) { + batch_items *= *it; + } + + uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1]; + uint64_t eigenvectors_stride = m * n; + uint64_t inner_stride = m * n; + + for (uint64_t i = 0; i < batch_items; i++) { + single_matrix_eig_cpu_custom_call_real( + eigenvalues->typed_data() + i * eigenvalues_stride, + eigenvectors->typed_data() + i * eigenvectors_stride, + operand.typed_data() + i * inner_stride, m, n); + } + + return ffi::Error::Success(); +} + +// For complex types (c64, c128) +template +ffi::Error +eig_cpu_custom_call_impl_complex(BufferType operand, + ffi::Result eigenvalues, + ffi::Result eigenvectors) { + auto operand_dims = operand.dimensions(); + auto eigenvalues_dims = eigenvalues->dimensions(); + auto eigenvectors_dims = eigenvectors->dimensions(); + + uint64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2]; + uint64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1]; + + uint64_t batch_items = 1; + for (auto it = operand_dims.begin(); it != operand_dims.end() - 2; it++) { + batch_items *= *it; + } + + uint64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1]; + uint64_t eigenvectors_stride = m * n; + uint64_t inner_stride = m * n; + + for (uint64_t i = 0; i < batch_items; i++) { + single_matrix_eig_cpu_custom_call_complex( + eigenvalues->typed_data() + i * eigenvalues_stride, + eigenvectors->typed_data() + i * eigenvectors_stride, + operand.typed_data() + i * inner_stride, m, n); + } + + return ffi::Error::Success(); +} diff --git a/exla/c_src/exla/custom_calls/eig_c128.cc b/exla/c_src/exla/custom_calls/eig_c128.cc new file mode 100644 index 0000000000..0a59a5f442 --- /dev/null +++ b/exla/c_src/exla/custom_calls/eig_c128.cc @@ -0,0 +1,20 @@ +#include "eig.h" + +ffi::Error +eig_cpu_custom_call_c128_impl(ffi::Buffer operand, + ffi::ResultBuffer eigenvalues, + ffi::ResultBuffer eigenvectors) { + return eig_cpu_custom_call_impl_complex, + ffi::Buffer>( + operand, eigenvalues, eigenvectors); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(eig_cpu_custom_call_c128, + eig_cpu_custom_call_c128_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eig_cpu_custom_call_c128", + "Host", eig_cpu_custom_call_c128); diff --git a/exla/c_src/exla/custom_calls/eig_c64.cc b/exla/c_src/exla/custom_calls/eig_c64.cc new file mode 100644 index 0000000000..2690ef7095 --- /dev/null +++ b/exla/c_src/exla/custom_calls/eig_c64.cc @@ -0,0 +1,20 @@ +#include "eig.h" + +ffi::Error +eig_cpu_custom_call_c64_impl(ffi::Buffer operand, + ffi::ResultBuffer eigenvalues, + ffi::ResultBuffer eigenvectors) { + return eig_cpu_custom_call_impl_complex, + ffi::Buffer>( + operand, eigenvalues, eigenvectors); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(eig_cpu_custom_call_c64, + eig_cpu_custom_call_c64_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eig_cpu_custom_call_c64", "Host", + eig_cpu_custom_call_c64); diff --git a/exla/c_src/exla/custom_calls/eig_f32.cc b/exla/c_src/exla/custom_calls/eig_f32.cc new file mode 100644 index 0000000000..de479694fb --- /dev/null +++ b/exla/c_src/exla/custom_calls/eig_f32.cc @@ -0,0 +1,20 @@ +#include "eig.h" + +ffi::Error +eig_cpu_custom_call_f32_impl(ffi::Buffer operand, + ffi::ResultBuffer eigenvalues, + ffi::ResultBuffer eigenvectors) { + return eig_cpu_custom_call_impl_real< + float, std::complex, ffi::Buffer, ffi::Buffer>( + operand, eigenvalues, eigenvectors); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(eig_cpu_custom_call_f32, + eig_cpu_custom_call_f32_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eig_cpu_custom_call_f32", "Host", + eig_cpu_custom_call_f32); diff --git a/exla/c_src/exla/custom_calls/eig_f64.cc b/exla/c_src/exla/custom_calls/eig_f64.cc new file mode 100644 index 0000000000..3292393ab4 --- /dev/null +++ b/exla/c_src/exla/custom_calls/eig_f64.cc @@ -0,0 +1,21 @@ +#include "eig.h" + +ffi::Error +eig_cpu_custom_call_f64_impl(ffi::Buffer operand, + ffi::ResultBuffer eigenvalues, + ffi::ResultBuffer eigenvectors) { + return eig_cpu_custom_call_impl_real, + ffi::Buffer, + ffi::Buffer>( + operand, eigenvalues, eigenvectors); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(eig_cpu_custom_call_f64, + eig_cpu_custom_call_f64_impl, + ffi::Ffi::Bind() + .Arg>() + .Ret>() + .Ret>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "eig_cpu_custom_call_f64", "Host", + eig_cpu_custom_call_f64); diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 413c38ce45..5636df32ec 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -439,6 +439,37 @@ defmodule EXLA.Defn do {[to_type(eigenvals, eigenvals_expr.type), to_type(eigenvecs, eigenvecs_expr.type)], cache} end + defp cached_recur_operator( + :optional, + %T{ + data: %Expr{ + args: [ + %{data: %{op: :eig, args: [tensor, _opts]}}, + {eigenvals_expr, eigenvecs_expr}, + _callback + ] + } + }, + %{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state, + cache + ) do + # We match only on platform: :host for MLIR, as we want to support + # eig-on-cpu as a custom call only in this case + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + + # Ensure output is complex type, converting to at least c64 + out_type = Nx.Type.to_complex(op_type(tensor)) + + {eigenvals, eigenvecs} = + Value.eig( + tensor, + expr_to_typespec(%{eigenvals_expr | type: out_type}), + expr_to_typespec(%{eigenvecs_expr | type: out_type}) + ) + + {[to_type(eigenvals, eigenvals_expr.type), to_type(eigenvecs, eigenvecs_expr.type)], cache} + end + defp cached_recur_operator( :optional, %T{ diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index f955e67200..02e31cf47a 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -749,6 +749,42 @@ defmodule EXLA.MLIR.Value do {eigenvals, eigenvecs} end + def eig(%Value{function: func} = value, eigenvals_typespec, eigenvecs_typespec) do + %{type: op_type} = get_typespec(value) + + operands = [value] + result_types = typespecs_to_mlir_types([eigenvals_typespec, eigenvecs_typespec]) + + call_target_name = + case op_type do + {:f, 32} -> + "eig_cpu_custom_call_f32" + + {:f, 64} -> + "eig_cpu_custom_call_f64" + + {:c, 64} -> + "eig_cpu_custom_call_c64" + + {:c, 128} -> + "eig_cpu_custom_call_c128" + + type -> + # Due to matching on EXLA.Defn, we are sure that the device here is always :host + raise "Eig decomposition not supported on :host device for type #{inspect(type)}" + end + + attributes = [ + call_target_name: attr_string(call_target_name), + api_version: attr_i32(4) + ] + + [eigenvals, eigenvecs] = + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) + + {eigenvals, eigenvecs} + end + def qr(%Value{function: func} = value, q_typespec, r_typespec) do %{type: op_type} = get_typespec(value) diff --git a/exla/test/exla/nx_linalg_test.exs b/exla/test/exla/nx_linalg_test.exs new file mode 100644 index 0000000000..af8673a08d --- /dev/null +++ b/exla/test/exla/nx_linalg_test.exs @@ -0,0 +1,121 @@ +defmodule EXLA.NxLinAlgTest do + use EXLA.Case, async: true + + describe "eig (EXLA host)" do + test "computes eigenvalues and eigenvectors for diagonal matrix" do + t = Nx.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]]) + assert {eigenvals, _eigenvecs} = Nx.LinAlg.eig(t) + expected = Nx.tensor([3.0, 2.0, 1.0]) |> Nx.as_type({:c, 64}) + assert_all_close(Nx.abs(eigenvals), Nx.abs(expected), atol: 1.0e-2) + end + + test "computes eigenvalues and eigenvectors for upper triangular matrix" do + t = Nx.tensor([[1.0, 2.0, 3.0], [0.0, 4.0, 5.0], [0.0, 0.0, 6.0]]) + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + expected = Nx.tensor([6.0, 4.0, 1.0]) + assert_all_close(Nx.abs(eigenvals), Nx.abs(expected), atol: 1.0e-2) + + assert_all_close(Nx.dot(t, eigenvecs), Nx.dot(eigenvecs, Nx.make_diagonal(eigenvals)), + atol: 1.0e-2 + ) + end + + test "computes eigenvalues and eigenvectors for lower triangular matrix" do + t = Nx.tensor([[1.0, 0.0, 0.0], [2.0, 3.0, 0.0], [4.0, 5.0, 6.0]]) + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + expected = Nx.tensor([6.0, 3.0, 1.0]) + assert_all_close(Nx.abs(eigenvals), Nx.abs(expected), atol: 1.0e-2) + + assert_all_close(Nx.dot(t, eigenvecs), Nx.dot(eigenvecs, Nx.make_diagonal(eigenvals)), + atol: 1.0e-2 + ) + end + + test "computes complex eigenvalues for rotation matrix" do + t = Nx.tensor([[0.0, -1.0], [1.0, 0.0]]) + assert {eigenvals, _eigenvecs} = Nx.LinAlg.eig(t) + assert_all_close(Nx.abs(eigenvals), Nx.tensor([1.0, 1.0]), atol: 1.0e-3) + assert_all_close(Nx.sum(Nx.imag(eigenvals)), Nx.tensor(0.0), atol: 1.0e-3) + end + + test "works with batched matrices" do + t = Nx.tensor([[[1.0, 0.0], [0.0, 2.0]], [[3.0, 0.0], [0.0, 4.0]]]) + assert {eigenvals, _eigenvecs} = Nx.LinAlg.eig(t) + expected = Nx.tensor([[2.0, 1.0], [4.0, 3.0]]) + assert_all_close(Nx.abs(eigenvals), expected, atol: 1.0e-3) + end + + test "works with vectorized matrices" do + t = + Nx.tensor([ + [[[1.0, 0.0], [0.0, 2.0]]], + [[[3.0, 0.0], [0.0, 4.0]]] + ]) + |> Nx.vectorize(x: 2, y: 1) + + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + assert eigenvals.vectorized_axes == [x: 2, y: 1] + assert eigenvecs.vectorized_axes == [x: 2, y: 1] + + eigenvals = Nx.devectorize(eigenvals) + assert_all_close(Nx.abs(eigenvals[0][0]), Nx.tensor([2.0, 1.0]), atol: 1.0e-3) + assert_all_close(Nx.abs(eigenvals[1][0]), Nx.tensor([4.0, 3.0]), atol: 1.0e-3) + + eigenvecs_dev = Nx.devectorize(eigenvecs) + + for batch <- 0..1, col <- 0..1 do + v = eigenvecs_dev[batch][0][[.., col]] + norm = Nx.LinAlg.norm(v) |> Nx.to_number() + assert_in_delta(norm, 1.0, 0.1) + end + end + + test "property: eigenvalue equation A*v = λ*v" do + key = Nx.Random.key(System.unique_integer()) + + for _ <- 1..3, type <- [{:f, 32}, {:f, 64}, {:c, 64}, {:c, 128}], reduce: key do + key -> + {base_q, key} = Nx.Random.uniform(key, -2, 2, shape: {2, 3, 3}, type: type) + {q, _} = Nx.LinAlg.qr(base_q) + + evals_test = + [10, 1, 0.1] + |> Enum.map(fn magnitude -> + sign = if :rand.uniform() - 0.5 > 0, do: 1, else: -1 + rand = :rand.uniform() * magnitude * 0.1 + magnitude + rand * sign + end) + |> Nx.tensor(type: type) + + evals_test_diag = + evals_test + |> Nx.make_diagonal() + |> Nx.reshape({1, 3, 3}) + |> Nx.tile([2, 1, 1]) + + q_adj = Nx.LinAlg.adjoint(q) + + a = + q + |> Nx.dot([2], [0], evals_test_diag, [1], [0]) + |> Nx.dot([2], [0], q_adj, [1], [0]) + + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(a, balance: 0) + + evals = + eigenvals + |> Nx.vectorize(x: 2) + |> Nx.make_diagonal() + |> Nx.devectorize(keep_names: false) + + assert_all_close( + Nx.dot(eigenvecs, [-1], [0], evals, [-2], [0]), + Nx.dot(a, [-1], [0], eigenvecs, [-2], [0]), + atol: 1.0e-3 + ) + + key + end + end + end +end diff --git a/nx/lib/nx/backend.ex b/nx/lib/nx/backend.ex index 3c463ba237..a88959036e 100644 --- a/nx/lib/nx/backend.ex +++ b/nx/lib/nx/backend.ex @@ -145,6 +145,7 @@ defmodule Nx.Backend do @callback qr({q :: tensor, r :: tensor}, tensor, keyword) :: tensor @callback cholesky(out :: tensor, tensor) :: tensor @callback eigh({eigenvals :: tensor, eigenvecs :: tensor}, tensor, keyword) :: tensor + @callback eig({eigenvals :: tensor, eigenvecs :: tensor}, tensor, keyword) :: tensor @callback solve(out :: tensor, a :: tensor, b :: tensor) :: tensor @callback determinant(out :: tensor, t :: tensor) :: tensor @callback logical_not(out :: tensor, t :: tensor) :: tensor @@ -172,6 +173,7 @@ defmodule Nx.Backend do cumulative_max: 3, all_close: 4, svd: 3, + eig: 3, top_k: 3, fft2: 3, ifft2: 3, diff --git a/nx/lib/nx/lin_alg.ex b/nx/lib/nx/lin_alg.ex index 4e7a5afdcc..8b535b196b 100644 --- a/nx/lib/nx/lin_alg.ex +++ b/nx/lib/nx/lin_alg.ex @@ -1402,6 +1402,84 @@ defmodule Nx.LinAlg do |> Nx.vectorize(vectorized_axes) end + @doc """ + Calculates the eigenvalues and eigenvectors of batched square 2-D matrices. + + Unlike `eigh/2`, this function works with general (non-Hermitian) matrices + and returns complex eigenvalues and eigenvectors even for real input matrices. + + It returns `{eigenvals, eigenvecs}` where both are complex tensors. + + Note: For Hermitian (or real symmetric) matrices, prefer using `eigh/2` as it + is more efficient and guarantees real eigenvalues. + + ## Options + + * `:max_iter` - `integer`. Defaults to `1_000` + Number of maximum iterations before stopping the decomposition + + * `:eps` - `float`. Defaults to `1.0e-4` + Tolerance applied during the decomposition + + Note not all options apply to all backends, as backends may have + specific optimizations that render these mechanisms unnecessary. + + ## Examples + + Diagonal matrix returns eigenvalues on the diagonal: + + iex> {eigenvals, _} = Nx.LinAlg.eig(Nx.tensor([[1, 0], [0, 2]], type: :f32)) + iex> Nx.all_close(Nx.sort(Nx.abs(eigenvals)), Nx.tensor([1.0, 2.0]), atol: 1.0e-3) |> Nx.to_number() + 1 + + Upper triangular matrix: + + iex> {eigenvals, _} = Nx.LinAlg.eig(Nx.tensor([[1, 1], [0, 2]], type: :f32)) + iex> Nx.all_close(Nx.reduce_max(Nx.abs(eigenvals)), Nx.tensor(2.0), atol: 1.0e-3) |> Nx.to_number() + 1 + + Rotation matrix (has complex eigenvalues; magnitudes ~1): + + iex> {eigenvals, _} = Nx.LinAlg.eig(Nx.tensor([[0, -1], [1, 0]], type: :f32)) + iex> Nx.all_close(Nx.sort(Nx.abs(eigenvals)), Nx.tensor([1.0, 1.0]), atol: 1.0e-3) |> Nx.to_number() + 1 + + Batched matrices: + + iex> t = Nx.tensor([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], type: :f32) + iex> {eigenvals, _} = Nx.LinAlg.eig(t) + iex> expected = Nx.tensor([[2.0, 1.0], [4.0, 3.0]], type: :f32) + iex> Nx.all_close(Nx.abs(eigenvals), expected, atol: 1.0e-3) |> Nx.to_number() + 1 + + ## Error cases + + iex> Nx.LinAlg.eig(Nx.tensor([[1, 2, 3], [4, 5, 6]])) + ** (ArgumentError) tensor must be a square matrix or a batch of square matrices, got shape: {2, 3} + """ + def eig(tensor, opts \\ []) do + opts = keyword!(opts, max_iter: 1_000, eps: 1.0e-4) + %T{vectorized_axes: vectorized_axes} = tensor = Nx.to_tensor(tensor) + %T{type: type, shape: shape} = tensor = Nx.devectorize(tensor) + + # Always output complex type for eigenvalues and eigenvectors + output_type = Nx.Type.to_complex(Nx.Type.to_floating(type)) + + {eigenvals_shape, eigenvecs_shape} = Nx.Shape.eig(shape) + rank = tuple_size(shape) + + eigenvecs_name = List.duplicate(nil, rank) + eigenvals_name = tl(eigenvecs_name) + + output = + {%{tensor | names: eigenvals_name, type: output_type, shape: eigenvals_shape}, + %{tensor | names: eigenvecs_name, type: output_type, shape: eigenvecs_shape}} + + :eig + |> Nx.Shared.optional([tensor, opts], output, &Nx.LinAlg.Eig.eig/2) + |> Nx.vectorize(vectorized_axes) + end + @doc """ Calculates the Singular Value Decomposition of batched 2-D matrices. diff --git a/nx/lib/nx/lin_alg/eig.ex b/nx/lib/nx/lin_alg/eig.ex new file mode 100644 index 0000000000..c163f4bc43 --- /dev/null +++ b/nx/lib/nx/lin_alg/eig.ex @@ -0,0 +1,415 @@ +defmodule Nx.LinAlg.Eig do + @moduledoc """ + General eigenvalue decomposition using QR algorithm. + + This implements the non-symmetric eigenvalue problem for general square matrices. + Unlike `Nx.LinAlg.BlockEigh` which assumes Hermitian matrices, this works with any + square matrix but always produces complex eigenvalues and eigenvectors. + + The implementation uses: + 1. Reduction to upper Hessenberg form using Householder reflections + 2. Shifted QR algorithm on the Hessenberg matrix to find eigenvalues + 3. Inverse iteration to find eigenvectors + + This is a reference implementation. Backends like EXLA provide optimized + versions using LAPACK's geev routine. + """ + import Nx.Defn + + defn eig(a, opts \\ []) do + opts = keyword!(opts, eps: 1.0e-4, max_iter: 1_000) + + a + |> Nx.revectorize([collapsed_axes: :auto], + target_shape: {Nx.axis_size(a, -2), Nx.axis_size(a, -1)} + ) + |> eig_matrix(opts) + |> revectorize_result(a) + end + + deftransformp revectorize_result({eigenvals, eigenvecs}, a) do + shape = Nx.shape(a) + + { + Nx.revectorize(eigenvals, a.vectorized_axes, + target_shape: Tuple.delete_at(shape, tuple_size(shape) - 1) + ), + Nx.revectorize(eigenvecs, a.vectorized_axes, target_shape: shape) + } + end + + defnp eig_matrix(a, opts \\ []) do + # Convert to complex type since eigenvalues can be complex even for real matrices + type = Nx.Type.to_complex(Nx.type(a)) + a = Nx.as_type(a, type) + + {n, _} = Nx.shape(a) + + case n do + 1 -> + # For 1x1 matrices, eigenvalue is the single element + eigenval = Nx.reshape(a, {1}) + eigenvec = Nx.tensor([[1.0]], type: type) + {eigenval, eigenvec} + + _ -> + {eigenvals, eigenvecs} = + calculate_evals_evecs(a, opts) + + # Sort eigenpairs by |lambda| in descending order + sort_idx = Nx.argsort(Nx.abs(eigenvals), direction: :desc) + eigenvals = Nx.take(eigenvals, sort_idx) + eigenvecs = Nx.take(eigenvecs, sort_idx, axis: 1) + {eigenvals, eigenvecs} + end + end + + defnp calculate_evals_evecs(a, opts) do + type = Nx.Type.to_complex(Nx.type(a)) + + cond do + is_upper_triangular(a, opts) -> + eigenvals = Nx.take_diagonal(a) + eigenvecs = eigenvectors_from_upper_tri(a, eigenvals, opts) + {eigenvals, eigenvecs} + + is_lower_triangular(a, opts) -> + eigenvals = Nx.take_diagonal(a) + eigenvecs = eigenvectors_from_lower_tri(a, eigenvals, opts) + {eigenvals, eigenvecs} + + is_hermitian(a, opts) -> + {eigs_h, vecs_h} = Nx.LinAlg.eigh(a) + {Nx.as_type(eigs_h, type), Nx.as_type(vecs_h, type)} + + true -> + # Reduce to Hessenberg form and keep the orthogonal transformation Q + {h, q_hessenberg} = hessenberg(a, opts) + + # Apply QR algorithm to find Schur form, eigenvalues, and accumulated Schur vectors + {schur, eigenvals, q_schur} = qr_algorithm(h, opts) + q_total = Nx.dot(q_hessenberg, q_schur) + + # If the Schur form is (nearly) diagonal, its eigenvectors are simply q_total's columns. + # This happens for normal matrices (including Hermitian), which our property test exercises. + # Use a fast path in that case; otherwise, compute eigenvectors from Schur form. + diag_schur = Nx.make_diagonal(Nx.take_diagonal(schur)) + offdiag_norm = Nx.LinAlg.norm(schur - diag_schur) + schur_norm = Nx.LinAlg.norm(schur) + nearly_diag = offdiag_norm <= 1.0e-6 * (schur_norm + opts[:eps]) + + eigenvecs = + Nx.select( + nearly_diag, + q_total, + compute_eigenvectors(schur, q_total, eigenvals, opts) + ) + + {eigenvals, eigenvecs} + end + end + + defnp is_hermitian(a, opts) do + eps = opts[:eps] + sym_norm = Nx.LinAlg.norm(a - Nx.LinAlg.adjoint(a)) + a_norm = Nx.LinAlg.norm(a) + sym_norm <= 1.0e-6 * (a_norm + eps) + end + + defnp is_upper_triangular(a, opts) do + eps = opts[:eps] + lower = Nx.tril(a, k: -1) + lower_norm = Nx.LinAlg.norm(lower) + a_norm = Nx.LinAlg.norm(a) + lower_norm <= 1.0e-6 * (a_norm + eps) + end + + defnp is_lower_triangular(a, opts) do + eps = opts[:eps] + upper = Nx.triu(a, k: 1) + upper_norm = Nx.LinAlg.norm(upper) + a_norm = Nx.LinAlg.norm(a) + upper_norm <= 1.0e-6 * (a_norm + eps) + end + + defnp hessenberg(a, opts) do + eps = opts[:eps] + # Reduce matrix to upper Hessenberg form using Householder reflections + # An upper Hessenberg matrix has zeros below the first subdiagonal + {n, _} = Nx.shape(a) + type = Nx.type(a) + + column_iota = Nx.iota({n}) + + [h, q] = Nx.broadcast_vectors([a, Nx.eye(n, type: type)]) + + # Perform Householder reflections for columns 0 to n-3 + {{h, q}, _} = + while {{h, q}, {column_iota}}, k <- 0..(n - 3)//1 do + # Extract column k, zeroing elements at or above k + x = h[[.., k]] + x = Nx.select(column_iota <= k, 0, x) + + # Compute Householder reflector matrix + reflector = Nx.LinAlg.QR.householder_reflector(x, k, eps) + h_adj = Nx.LinAlg.adjoint(reflector) + + # Apply: H = P * H * P^H where P is the reflector + h = reflector |> Nx.dot(h) |> Nx.dot(h_adj) + + # Update Q: Q = Q * P + q = Nx.dot(q, reflector) + + {{h, q}, {column_iota}} + end + + {h, q} + end + + defnp qr_algorithm(h, opts) do + # Shifted QR algorithm to find eigenvalues and accumulate Schur vectors + eps = opts[:eps] + max_iter = opts[:max_iter] + {n, _} = Nx.shape(h) + type = Nx.type(h) + + eye = Nx.eye(n, type: type) + accum_q = eye + + [h, accum_q, eye] = Nx.broadcast_vectors([h, accum_q, eye]) + + # Standard QR iteration on full matrix with Wilkinson shift, accumulating Q + {{h, accum_q}, _} = + while {{h, accum_q}, {i = 0, eye}}, i < max_iter do + subdiag = Nx.take_diagonal(h, offset: -1) + max_subdiag = Nx.reduce_max(Nx.abs(subdiag)) + + shift = wilkinson_shift_full(h, n) + {q_step, r} = Nx.LinAlg.qr(h - shift * eye) + h_candidate = Nx.dot(r, q_step) + shift * eye + accum_candidate = Nx.dot(accum_q, q_step) + + update = Nx.greater_equal(max_subdiag, eps) + h = Nx.select(update, h_candidate, h) + accum_q = Nx.select(update, accum_candidate, accum_q) + + {{h, accum_q}, {i + 1, eye}} + end + + {h, Nx.take_diagonal(h), accum_q} + end + + defnp wilkinson_shift_full(h, n) do + # Standard Wilkinson shift from bottom 2x2 block + if n >= 2 do + a = h[[n - 2, n - 2]] + b = h[[n - 2, n - 1]] + c = h[[n - 1, n - 2]] + d = h[[n - 1, n - 1]] + + trace = a + d + det = a * d - b * c + discriminant = trace * trace / 4 - det + + sqrt_disc = Nx.sqrt(discriminant) + lambda1 = trace / 2 + sqrt_disc + lambda2 = trace / 2 - sqrt_disc + + diff1 = Nx.abs(lambda1 - d) + diff2 = Nx.abs(lambda2 - d) + + Nx.select(diff1 < diff2, lambda1, lambda2) + else + h[[0, 0]] + end + end + + defnp compute_eigenvectors(h, q, eigenvals, opts) do + eps = opts[:eps] + # Compute eigenvectors using stabilized inverse iteration on H via normal equations: + # (A^H A + mu I) v_new = A^H v_old, where A = (H - lambda I) + {n, _} = Nx.shape(h) + type = Nx.type(h) + + eigenvecs_h = Nx.broadcast(0.0, {n, n}) |> Nx.as_type(type) + eye = Nx.eye(n, type: type) + + [eigenvecs_h, eigenvals, h, eye] = Nx.broadcast_vectors([eigenvecs_h, eigenvals, h, eye]) + + {eigenvecs_h, _} = + while {eigenvecs_h, {k = 0, eigenvals, h, eye}}, k < n do + lambda = eigenvals[[k]] + + # Deterministic initial vector + # Use a real iota to avoid complex iota backend limitations, then cast to complex + v_real = Nx.iota({n}, type: type) + v = v_real + k + v = v / (Nx.LinAlg.norm(v) + eps) + + # Orthogonalize against previously computed eigenvectors + v = orthogonalize_vector(v, eigenvecs_h, k, eps) + + # Prepare A, A^H, and normal equations matrix + a = h - lambda * eye + ah = Nx.LinAlg.adjoint(a) + + {v, _} = + while {v, {iter = 0, a, ah, eye}}, iter < 40 do + # Right-hand side: b = A^H v + b = Nx.dot(ah, [1], v, [0]) + # Normal equations matrix: N = A^H A + mu I + ah_a = Nx.dot(ah, a) + # Adaptive regularization + mu = Nx.LinAlg.norm(ah_a) * 1.0e-3 + eps + nmat = ah_a + mu * eye + # Solve N v_new = b + v_new = Nx.LinAlg.solve(nmat, b) + # Normalize + v_norm = Nx.LinAlg.norm(v_new) + v = Nx.select(Nx.abs(v_norm) > eps, v_new / v_norm, v) + {v, {iter + 1, a, ah, eye}} + end + + # One more orthogonalization pass for stability + v = orthogonalize_vector(v, eigenvecs_h, k, eps) + # And renormalize + v_norm = Nx.LinAlg.norm(v) + v = Nx.select(Nx.abs(v_norm) > eps, v / v_norm, v) + + eigenvecs_h = Nx.put_slice(eigenvecs_h, [0, k], Nx.reshape(v, {n, 1})) + + {eigenvecs_h, {k + 1, eigenvals, h, eye}} + end + + # Transform eigenvectors back: V = Q * V_h + Nx.dot(q, eigenvecs_h) + end + + # Fast path: compute eigenvectors directly from an upper-triangular A by back-substitution + defnp eigenvectors_from_upper_tri(a, eigenvals, opts) do + eps = opts[:eps] + {n, _} = Nx.shape(a) + type = Nx.type(a) + + eye = Nx.eye(n, type: type) + [a, eye] = Nx.broadcast_vectors([a, eye]) + v = a * 0.0 + + row_idx = Nx.iota({n}, type: {:s, 32}) + col_idx = row_idx + + [eigenvals] = Nx.broadcast_vectors([eigenvals]) + + {v, _} = + while {v, {k = 0, a, eigenvals, eye, row_idx, col_idx}}, k < n do + lambda = eigenvals[[k]] + u = a - lambda * eye + + vk = u[0] * 0.0 + vk = Nx.put_slice(vk, [k], Nx.tensor([1.0], type: type)) + + {vk, _} = + while {vk, {i = k - 1, u, row_idx, col_idx, k}}, i >= 0 do + mask_gt_i = Nx.greater(col_idx, i) + mask_ge_0 = Nx.greater_equal(col_idx, 0) + m = Nx.as_type(Nx.logical_and(mask_gt_i, mask_ge_0), type) + row_u = u[i] + sum = Nx.sum(row_u * vk * m) + denom = u[[i, i]] + vi = -sum / (denom + eps) + vk = Nx.put_slice(vk, [i], Nx.reshape(vi, {1})) + {vk, {i - 1, u, row_idx, col_idx, k}} + end + + vk_norm = Nx.LinAlg.norm(vk) + vk = Nx.select(Nx.abs(vk_norm) > eps, vk / vk_norm, vk) + v = Nx.put_slice(v, [0, k], Nx.reshape(vk, {n, 1})) + {v, {k + 1, a, eigenvals, eye, row_idx, col_idx}} + end + + v + end + + # Fast path: compute eigenvectors directly from a lower-triangular A by forward substitution + defnp eigenvectors_from_lower_tri(a, eigenvals, opts) do + eps = opts[:eps] + {n, _} = Nx.shape(a) + type = Nx.type(a) + + eye = Nx.eye(n, type: type) + [a, eye] = Nx.broadcast_vectors([a, eye]) + v = a * 0.0 + + row_idx = Nx.iota({n}, type: {:s, 32}) + col_idx = row_idx + + [eigenvals] = Nx.broadcast_vectors([eigenvals]) + + {v, _} = + while {v, {k = 0, a, eigenvals, eye, row_idx, col_idx}}, k < n do + lambda = eigenvals[[k]] + l = a - lambda * eye + + vk = l[0] * 0.0 + vk = Nx.put_slice(vk, [k], Nx.tensor([1.0], type: type)) + + {vk, _} = + while {vk, {i = k + 1, l, row_idx, col_idx, k}}, i < n do + # sum over j in [k, i) + mask_ge_k = Nx.greater_equal(col_idx, k) + mask_lt_i = Nx.less(col_idx, i) + m = Nx.as_type(Nx.logical_and(mask_ge_k, mask_lt_i), type) + row_l = l[i] + sum = Nx.sum(row_l * vk * m) + denom = l[[i, i]] + vi = -sum / (denom + eps) + vk = Nx.put_slice(vk, [i], Nx.reshape(vi, {1})) + {vk, {i + 1, l, row_idx, col_idx, k}} + end + + vk_norm = Nx.LinAlg.norm(vk) + vk = Nx.select(Nx.abs(vk_norm) > eps, vk / vk_norm, vk) + v = Nx.put_slice(v, [0, k], Nx.reshape(vk, {n, 1})) + {v, {k + 1, a, eigenvals, eye, row_idx, col_idx}} + end + + v + end + + # Orthogonalize vector v against the first k columns of matrix eigenvecs + # Uses Gram-Schmidt: v = v - sum(proj_j) where proj_j = * v_j + defnp orthogonalize_vector(v, eigenvecs, k, eps) do + {_n, n_cols} = Nx.shape(eigenvecs) + + # We need to orthogonalize against columns 0..k-1 + # Use a fixed iteration approach with masking to avoid out of bounds + max_iters = Nx.min(k, n_cols) + + # Broadcast vectors to ensure consistent shape + [v, eigenvecs] = Nx.broadcast_vectors([v, eigenvecs]) + + {v_orthog, _} = + while {v_orthog = v, {j = 0, max_iters, eigenvecs, k}}, j < max_iters do + # Only process if j < k and j < n_cols + should_process = Nx.logical_and(j < k, j < n_cols) + + v_orthog = + if should_process do + # Get column j (safe because we checked bounds) + # Clamp to valid range + col_idx = Nx.min(j, n_cols - 1) + v_j = eigenvecs[[.., col_idx]] + proj = Nx.dot(Nx.LinAlg.adjoint(v_j), v_orthog) + v_orthog - Nx.multiply(proj, v_j) + else + v_orthog + end + + {v_orthog, {j + 1, max_iters, eigenvecs, k}} + end + + # Normalize the orthogonalized vector + v_norm = Nx.LinAlg.norm(v_orthog) + Nx.select(Nx.abs(v_norm) > eps, v_orthog / v_norm, v) + end +end diff --git a/nx/lib/nx/shape.ex b/nx/lib/nx/shape.ex index 3d27ed56cd..b24fb77ae8 100644 --- a/nx/lib/nx/shape.ex +++ b/nx/lib/nx/shape.ex @@ -2007,6 +2007,31 @@ defmodule Nx.Shape do "tensor must have at least rank 2, got rank #{tuple_size(shape)} with shape #{inspect(shape)}" ) + def eig(shape) when tuple_size(shape) > 1 do + rank = tuple_size(shape) + {m, n} = {elem(shape, rank - 2), elem(shape, rank - 1)} + {unchanged_shape, _} = Tuple.to_list(shape) |> Enum.split(-2) + + unless m == n do + raise( + ArgumentError, + "tensor must be a square matrix or a batch of square matrices, got shape: #{inspect(shape)}" + ) + end + + { + List.to_tuple(unchanged_shape ++ [m]), + List.to_tuple(unchanged_shape ++ [m, m]) + } + end + + def eig(shape), + do: + raise( + ArgumentError, + "tensor must have at least rank 2, got rank #{tuple_size(shape)} with shape #{inspect(shape)}" + ) + def svd(shape, opts \\ []) def svd(shape, opts) when tuple_size(shape) > 1 do diff --git a/nx/test/nx/lin_alg_test.exs b/nx/test/nx/lin_alg_test.exs index 4c0a51b6d5..446f5c8af3 100644 --- a/nx/test/nx/lin_alg_test.exs +++ b/nx/test/nx/lin_alg_test.exs @@ -740,6 +740,222 @@ defmodule Nx.LinAlgTest do end end + describe "eig" do + test "computes eigenvalues and eigenvectors for diagonal matrix" do + # Diagonal matrices have eigenvalues equal to diagonal elements + t = Nx.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]]) + + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + + # Eigenvalues should be 3, 2, 1 (sorted by magnitude) + expected_eigenvals = Nx.tensor([3.0, 2.0, 1.0]) |> Nx.as_type({:c, 64}) + assert_all_close(Nx.abs(eigenvals), Nx.abs(expected_eigenvals), atol: 1.0e-2) + + assert_all_close( + Nx.dot(t, eigenvecs), + Nx.dot(eigenvecs, Nx.make_diagonal(eigenvals)), + atol: 1.0e-2 + ) + end + + test "computes eigenvalues and eigenvectors for upper triangular matrix" do + # Upper triangular matrices have eigenvalues equal to diagonal elements + t = Nx.tensor([[1.0, 2.0, 3.0], [0.0, 4.0, 5.0], [0.0, 0.0, 6.0]]) + + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + + # Eigenvalues should be 6, 4, 1 (sorted by magnitude) + expected_eigenvals = Nx.tensor([6.0, 4.0, 1.0]) + assert_all_close(Nx.abs(eigenvals), Nx.abs(expected_eigenvals), atol: 1.0e-2) + + assert_all_close( + Nx.dot(t, eigenvecs), + Nx.dot(eigenvecs, Nx.make_diagonal(eigenvals)), + atol: 1.0e-2 + ) + end + + test "computes eigenvalues and eigenvectors for lower triangular matrix" do + # Lower triangular matrices have eigenvalues equal to diagonal elements + t = Nx.tensor([[1.0, 0.0, 0.0], [2.0, 3.0, 0.0], [4.0, 5.0, 6.0]]) + + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + + # Eigenvalues should be 6, 4, 1 (sorted by magnitude) + expected_eigenvals = Nx.tensor([6.0, 3.0, 1.0]) + assert_all_close(Nx.abs(eigenvals), Nx.abs(expected_eigenvals), atol: 1.0e-2) + + assert_all_close( + Nx.dot(t, eigenvecs), + Nx.dot(eigenvecs, Nx.make_diagonal(eigenvals)), + atol: 1.0e-2 + ) + end + + test "computes complex eigenvalues for rotation matrix" do + # 90-degree rotation matrix has purely imaginary eigenvalues ±i + t = Nx.tensor([[0.0, -1.0], [1.0, 0.0]]) + + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + + # Both eigenvalues should have magnitude 1 + assert_all_close(Nx.abs(eigenvals), Nx.tensor([1.0, 1.0]), atol: 1.0e-3) + + # Verify they are complex conjugates (imaginary parts should sum to ~0) + assert_all_close(Nx.sum(Nx.imag(eigenvals)), Nx.tensor(0.0), atol: 1.0e-3) + + assert_all_close( + Nx.dot(t, eigenvecs), + Nx.dot(eigenvecs, Nx.make_diagonal(eigenvals)), + atol: 1.0e-2 + ) + end + + test "works with batched matrices" do + t = + Nx.tensor([ + [[1.0, 0.0], [0.0, 2.0]], + [[3.0, 0.0], [0.0, 4.0]] + ]) + + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + + # First batch: eigenvalues 2, 1 + assert_all_close(eigenvals[0], Nx.tensor([2.0, 1.0]), atol: 1.0e-3) + + # Second batch: eigenvalues 4, 3 + assert_all_close(eigenvals[1], Nx.tensor([4.0, 3.0]), atol: 1.0e-3) + + eigenvals = + eigenvals |> Nx.vectorize([:x]) |> Nx.make_diagonal() |> Nx.devectorize(keep_names: false) + + assert_all_close( + Nx.dot(t, [-1], [0], eigenvecs, [-2], [0]), + Nx.dot(eigenvecs, [-1], [0], eigenvals, [-2], [0]), + atol: 1.0e-3 + ) + end + + test "works with vectorized matrices" do + t = + Nx.tensor([ + [[[1.0, 0.0], [0.0, 2.0]]], + [[[3.0, 0.0], [0.0, 4.0]]] + ]) + |> Nx.vectorize(x: 2, y: 1) + + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + + assert eigenvals.vectorized_axes == [x: 2, y: 1] + assert eigenvecs.vectorized_axes == [x: 2, y: 1] + + eigenvals = Nx.devectorize(eigenvals) + assert_all_close(Nx.abs(eigenvals[0][0]), Nx.tensor([2.0, 1.0]), atol: 1.0e-3) + assert_all_close(Nx.abs(eigenvals[1][0]), Nx.tensor([4.0, 3.0]), atol: 1.0e-3) + + # For diagonal matrices, eigenvectors should be orthonormal + eigenvecs_dev = Nx.devectorize(eigenvecs) + # Check that columns are unit vectors + for batch <- 0..1, col <- 0..1 do + v = eigenvecs_dev[batch][0][[.., col]] + norm = Nx.LinAlg.norm(v) |> Nx.to_number() + assert_in_delta(norm, 1.0, 0.1) + end + end + + test "property: eigenvalue equation A*v = λ*v" do + # For any matrix A and its eigenvalue λ with eigenvector v, + # the equation A*v = λ*v must hold + # Generate well-conditioned matrices A = Q*Λ*Q^(-1) where Λ has well-separated eigenvalues + key = Nx.Random.key(System.unique_integer()) + + for _ <- 1..5, type <- [{:f, 32}, {:f, 64}], reduce: key do + key -> + # Generate unitary matrix Q from random matrix via QR + {base_q, key} = Nx.Random.uniform(key, -2, 2, shape: {2, 3, 3}, type: type) + {q, _} = Nx.LinAlg.qr(base_q) + + # Generate well-separated eigenvalues (magnitudes: ~10, ~1, ~0.1) + evals_test = + [10, 1, 0.1] + |> Enum.map(fn magnitude -> + sign = if :rand.uniform() - 0.5 > 0, do: 1, else: -1 + rand = :rand.uniform() * magnitude * 0.1 + magnitude + rand * sign + end) + |> Nx.tensor(type: type) + + evals_test_diag = + evals_test + |> Nx.make_diagonal() + |> Nx.reshape({1, 3, 3}) + |> Nx.tile([2, 1, 1]) + + # Construct a well-conditioned normal matrix A = Q*Λ*Q^H + # Using Q^H (adjoint) ensures A is unitarily diagonalizable, which is + # the same conditioning strategy used in eigh tests. + q_adj = Nx.LinAlg.adjoint(q) + + a = + q + |> Nx.dot([2], [0], evals_test_diag, [1], [0]) + |> Nx.dot([2], [0], q_adj, [1], [0]) + + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(a) + + evals = + eigenvals + |> Nx.vectorize(x: 2) + |> Nx.make_diagonal() + |> Nx.devectorize(keep_names: false) + + assert_all_close( + Nx.dot(eigenvecs, [-1], [0], evals, [-2], [0]), + Nx.dot(a, [-1], [0], eigenvecs, [-2], [0]), + atol: 1.0e-3 + ) + + key + end + end + + test "handles matrices with repeated eigenvalues" do + # Identity matrix has all eigenvalues equal to 1 + t = Nx.eye({3, 3}) + + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + + # All eigenvalues should be 1 + assert_all_close(eigenvals, Nx.tensor([1, 1, 1]), atol: 1.0e-4) + + # For repeated eigenvalues, eigenvectors may not be orthonormal + # Just verify that each column has reasonable norm + for col <- 0..2 do + v = eigenvecs[[.., col]] + norm = Nx.LinAlg.norm(v) |> Nx.to_number() + assert_in_delta(norm, 1.0, 0.5) + end + end + + test "handles zero matrix" do + # Zero matrix has all eigenvalues equal to 0 + t = Nx.broadcast(0.0, {3, 3}) + + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + + # All eigenvalues should be 0 + assert eigenvals == ~VEC[0.0+0.0i 0.0+0.0i 0.0+0.0i] + + # For zero matrix, eigenvectors are arbitrary + # Just verify that each column has reasonable norm + for col <- 0..2 do + v = eigenvecs[[.., col]] + norm = Nx.LinAlg.norm(v) |> Nx.to_number() + assert_in_delta(norm, 1.0, 0.5) + end + end + end + describe "svd" do test "finds the singular values of tall matrices" do t = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]) diff --git a/torchx/c_src/torchx.cpp b/torchx/c_src/torchx.cpp index 1b150da9c1..0d4684e2dc 100644 --- a/torchx/c_src/torchx.cpp +++ b/torchx/c_src/torchx.cpp @@ -1043,6 +1043,17 @@ eigh(ErlNifEnv *env, fine::ResourcePtr tensor) { REGISTER_TENSOR_NIF(eigh); +fine::Ok< + std::tuple, fine::ResourcePtr>> +eig(ErlNifEnv *env, fine::ResourcePtr tensor) { + auto result = torch::linalg_eig(get_tensor(tensor)); + return fine::Ok( + std::make_tuple(fine::make_resource(std::get<0>(result)), + fine::make_resource(std::get<1>(result)))); +} + +REGISTER_TENSOR_NIF(eig); + fine::Ok> solve(ErlNifEnv *env, fine::ResourcePtr tensorA, fine::ResourcePtr tensorB) { diff --git a/torchx/lib/torchx.ex b/torchx/lib/torchx.ex index 7e9fc61ff1..3f7cd7b49c 100644 --- a/torchx/lib/torchx.ex +++ b/torchx/lib/torchx.ex @@ -359,6 +359,7 @@ defmodule Torchx do deftensor cholesky(tensor) deftensor cholesky(tensor, upper) + deftensor eig(tensor) deftensor eigh(tensor) deftensor qr(tensor) deftensor qr(tensor, reduced) diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 77308ce94d..0820bdfe13 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -1032,6 +1032,45 @@ defmodule Torchx.Backend do {to_nx(q, eigenvals), to_nx(r, eigenvecs)} end + @impl true + def eig({eigenvals, eigenvecs}, tensor, _opts) do + {vals_tx, vecs_tx} = + tensor + |> from_nx() + |> Torchx.eig() + + abs_type = to_torch_type(Nx.Type.to_real(eigenvals.type)) + + m = Nx.axis_size(eigenvecs, -2) + n = Nx.axis_size(eigenvecs, -1) + + sort_nx = + vals_tx + |> Torchx.abs() + |> Torchx.to_type(abs_type) + |> Torchx.to_nx() + |> Nx.argsort(axis: -1, direction: :desc) + |> Nx.revectorize([leading: :auto], target_shape: {n}) + + # Nx expects the eigenvalues and eigenvectors to be sorted + # We rely on vectorization so that we can use Nx.take/2 + # in a similar way to what the reference implementation for Nx does + + {vals_tx + |> Torchx.to_type(to_torch_type(eigenvals.type)) + |> to_nx(eigenvals) + |> Nx.revectorize([leading: :auto], target_shape: {n}) + |> Nx.take(sort_nx) + |> Nx.devectorize(keep_names: false) + |> Nx.revectorize([], target_shape: eigenvals.shape, target_names: eigenvals.names), + vecs_tx + |> Torchx.to_type(to_torch_type(eigenvecs.type)) + |> to_nx(eigenvecs) + |> Nx.revectorize([leading: :auto], target_shape: {m, n}) + |> Nx.take(sort_nx, axis: 1) + |> Nx.revectorize([], target_shape: eigenvecs.shape, target_names: eigenvecs.names)} + end + @impl true def qr({q_holder, r_holder}, tensor, opts) do {q, r} = diff --git a/torchx/mix.exs b/torchx/mix.exs index e6e88bd54b..76f52f9096 100644 --- a/torchx/mix.exs +++ b/torchx/mix.exs @@ -41,8 +41,8 @@ defmodule Torchx.MixProject do defp deps do [ - {:nx, "~> 0.10.0"}, - # {:nx, path: "../nx"}, + # Use the local Nx workspace for testing eig implementation + {:nx, path: "../nx"}, {:fine, "~> 0.1.0", runtime: false}, {:ex_doc, "~> 0.29", only: :docs} ] diff --git a/torchx/test/torchx/nx_doctest_test.exs b/torchx/test/torchx/nx_doctest_test.exs index a846948421..082b238691 100644 --- a/torchx/test/torchx/nx_doctest_test.exs +++ b/torchx/test/torchx/nx_doctest_test.exs @@ -26,13 +26,10 @@ defmodule Torchx.NxDoctestTest do standard_deviation: 2 ] - if Application.compile_env(:torchx, :is_apple_arm64) do - @os_rounding_error_doctests [sin: 1] - else - case :os.type() do - {:win32, _} -> @os_rounding_error_doctests [expm1: 1, erf: 1] - _ -> @os_rounding_error_doctests [] - end + case :os.type() do + {:win32, _} -> @os_rounding_error_doctests [expm1: 1, erf: 1] + {:unix, :darwin} -> @os_rounding_error_doctests [sin: 1, erf: 1] + _ -> @os_rounding_error_doctests [] end @unrelated_doctests [ diff --git a/torchx/test/torchx/nx_linalg_test.exs b/torchx/test/torchx/nx_linalg_test.exs index 8dfa491d82..a5f797bf49 100644 --- a/torchx/test/torchx/nx_linalg_test.exs +++ b/torchx/test/torchx/nx_linalg_test.exs @@ -335,6 +335,125 @@ defmodule Torchx.NxLinAlgTest do end end + describe "eig" do + test "computes eigenvalues and eigenvectors for diagonal matrix" do + t = Nx.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]]) + assert {eigenvals, _eigenvecs} = Nx.LinAlg.eig(t) + expected = Nx.tensor([3.0, 2.0, 1.0]) |> Nx.as_type({:c, 64}) + assert_all_close(Nx.abs(eigenvals), Nx.abs(expected), atol: 1.0e-2) + end + + test "computes eigenvalues and eigenvectors for upper triangular matrix" do + t = Nx.tensor([[1.0, 2.0, 3.0], [0.0, 4.0, 5.0], [0.0, 0.0, 6.0]]) + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + expected = Nx.tensor([6.0, 4.0, 1.0]) + assert_all_close(Nx.abs(eigenvals), Nx.abs(expected), atol: 1.0e-2) + + assert_all_close(Nx.dot(t, eigenvecs), Nx.dot(eigenvecs, Nx.make_diagonal(eigenvals)), + atol: 1.0e-2 + ) + end + + test "computes eigenvalues and eigenvectors for lower triangular matrix" do + t = Nx.tensor([[1.0, 0.0, 0.0], [2.0, 3.0, 0.0], [4.0, 5.0, 6.0]]) + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + expected = Nx.tensor([6.0, 3.0, 1.0]) + assert_all_close(Nx.abs(eigenvals), Nx.abs(expected), atol: 1.0e-2) + + assert_all_close(Nx.dot(t, eigenvecs), Nx.dot(eigenvecs, Nx.make_diagonal(eigenvals)), + atol: 1.0e-2 + ) + end + + test "computes complex eigenvalues for rotation matrix" do + t = Nx.tensor([[0.0, -1.0], [1.0, 0.0]]) + assert {eigenvals, _eigenvecs} = Nx.LinAlg.eig(t) + assert_all_close(Nx.abs(eigenvals), Nx.tensor([1.0, 1.0]), atol: 1.0e-3) + assert_all_close(Nx.sum(Nx.imag(eigenvals)), Nx.tensor(0.0), atol: 1.0e-3) + end + + test "works with batched matrices" do + t = Nx.tensor([[[1.0, 0.0], [0.0, 2.0]], [[3.0, 0.0], [0.0, 4.0]]]) + assert {eigenvals, _eigenvecs} = Nx.LinAlg.eig(t) + expected = Nx.tensor([[2.0, 1.0], [4.0, 3.0]]) + assert_all_close(Nx.abs(eigenvals), expected, atol: 1.0e-3) + end + + test "works with vectorized matrices" do + t = + Nx.tensor([ + [[[1.0, 0.0], [0.0, 2.0]]], + [[[3.0, 0.0], [0.0, 4.0]]] + ]) + |> Nx.vectorize(x: 2, y: 1) + + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(t) + assert eigenvals.vectorized_axes == [x: 2, y: 1] + assert eigenvecs.vectorized_axes == [x: 2, y: 1] + + eigenvals = Nx.devectorize(eigenvals) + assert_all_close(Nx.abs(eigenvals[0][0]), Nx.tensor([2.0, 1.0]), atol: 1.0e-3) + assert_all_close(Nx.abs(eigenvals[1][0]), Nx.tensor([4.0, 3.0]), atol: 1.0e-3) + + eigenvecs_dev = Nx.devectorize(eigenvecs) + + for batch <- 0..1, col <- 0..1 do + v = eigenvecs_dev[batch][0][[.., col]] + norm = Nx.LinAlg.norm(v) |> Nx.to_number() + assert_in_delta(norm, 1.0, 0.1) + end + end + + test "property: eigenvalue equation A*v = λ*v" do + key = Nx.Random.key(System.unique_integer()) + + for _ <- 1..3, type <- [{:f, 32}, {:f, 64}], reduce: key do + key -> + {base_q, key} = Nx.Random.uniform(key, -2, 2, shape: {2, 3, 3}, type: :f32) + + {q, _} = Nx.LinAlg.qr(base_q) + + evals_test = + [10, 1, 0.1] + |> Enum.map(fn magnitude -> + sign = if :rand.uniform() - 0.5 > 0, do: 1, else: -1 + rand = :rand.uniform() * magnitude * 0.1 + magnitude + rand * sign + end) + |> Nx.tensor(type: type) + + evals_test_diag = + evals_test + |> Nx.make_diagonal() + |> Nx.reshape({1, 3, 3}) + |> Nx.tile([2, 1, 1]) + + q_adj = Nx.LinAlg.adjoint(q) + + a = + q + |> Nx.dot([2], [0], evals_test_diag, [1], [0]) + |> Nx.dot([2], [0], q_adj, [1], [0]) + + assert {eigenvals, eigenvecs} = Nx.LinAlg.eig(a, balance: 0) + + evals = + eigenvals + |> Nx.vectorize(x: 2) + |> Nx.make_diagonal() + |> Nx.devectorize(keep_names: false) + + assert_all_close( + Nx.dot(eigenvecs, [-1], [0], evals, [-2], [0]), + Nx.dot(a, [-1], [0], eigenvecs, [-2], [0]), + atol: 1.0e-3 + ) + + key + end + end + end + defp random_uniform(shape, opts \\ [type: :f32]) do values = Enum.map(1..Tuple.product(shape), fn _ -> :rand.uniform() end)