Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: extend power_iteration to accept a matrix in implicit form #858

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 94 additions & 42 deletions optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,70 +14,122 @@
# ==============================================================================
"""Linear algebra utilities used in optimisation."""

import functools
from typing import Callable, Optional, Union

import chex
import jax
from jax import lax
import jax.numpy as jnp
import numpy as np

from optax import tree_utils as otu
from optax._src import base
from optax._src import numerics


def _normalize_tree(x):
# divide by the L2 norm of the tree weights.
return otu.tree_scalar_mul(1.0 / otu.tree_l2_norm(x), x)


def global_norm(updates: base.PyTree) -> chex.Array:
"""Compute the global norm across a nested structure of tensors."""
return jnp.sqrt(sum(
jnp.sum(numerics.abs_sq(x)) for x in jax.tree_util.tree_leaves(updates)))


def power_iteration(matrix: chex.Array,
num_iters: int = 100,
error_tolerance: float = 1e-6,
precision: lax.Precision = lax.Precision.HIGHEST):
def _power_iteration_cond_fun(error_tolerance, num_iters, loop_vars):
normalized_eigvec, unnormalized_eigvec, eig, iter_num = loop_vars
residual = otu.tree_sub(
unnormalized_eigvec, otu.tree_scalar_mul(eig, normalized_eigvec)
)
residual_norm = otu.tree_l2_norm(residual)
converged = jnp.abs(residual_norm / eig) < error_tolerance
return ~converged & (iter_num < num_iters)


def power_iteration(
matrix: Union[chex.Array, Callable[[chex.ArrayTree], chex.ArrayTree]],
*,
v0: Optional[chex.ArrayTree] = None,
num_iters: int = 100,
error_tolerance: float = 1e-6,
precision: lax.Precision = lax.Precision.HIGHEST,
key: Optional[chex.PRNGKey] = None,
) -> tuple[chex.Numeric, chex.ArrayTree]:
r"""Power iteration algorithm.

The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
of `A`, and a vector v, which is the corresponding eigenvector of `A`.
This algorithm computes the dominant eigenvalue and its associated eigenvector
of a diagonalizable matrix. This matrix can be given as an array or as a
callable implementing a matrix-vector product.

References:
[Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
Wikipedia contributors. `Power iteration
<https://en.wikipedia.org/w/index.php?tit0le=Power_iteration>`_.

Args:
matrix: the symmetric PSD matrix.
num_iters: Number of iterations.
error_tolerance: Iterative exit condition.
precision: precision XLA related flag, the available options are:
a) lax.Precision.DEFAULT (better step time, but not precise);
b) lax.Precision.HIGH (increased precision, slower);
c) lax.Precision.HIGHEST (best possible precision, slowest).
matrix: a square matrix, either as an array or a callable implementing a
matrix-vector product.
v0: initial vector approximating the dominiant eigenvector. If ``matrix``
is an array of size (n, n), v0 must be a vector of size (n,). If instead
``matrix`` is a callable, then v0 must be a tree with the same structure
as the input of this callable. If this argument is None and ``matrix`` is
an array, then a random vector sampled from a uniform distribution in
[-1, 1] is used as initial vector.
num_iters: Number of power iterations.
error_tolerance: Iterative exit condition. The procedure stops when the
relative error of the estimate of the dominant eigenvalue is below this
threshold.
precision: precision XLA related flag, the available options are: a)
lax.Precision.DEFAULT (better step time, but not precise); b)
lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST
(best possible precision, slowest).
key: random key for the initialization of ``v0`` when not given
explicitly. When this argument is None, `jax.random.PRNGKey(0)` is used.

Returns:
eigen vector, eigen value
A pair (eigenvalue, eigenvector), where eigenvalue is the dominant
eigenvalue of ``matrix`` and eigenvector is its associated eigenvector.
"""
matrix_size = matrix.shape[-1]
def _iter_condition(state):
i, unused_v, unused_s, unused_s_v, run_step = state
return jnp.logical_and(i < num_iters, run_step)

def _iter_body(state):
"""One step of power iteration."""
i, new_v, s, s_v, unused_run_step = state
new_v = new_v / jnp.linalg.norm(new_v)

s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision)
s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision)
return (i + 1, s_v, s_new, s_v,
jnp.greater(jnp.abs(s_new - s), error_tolerance))

# Figure out how to use step as seed for random.
v_0 = np.random.uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype)

init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
_, v_out, s_out, _, _ = lax.while_loop(
_iter_condition, _iter_body, init_state)
v_out = v_out / jnp.linalg.norm(v_out)
return v_out, s_out
if callable(matrix):
mvp = matrix
if v0 is None:
# v0 must be given as we don't know the underlying pytree structure.
raise ValueError('v0 must be provided when `matrix` is a callable.')
else:
mvp = lambda v: jnp.matmul(matrix, v, precision=precision)
if v0 is None:
if key is None:
key = jax.random.PRNGKey(0)
# v0 is uniformly distributed in [-1, 1]
v0 = jax.random.uniform(
key,
shape=matrix.shape[-1:],
dtype=matrix.dtype,
minval=-1.0,
maxval=1.0,
)

v0 = _normalize_tree(v0)

cond_fun = functools.partial(
_power_iteration_cond_fun,
error_tolerance,
num_iters,
)

def _body_fun(loop_vars):
_, z, _, iter_num = loop_vars
eigvec = _normalize_tree(z)
z = mvp(eigvec)
eig = otu.tree_vdot(eigvec, z)
return eigvec, z, eig, iter_num + 1

init_vars = (v0, mvp(v0), jnp.asarray(0.0), jnp.asarray(0))
_, unormalized_eigenvector, dominant_eigenvalue, _ = (
jax.lax.while_loop(cond_fun, _body_fun, init_vars)
)
normalized_eigenvector = _normalize_tree(unormalized_eigenvector)
return dominant_eigenvalue, normalized_eigenvector


