From 723cd03a56317dc2cf51a7dc49adbd68cde6b9dd Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa Date: Fri, 8 Mar 2024 07:10:49 -0800 Subject: [PATCH] ENH: extend power_iteration to accept a matrix in implicit form * This also reversese the return of power_iteration to (eigenvalue, eigenvector), which is the default in numpy/scipy PiperOrigin-RevId: 613924237 --- optax/_src/linear_algebra.py | 136 ++++++++++++++++------- optax/_src/linear_algebra_test.py | 177 ++++++++++++++++++++++++++++-- 2 files changed, 260 insertions(+), 53 deletions(-) diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index 7bd64300a..8152eb1db 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -14,70 +14,122 @@ # ============================================================================== """Linear algebra utilities used in optimisation.""" +import functools +from typing import Callable, Optional, Union + import chex import jax from jax import lax import jax.numpy as jnp -import numpy as np - +from optax import tree_utils as otu from optax._src import base from optax._src import numerics +def _normalize_tree(x): + # divide by the L2 norm of the tree weights. + return otu.tree_scalar_mul(1.0 / otu.tree_l2_norm(x), x) + + def global_norm(updates: base.PyTree) -> chex.Array: """Compute the global norm across a nested structure of tensors.""" return jnp.sqrt(sum( jnp.sum(numerics.abs_sq(x)) for x in jax.tree_util.tree_leaves(updates))) -def power_iteration(matrix: chex.Array, - num_iters: int = 100, - error_tolerance: float = 1e-6, - precision: lax.Precision = lax.Precision.HIGHEST): +def _power_iteration_cond_fun(error_tolerance, num_iters, loop_vars): + normalized_eigvec, unnormalized_eigvec, eig, iter_num = loop_vars + residual = otu.tree_sub( + unnormalized_eigvec, otu.tree_scalar_mul(eig, normalized_eigvec) + ) + residual_norm = otu.tree_l2_norm(residual) + converged = jnp.abs(residual_norm / eig) < error_tolerance + return ~converged & (iter_num < num_iters) + + +def power_iteration( + matrix: Union[chex.Array, Callable[[chex.ArrayTree], chex.ArrayTree]], + *, + v0: Optional[chex.ArrayTree] = None, + num_iters: int = 100, + error_tolerance: float = 1e-6, + precision: lax.Precision = lax.Precision.HIGHEST, + key: Optional[chex.PRNGKey] = None, +) -> tuple[chex.Numeric, chex.ArrayTree]: r"""Power iteration algorithm. - The power iteration algorithm takes a symmetric PSD matrix `A`, and produces - a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue - of `A`, and a vector v, which is the corresponding eigenvector of `A`. + This algorithm computes the dominant eigenvalue and its associated eigenvector + of a diagonalizable matrix. This matrix can be given as an array or as a + callable implementing a matrix-vector product. References: - [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration) + Wikipedia contributors. `Power iteration + `_. Args: - matrix: the symmetric PSD matrix. - num_iters: Number of iterations. - error_tolerance: Iterative exit condition. - precision: precision XLA related flag, the available options are: - a) lax.Precision.DEFAULT (better step time, but not precise); - b) lax.Precision.HIGH (increased precision, slower); - c) lax.Precision.HIGHEST (best possible precision, slowest). + matrix: a square matrix, either as an array or a callable implementing a + matrix-vector product. + v0: initial vector approximating the dominiant eigenvector. If ``matrix`` + is an array of size (n, n), v0 must be a vector of size (n,). If instead + ``matrix`` is a callable, then v0 must be a tree with the same structure + as the input of this callable. If this argument is None and ``matrix`` is + an array, then a random vector sampled from a uniform distribution in + [-1, 1] is used as initial vector. + num_iters: Number of power iterations. + error_tolerance: Iterative exit condition. The procedure stops when the + relative error of the estimate of the dominant eigenvalue is below this + threshold. + precision: precision XLA related flag, the available options are: a) + lax.Precision.DEFAULT (better step time, but not precise); b) + lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST + (best possible precision, slowest). + key: random key for the initialization of ``v0`` when not given + explicitly. When this argument is None, `jax.random.PRNGKey(0)` is used. Returns: - eigen vector, eigen value + A pair (eigenvalue, eigenvector), where eigenvalue is the dominant + eigenvalue of ``matrix`` and eigenvector is its associated eigenvector. """ - matrix_size = matrix.shape[-1] - def _iter_condition(state): - i, unused_v, unused_s, unused_s_v, run_step = state - return jnp.logical_and(i < num_iters, run_step) - - def _iter_body(state): - """One step of power iteration.""" - i, new_v, s, s_v, unused_run_step = state - new_v = new_v / jnp.linalg.norm(new_v) - - s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision) - s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision) - return (i + 1, s_v, s_new, s_v, - jnp.greater(jnp.abs(s_new - s), error_tolerance)) - - # Figure out how to use step as seed for random. - v_0 = np.random.uniform(-1.0, 1.0, matrix_size).astype(matrix.dtype) - - init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True]) - _, v_out, s_out, _, _ = lax.while_loop( - _iter_condition, _iter_body, init_state) - v_out = v_out / jnp.linalg.norm(v_out) - return v_out, s_out + if callable(matrix): + mvp = matrix + if v0 is None: + # v0 must be given as we don't know the underlying pytree structure. + raise ValueError('v0 must be provided when `matrix` is a callable.') + else: + mvp = lambda v: jnp.matmul(matrix, v, precision=precision) + if v0 is None: + if key is None: + key = jax.random.PRNGKey(0) + # v0 is uniformly distributed in [-1, 1] + v0 = jax.random.uniform( + key, + shape=matrix.shape[-1:], + dtype=matrix.dtype, + minval=-1.0, + maxval=1.0, + ) + + v0 = _normalize_tree(v0) + + cond_fun = functools.partial( + _power_iteration_cond_fun, + error_tolerance, + num_iters, + ) + + def _body_fun(loop_vars): + _, z, _, iter_num = loop_vars + eigvec = _normalize_tree(z) + z = mvp(eigvec) + eig = otu.tree_vdot(eigvec, z) + return eigvec, z, eig, iter_num + 1 + + init_vars = (v0, mvp(v0), jnp.asarray(0.0), jnp.asarray(0)) + _, unormalized_eigenvector, dominant_eigenvalue, _ = ( + jax.lax.while_loop(cond_fun, _body_fun, init_vars) + ) + normalized_eigenvector = _normalize_tree(unormalized_eigenvector) + return dominant_eigenvalue, normalized_eigenvector def matrix_inverse_pth_root(matrix: chex.Array, @@ -117,7 +169,7 @@ def matrix_inverse_pth_root(matrix: chex.Array, matrix_size = matrix.shape[0] alpha = jnp.asarray(-1.0 / p, jnp.float32) identity = jnp.eye(matrix_size, dtype=jnp.float32) - _, max_ev = power_iteration( + max_ev, _ = power_iteration( matrix=matrix, num_iters=100, error_tolerance=1e-6, precision=precision) ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16) diff --git a/optax/_src/linear_algebra_test.py b/optax/_src/linear_algebra_test.py index 386a76ef3..8296057b1 100644 --- a/optax/_src/linear_algebra_test.py +++ b/optax/_src/linear_algebra_test.py @@ -12,26 +12,180 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== + """Tests for optax._src.linear_algebra.""" -from absl.testing import absltest +from typing import Iterable +from absl.testing import absltest +from absl.testing import parameterized +import chex +import flax.linen as nn +import jax import jax.numpy as jnp import numpy as np +from optax import tree_utils from optax._src import linear_algebra import scipy.stats -class LinearAlgebraTest(absltest.TestCase): +class MLP(nn.Module): + # Multi-layer perceptron (MLP). + num_outputs: int + hidden_sizes: Iterable[int] + + @nn.compact + def __call__(self, x): + for num_hidden in self.hidden_sizes: + x = nn.Dense(num_hidden)(x) + x = nn.gelu(x) + return nn.Dense(self.num_outputs)(x) + + +class LinearAlgebraTest(chex.TestCase): def test_global_norm(self): - flat_updates = jnp.array([2., 4., 3., 5.], dtype=jnp.float32) + flat_updates = jnp.array([2.0, 4.0, 3.0, 5.0], dtype=jnp.float32) nested_updates = dict( - a=jnp.array([2., 4.], dtype=jnp.float32), - b=jnp.array([3., 5.], dtype=jnp.float32)) + a=jnp.array([2.0, 4.0], dtype=jnp.float32), + b=jnp.array([3.0, 5.0], dtype=jnp.float32), + ) np.testing.assert_array_equal( jnp.sqrt(jnp.sum(flat_updates**2)), - linear_algebra.global_norm(nested_updates)) + linear_algebra.global_norm(nested_updates), + ) + + def test_power_iteration_cond_fun(self, dim=6): + """Test the condition function for power iteration.""" + matrix = jax.random.normal(jax.random.PRNGKey(0), (dim, dim)) + matrix = matrix @ matrix.T + all_eigenval, all_eigenvec = jax.numpy.linalg.eigh(matrix) + dominant_eigenval = all_eigenval[-1] + dominant_eigenvec = all_eigenvec[:, -1] * jnp.sign(all_eigenvec[:, -1][0]) + # loop variables for _power_iteration_cond_fun + loop_vars = ( + dominant_eigenvec, + dominant_eigenval * dominant_eigenvec, + dominant_eigenval, + 10, + ) + # when given the correct dominant eigenvector, the condition function + # should stop and return False. + cond_fun_result = linear_algebra._power_iteration_cond_fun( + 100, 1e-3, loop_vars + ) + self.assertEqual(cond_fun_result, False) + + @chex.all_variants + @parameterized.parameters( + dict(implicit=True), + dict(implicit=False), + ) + def test_power_iteration( + self, implicit, dim=6, tol=1e-3, num_iters=100 + ): + """Test power_iteration by comparing to numpy.linalg.eigh.""" + + if implicit: + # test the function when the matrix is given in implicit form by a + # matrix-vector product. + def power_iteration(matrix, *, v0): + return linear_algebra.power_iteration( + lambda x: matrix @ x, + v0=v0, + error_tolerance=tol, + num_iters=num_iters, + ) + else: + power_iteration = linear_algebra.power_iteration + + # test this function with/without jax.jit and on different devices + power_iteration = self.variant(power_iteration) + + # create a random PSD matrix + matrix = jax.random.normal(jax.random.PRNGKey(0), (dim, dim)) + matrix = matrix @ matrix.T + v0 = jnp.ones((dim,)) + + eigval_power, eigvec_power = power_iteration(matrix, v0=v0) + all_eigenval, all_eigenvec = jax.numpy.linalg.eigh(matrix) + + self.assertAlmostEqual(eigval_power, all_eigenval[-1], delta=10 * tol) + np.testing.assert_array_almost_equal( + all_eigenvec[:, -1] * jnp.sign(all_eigenvec[:, -1][0]), + eigvec_power * jnp.sign(eigvec_power[0]), + decimal=3, + ) + + @chex.all_variants + def test_power_iteration_pytree( + self, dim=6, tol=1e-3, num_iters=100 + ): + """Test power_iteration for matrix-vector products acting on pytrees.""" + + def matrix_vector_product(x): + # implements a block-diagonal matrix where each block is a scaled + # identity matrix. The scaling factor is 2 and 1 for the first and second + # block respectively. + return {'a': 2 * x['a'], 'b': x['b']} + + @self.variant + def power_iteration(*, v0): + return linear_algebra.power_iteration( + matrix_vector_product, + v0=v0, + error_tolerance=tol, + num_iters=num_iters, + ) + + v0 = {'a': jnp.ones((dim,)), 'b': jnp.ones((dim,))} + + eigval_power, _ = power_iteration(v0=v0) + + # from the block-diagonal structure of matrix, largest eigenvalue is 2. + self.assertAlmostEqual(eigval_power, 2., delta=10 * tol) + + @chex.all_variants + def test_power_iteration_mlp_hessian( + self, input_dim=16, output_dim=4, tol=1e-3 + ): + """Test power_iteration on the Hessian of an MLP.""" + mlp = MLP(num_outputs=output_dim, hidden_sizes=[input_dim, 8, output_dim]) + key = jax.random.PRNGKey(0) + key_params, key_input, key_output = jax.random.split(key, 3) + # initialize the mlp + params = mlp.init(key_params, jnp.ones(input_dim)) + x = jax.random.normal(key_input, (input_dim,)) + y = jax.random.normal(key_output, (output_dim,)) + + @self.variant + def train_obj(params_): + z = mlp.apply(params_, x) + return jnp.sum((z - y) ** 2) + + def hessian_vector_product(tangents_): + return jax.jvp(jax.grad(train_obj), (params,), (tangents_,))[1] + + eigval_power, eigvec_power = linear_algebra.power_iteration( + hessian_vector_product, v0=tree_utils.tree_ones_like(params) + ) + + params_flat, unravel = jax.flatten_util.ravel_pytree(params) + eigvec_power_flat, _ = jax.flatten_util.ravel_pytree(eigvec_power) + + def train_obj_flat(params_flat_): + params_ = unravel(params_flat_) + return train_obj(params_) + + hessian = jax.hessian(train_obj_flat)(params_flat) + all_eigenval, all_eigenvec = jax.numpy.linalg.eigh(hessian) + + self.assertAlmostEqual(all_eigenval[-1], eigval_power, delta=10 * tol) + np.testing.assert_array_almost_equal( + all_eigenvec[:, -1] * jnp.sign(all_eigenvec[:, -1][0]), + eigvec_power_flat * jnp.sign(eigvec_power_flat[0]), + decimal=3, + ) def test_matrix_inverse_pth_root(self): """Test for matrix inverse pth root.""" @@ -39,18 +193,19 @@ def test_matrix_inverse_pth_root(self): def _gen_symmetrix_matrix(dim, condition_number): u = scipy.stats.ortho_group.rvs(dim=dim).astype(np.float64) v = u.T - diag = np.diag([condition_number ** (-i/(dim-1)) for i in range(dim)]) + diag = np.diag([condition_number ** (-i / (dim - 1)) for i in range(dim)]) return u @ diag @ v # Fails after it reaches a particular condition number. for e in range(2, 12): - condition_number = 10 ** e + condition_number = 10**e ms = _gen_symmetrix_matrix(16, condition_number) self.assertLess( - np.abs(np.linalg.cond(ms) - condition_number), - condition_number * 0.01) + np.abs(np.linalg.cond(ms) - condition_number), condition_number * 0.01 + ) error = linear_algebra.matrix_inverse_pth_root( - ms.astype(np.float32), 4, ridge_epsilon=1e-12)[1] + ms.astype(np.float32), 4, ridge_epsilon=1e-12 + )[1] if e < 7: self.assertLess(error, 0.1) else: