Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extragradient and tests #21

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3c4d321
Basic implementation of extra-gradient
pierrelux Feb 25, 2020
157e2e3
Merge remote-tracking branch 'manuel-delverme/master' into extragradient
manuel-delverme Feb 26, 2020
6a10a40
Added tests, forwarding the cost function for debugging purposes (not…
manuel-delverme Feb 29, 2020
a2c0a8e
This commit includes only non functional changes, some of them might …
manuel-delverme Feb 29, 2020
c8ae755
added more details to setup.py
manuel-delverme Mar 10, 2020
fd79e66
added rprop extra gradient
manuel-delverme Apr 8, 2020
e8accf6
rprop EG solves ~half of the constrained tasks
manuel-delverme Apr 9, 2020
c2af669
some tests fail, im not sure why
manuel-delverme Apr 23, 2020
927c37e
passing all the tests
manuel-delverme Apr 25, 2020
967f73e
Merge remote-tracking branch 'manuel-delverme/HockSchittkowski_tests'…
manuel-delverme Apr 30, 2020
bc12646
3/17 tests fail
manuel-delverme Apr 30, 2020
bf6d58c
non working extragradient.py cleanup
manuel-delverme Jun 1, 2020
0b3a821
2/20 tests fail
manuel-delverme Jun 2, 2020
139b9f5
some cleanup
manuel-delverme Jun 2, 2020
fa21e13
removing basic tests, the HS test suite is broad enough;
manuel-delverme Jun 2, 2020
da1baa3
removed extragradient_test.py, the tests are in constrained_test.py
manuel-delverme Jun 2, 2020
5350c21
Merge branch 'master' into extragradient_test
manuel-delverme Jun 2, 2020
4f3883b
Addressing some reviews:
manuel-delverme Jun 5, 2020
7bc1c92
Merge remote-tracking branch 'manuel-delverme/extragradient_test' int…
manuel-delverme Jun 5, 2020
5d33f8e
reverted setup.py changes and indentation levels
manuel-delverme Jun 8, 2020
498f2eb
Merge branch 'master' into extragradient_test
manuel-delverme Jun 8, 2020
19ae12b
Fixed jax.numpy import
manuel-delverme Aug 25, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions fax/competitive/cga.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from functools import partial

import jax
import jax.numpy as np
from jax import lax
from jax import tree_util
import jax.numpy as np
from jax.experimental import optimizers

from fax import converge
from fax import loop
from fax.competitive import cg

CGAState = collections.namedtuple("CGAState", "x y delta_x delta_y")

Expand Down
55 changes: 55 additions & 0 deletions fax/competitive/extragradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Callable

import jax.experimental.optimizers
from jax import numpy as np, tree_util

import fax.competitive.sgd
from fax.jax_utils import add


def adam_extragradient_optimizer(step_size_x, step_size_y, b1=0.3, b2=0.2, eps=1e-8) -> (Callable, Callable, Callable):
"""Construct optimizer triple for Adam.

Args:
step_size_x: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar for the first player.
step_size_y: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar for the second player.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.3).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
for the second moment estimates (default 0.2).
eps: optional, a positive scalar value for epsilon, a small constant for
numerical stability (default 1e-8).

Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size_x = jax.experimental.optimizers.make_schedule(step_size_x)
step_size_y = jax.experimental.optimizers.make_schedule(step_size_y)

def init(initial_values):
mean_avg = tree_util.tree_map(lambda x: np.zeros(x.shape, x.dtype), initial_values)
var_avg = tree_util.tree_map(lambda x: np.zeros(x.shape, x.dtype), initial_values)
return initial_values, (mean_avg, var_avg)

def update(step, grad_fns, state):
x0, optimizer_state = state
step_sizes = - step_size_x(step), step_size_y(step) # negate the step size so that we do gradient ascent-descent

grads = grad_fns(*x0)
deltas, optimizer_state = fax.competitive.sgd.adam_step(b1, b2, eps, step_sizes, grads, optimizer_state, step)

x_bar = add(x0, deltas)

grads = grad_fns(*x_bar) # the gradient is evaluated at x_bar
deltas, optimizer_state = fax.competitive.sgd.adam_step(b1, b2, eps, step_sizes, grads, optimizer_state, step)
x1 = add(x0, deltas) # but applied at x_0

return x1, optimizer_state

def get_params(state):
x, _optimizer_state = state
return x

return init, update, get_params
14 changes: 6 additions & 8 deletions fax/constrained/constrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,24 @@
"""
import collections

from scipy.optimize import minimize

import jax
from jax import lax
from jax import jit
import jax.numpy as np
from jax import grad
from jax import jacrev
import jax.numpy as np
from jax import jit
from jax import lax
from jax import tree_util
from jax.experimental import optimizers
from jax.flatten_util import ravel_pytree
from scipy.optimize import minimize

from fax import math
from fax import converge
from fax import math
from fax.competitive import cg
from fax.competitive import cga
from fax.loop import fixed_point_iteration
from fax.implicit.twophase import make_adjoint_fixed_point_iteration
from fax.implicit.twophase import make_forward_fixed_point_iteration

from fax.loop import fixed_point_iteration

ConstrainedSolution = collections.namedtuple(
"ConstrainedSolution",
Expand Down
89 changes: 71 additions & 18 deletions fax/constrained/constrained_test.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,33 @@
import absl.testing
import absl.testing.parameterized
import hypothesis.extra.numpy
import hypothesis.strategies
import jax
import jax.experimental.optimizers
import jax.nn
import jax.numpy as np
import jax.scipy.special
import jax.test_util
import jax.tree_util
import numpy as onp
from absl.testing import absltest
from absl.testing import parameterized
from jax import random
from jax import tree_util
from jax.config import config
from jax.experimental import optimizers
from jax.experimental.stax import softmax
from jax.scipy.special import logsumexp

import fax.tests.hock_schittkowski_suite
import fax
import fax.test_util
from fax import converge
from fax import test_util
from fax.competitive import extragradient
from fax.constrained import cga_ecp
from fax.constrained import cga_lagrange_min
from fax.constrained import implicit_ecp
from fax.constrained import make_lagrangian
from fax.constrained import slsqp_ecp

config.update("jax_enable_x64", True)
benchmarks = list(fax.tests.hock_schittkowski_suite.load_HockSchittkowski_models())
jax.config.update("jax_enable_x64", True)
test_params = dict(rtol=1e-4, atol=1e-4, check_dtypes=False)
convergence_params = dict(rtol=1e-5, atol=1e-5)


class CGATest(jax.test_util.JaxTestCase):
Expand All @@ -34,8 +39,8 @@ def test_cga_lagrange_min(self):

init_mult, lagrangian, get_x = make_lagrangian(func, eq_constraints)

rng = random.PRNGKey(8413)
init_params = random.uniform(rng, (n,))
rng = jax.random.PRNGKey(8413)
init_params = jax.random.uniform(rng, (n,))
lagr_params = init_mult(init_params)

lr = 0.5
Expand All @@ -48,7 +53,8 @@ def convergence_test(x_new, x_old):
@jax.jit
def step(i, opt_state):
params = get_params(opt_state)
grads = jax.grad(lagrangian, (0, 1))(*params)
grad_fn = jax.grad(lagrangian, (0, 1))
grads = grad_fn(*params)
return opt_update(i, grads, opt_state)

opt_state = opt_init(lagr_params)
Expand All @@ -65,10 +71,10 @@ def step(i, opt_state):
check_dtypes=False)

h = eq_constraints(get_x(final_params))
self.assertAllClose(h, tree_util.tree_map(np.zeros_like, h),
self.assertAllClose(h, jax.tree_util.tree_map(np.zeros_like, h),
check_dtypes=False)

@parameterized.parameters(
@absl.testing.parameterized.parameters(
{'method': cga_ecp, 'kwargs': {'max_iter': 1000, 'lr_func': 0.5}},
{'method': slsqp_ecp, 'kwargs': {'max_iter': 1000}}, )
@hypothesis.settings(max_examples=10, deadline=5000.)
Expand All @@ -86,8 +92,8 @@ def objective(x, y):
def constraints(x, y):
return 1 - np.linalg.norm(np.asarray([x, y]))

rng = random.PRNGKey(8413)
initial_values = random.uniform(rng, (len(v),))
rng = jax.random.PRNGKey(8413)
initial_values = jax.random.uniform(rng, (len(v),))

solution = method(objective, constraints, initial_values, **kwargs)

Expand All @@ -96,7 +102,7 @@ def constraints(x, y):
objective(*solution.value),
check_dtypes=False)

@parameterized.parameters(
@absl.testing.parameterized.parameters(
{'method': implicit_ecp,
'kwargs': {'max_iter': 1000, 'lr_func': 0.01, 'optimizer': optimizers.adam}},
{'method': cga_ecp, 'kwargs': {'max_iter': 1000, 'lr_func': 0.15, 'lr_multipliers': 0.925}},
Expand All @@ -115,8 +121,7 @@ def test_omd(self, method, kwargs):

def smooth_bellman_optimality_operator(x, params):
transition, reward, discount, temperature = params
return reward + discount * np.einsum('ast,t->sa', transition, temperature *
logsumexp((1. / temperature) * x, axis=1))
return reward + discount * np.einsum('ast,t->sa', transition, temperature * logsumexp((1. / temperature) * x, axis=1))

@jax.jit
def objective(x, params):
Expand All @@ -143,5 +148,53 @@ def equality_constraints(x, params):
self.assertAllClose(objective(*solution.value), optimal_value, check_dtypes=False)


class EGTest(jax.test_util.JaxTestCase):
@absl.testing.parameterized.parameters(fax.test_util.load_HockSchittkowski_models())
def test_eg_HockSchittkowski(self, objective_function, equality_constraints, hs_optimal_value: np.array, initial_value):
def convergence_test(x_new, x_old):
return fax.converge.max_diff_test(x_new, x_old, **convergence_params)

initialize_multipliers, lagrangian, get_x = make_lagrangian(objective_function, equality_constraints)

x0 = initial_value()
initial_values = initialize_multipliers(x0)

final_val, h, x, multiplier = self.eg_solve(lagrangian, convergence_test, equality_constraints, objective_function, get_x, initial_values)

import scipy.optimize
constraints = ({'type': 'eq', 'fun': equality_constraints, },)

res = scipy.optimize.minimize(lambda *args: -objective_function(*args), initial_values[0], method='SLSQP', constraints=constraints)
scipy_optimal_value = -res.fun
scipy_constraint = equality_constraints(res.x)

self.assertAllClose(final_val, scipy_optimal_value, **test_params)
self.assertAllClose(h, scipy_constraint, **test_params)

def eg_solve(self, lagrangian, convergence_test, equality_constraints, objective_function, get_x, initial_values):
optimizer_init, optimizer_update, optimizer_get_params = extragradient.adam_extragradient_optimizer(
step_size_x=jax.experimental.optimizers.inverse_time_decay(1e-1, 50, 0.3, staircase=True),
step_size_y=5e-2,
)

@jax.jit
def update(i, opt_state):
grad_fn = jax.grad(lagrangian, (0, 1))
return optimizer_update(i, grad_fn, opt_state)

solution = fax.loop.fixed_point_iteration(
init_x=optimizer_init(initial_values),
func=update,
convergence_test=convergence_test,
max_iter=100000000,
get_params=optimizer_get_params,
f=lagrangian,
)
x, multipliers = get_x(solution)
final_val = objective_function(x)
h = equality_constraints(x)
return final_val, h, x, multipliers


if __name__ == "__main__":
absltest.main()
31 changes: 31 additions & 0 deletions fax/jax_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import functools

from jax import tree_util, lax, numpy as np

division = functools.partial(tree_util.tree_multimap, lax.div)
add = functools.partial(tree_util.tree_multimap, lax.add)
sub = functools.partial(tree_util.tree_multimap, lax.sub)
mul = functools.partial(tree_util.tree_multimap, lax.mul)
square = functools.partial(tree_util.tree_map, lax.square)


def division_constant(constant):
def divide(a):
return tree_util.tree_multimap(lambda _a: _a / constant, a)

return divide


def multiply_constant(constant):
return functools.partial(mul, constant)


def expand_like(a, b):
return a * np.ones(b.shape, b.dtype)


def make_exp_smoothing(beta):
def exp_smoothing(state, var):
return multiply_constant(beta)(state) + multiply_constant((1 - beta))(var)

return exp_smoothing
10 changes: 5 additions & 5 deletions fax/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings

import jax
import jax.lax
import jax.numpy as np

FixedPointSolution = collections.namedtuple(
Expand All @@ -28,8 +29,7 @@ def unrolled(i, init_x, func, num_iter, return_last_two=False):
x_old = None

for _ in range(num_iter):
x_old = x
x = func(i, x_old)
x, x_old = func(i, x), x
i = i + 1

if return_last_two:
Expand All @@ -38,8 +38,7 @@ def unrolled(i, init_x, func, num_iter, return_last_two=False):
return i, x


def fixed_point_iteration(init_x, func, convergence_test, max_iter,
batched_iter_size=1, unroll=False):
def fixed_point_iteration(init_x, func, convergence_test, max_iter, batched_iter_size=1, unroll=False, get_params=lambda x: x, f=None) -> FixedPointSolution:
"""Find a fixed point of `func` by repeatedly applying `func`.

Use this function to find a fixed point of `func` by repeatedly applying
Expand Down Expand Up @@ -104,6 +103,7 @@ def fixed_point_iteration(init_x, func, convergence_test, max_iter,

def cond(args):
i, x_new, x_old = args
x_new, x_old = get_params(x_new), get_params(x_old)
manuel-delverme marked this conversation as resolved.
Show resolved Hide resolved
converged = convergence_test(x_new, x_old)

if max_iter is not None:
Expand Down Expand Up @@ -136,13 +136,13 @@ def scan_step(args, idx):
xs=np.arange(max_batched_iter - 1),
)
converged = convergence_test(sol, prev_sol)

else:
iterations, sol, prev_sol = jax.lax.while_loop(
cond,
body,
init_vals,
)
sol, prev_sol = get_params(sol), get_params(prev_sol)
converged = max_iter is None or iterations < max_iter

return FixedPointSolution(
Expand Down
Loading