def matrix_inverse_pth_root(matrix: chex.Array,
Expand Down Expand Up @@ -117,7 +169,7 @@ def matrix_inverse_pth_root(matrix: chex.Array,
matrix_size = matrix.shape[0]
alpha = jnp.asarray(-1.0 / p, jnp.float32)
identity = jnp.eye(matrix_size, dtype=jnp.float32)
_, max_ev = power_iteration(
max_ev, _ = power_iteration(
matrix=matrix, num_iters=100,
error_tolerance=1e-6, precision=precision)
ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16)
Expand Down
177 changes: 166 additions & 11 deletions optax/_src/linear_algebra_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,45 +12,200 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for optax._src.linear_algebra."""

from absl.testing import absltest
from typing import Iterable

from absl.testing import absltest
from absl.testing import parameterized
import chex
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from optax import tree_utils
from optax._src import linear_algebra
import scipy.stats


class LinearAlgebraTest(absltest.TestCase):
class MLP(nn.Module):
# Multi-layer perceptron (MLP).
num_outputs: int
hidden_sizes: Iterable[int]

@nn.compact
def __call__(self, x):
for num_hidden in self.hidden_sizes:
x = nn.Dense(num_hidden)(x)
x = nn.gelu(x)
return nn.Dense(self.num_outputs)(x)


class LinearAlgebraTest(chex.TestCase):

def test_global_norm(self):
flat_updates = jnp.array([2., 4., 3., 5.], dtype=jnp.float32)
flat_updates = jnp.array([2.0, 4.0, 3.0, 5.0], dtype=jnp.float32)
nested_updates = dict(
a=jnp.array([2., 4.], dtype=jnp.float32),
b=jnp.array([3., 5.], dtype=jnp.float32))
a=jnp.array([2.0, 4.0], dtype=jnp.float32),
b=jnp.array([3.0, 5.0], dtype=jnp.float32),
)
np.testing.assert_array_equal(
jnp.sqrt(jnp.sum(flat_updates**2)),
linear_algebra.global_norm(nested_updates))
linear_algebra.global_norm(nested_updates),
)

def test_power_iteration_cond_fun(self, dim=6):
"""Test the condition function for power iteration."""
matrix = jax.random.normal(jax.random.PRNGKey(0), (dim, dim))
matrix = matrix @ matrix.T
all_eigenval, all_eigenvec = jax.numpy.linalg.eigh(matrix)
dominant_eigenval = all_eigenval[-1]
dominant_eigenvec = all_eigenvec[:, -1] * jnp.sign(all_eigenvec[:, -1][0])
# loop variables for _power_iteration_cond_fun
loop_vars = (
dominant_eigenvec,
dominant_eigenval * dominant_eigenvec,
dominant_eigenval,
10,
)
# when given the correct dominant eigenvector, the condition function
# should stop and return False.
cond_fun_result = linear_algebra._power_iteration_cond_fun(
100, 1e-3, loop_vars
)
self.assertEqual(cond_fun_result, False)

@chex.all_variants
@parameterized.parameters(
dict(implicit=True),
dict(implicit=False),
)
def test_power_iteration(
self, implicit, dim=6, tol=1e-3, num_iters=100
):
"""Test power_iteration by comparing to numpy.linalg.eigh."""

if implicit:
# test the function when the matrix is given in implicit form by a
# matrix-vector product.
def power_iteration(matrix, *, v0):
return linear_algebra.power_iteration(
lambda x: matrix @ x,
v0=v0,
error_tolerance=tol,
num_iters=num_iters,
)
else:
power_iteration = linear_algebra.power_iteration

# test this function with/without jax.jit and on different devices
power_iteration = self.variant(power_iteration)

# create a random PSD matrix
matrix = jax.random.normal(jax.random.PRNGKey(0), (dim, dim))
matrix = matrix @ matrix.T
v0 = jnp.ones((dim,))

eigval_power, eigvec_power = power_iteration(matrix, v0=v0)
all_eigenval, all_eigenvec = jax.numpy.linalg.eigh(matrix)

self.assertAlmostEqual(eigval_power, all_eigenval[-1], delta=10 * tol)
np.testing.assert_array_almost_equal(
all_eigenvec[:, -1] * jnp.sign(all_eigenvec[:, -1][0]),
eigvec_power * jnp.sign(eigvec_power[0]),
decimal=3,
)

@chex.all_variants
def test_power_iteration_pytree(
self, dim=6, tol=1e-3, num_iters=100
):
"""Test power_iteration for matrix-vector products acting on pytrees."""

def matrix_vector_product(x):
# implements a block-diagonal matrix where each block is a scaled
# identity matrix. The scaling factor is 2 and 1 for the first and second
# block respectively.
return {'a': 2 * x['a'], 'b': x['b']}

@self.variant
def power_iteration(*, v0):
return linear_algebra.power_iteration(
matrix_vector_product,
v0=v0,
error_tolerance=tol,
num_iters=num_iters,
)

v0 = {'a': jnp.ones((dim,)), 'b': jnp.ones((dim,))}

eigval_power, _ = power_iteration(v0=v0)

# from the block-diagonal structure of matrix, largest eigenvalue is 2.
self.assertAlmostEqual(eigval_power, 2., delta=10 * tol)

@chex.all_variants
def test_power_iteration_mlp_hessian(
self, input_dim=16, output_dim=4, tol=1e-3
):
"""Test power_iteration on the Hessian of an MLP."""
mlp = MLP(num_outputs=output_dim, hidden_sizes=[input_dim, 8, output_dim])
key = jax.random.PRNGKey(0)
key_params, key_input, key_output = jax.random.split(key, 3)
# initialize the mlp
params = mlp.init(key_params, jnp.ones(input_dim))
x = jax.random.normal(key_input, (input_dim,))
y = jax.random.normal(key_output, (output_dim,))

@self.variant
def train_obj(params_):
z = mlp.apply(params_, x)
return jnp.sum((z - y) ** 2)

def hessian_vector_product(tangents_):
return jax.jvp(jax.grad(train_obj), (params,), (tangents_,))[1]

eigval_power, eigvec_power = linear_algebra.power_iteration(
hessian_vector_product, v0=tree_utils.tree_ones_like(params)
)

params_flat, unravel = jax.flatten_util.ravel_pytree(params)
eigvec_power_flat, _ = jax.flatten_util.ravel_pytree(eigvec_power)

def train_obj_flat(params_flat_):
params_ = unravel(params_flat_)
return train_obj(params_)

hessian = jax.hessian(train_obj_flat)(params_flat)
all_eigenval, all_eigenvec = jax.numpy.linalg.eigh(hessian)

self.assertAlmostEqual(all_eigenval[-1], eigval_power, delta=10 * tol)
np.testing.assert_array_almost_equal(
all_eigenvec[:, -1] * jnp.sign(all_eigenvec[:, -1][0]),
eigvec_power_flat * jnp.sign(eigvec_power_flat[0]),
decimal=3,
)

def test_matrix_inverse_pth_root(self):
"""Test for matrix inverse pth root."""

def _gen_symmetrix_matrix(dim, condition_number):
u = scipy.stats.ortho_group.rvs(dim=dim).astype(np.float64)
v = u.T
diag = np.diag([condition_number ** (-i/(dim-1)) for i in range(dim)])
diag = np.diag([condition_number ** (-i / (dim - 1)) for i in range(dim)])
return u @ diag @ v

# Fails after it reaches a particular condition number.
for e in range(2, 12):
condition_number = 10 ** e
condition_number = 10**e
ms = _gen_symmetrix_matrix(16, condition_number)
self.assertLess(
np.abs(np.linalg.cond(ms) - condition_number),
condition_number * 0.01)
np.abs(np.linalg.cond(ms) - condition_number), condition_number * 0.01
)
error = linear_algebra.matrix_inverse_pth_root(
ms.astype(np.float32), 4, ridge_epsilon=1e-12)[1]
ms.astype(np.float32), 4, ridge_epsilon=1e-12
)[1]
if e < 7:
self.assertLess(error, 0.1)
else:
Expand Down
Loading