Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.custom_linear_solve primitive #1402

Merged
merged 29 commits into from
Oct 21, 2019
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6888e4f
WIP: linear solvers
shoyer Sep 25, 2019
385e282
Merge branch 'solvers' into linear-solvers
shoyer Sep 26, 2019
189ca8c
Draft of lax.linear_solve
shoyer Sep 26, 2019
2479314
Merge branch 'master' into linear-solvers
shoyer Sep 28, 2019
5042dde
Refactor pytree munging inside lax.root.
shoyer Sep 28, 2019
42e09e5
Merge branch 'refactor-root' into linear-solvers
shoyer Sep 28, 2019
d3211a2
Fixup linear_solve
shoyer Sep 29, 2019
2281105
Linearize multiple times in _root_jvp to avoid zeros
shoyer Sep 29, 2019
c9cfdf0
Merge branch 'refactor-root' into linear-solvers
shoyer Sep 29, 2019
8a2ec33
fix deftraced
shoyer Sep 29, 2019
b1806c4
add a symmetric argument
shoyer Sep 29, 2019
5c9c8cc
Fixup float64; add a test for symmetric/non-symmetric
shoyer Oct 2, 2019
3e39769
test zeros in linear_solve_jvp
shoyer Oct 2, 2019
d5aaaad
Revisions per review
shoyer Oct 3, 2019
39c3f5c
Merge branch 'refactor-root' into linear-solvers
shoyer Oct 3, 2019
f5f3007
Adjust signature of linear_solve
shoyer Oct 3, 2019
709899b
Merge branch 'master' into linear-solvers
shoyer Oct 3, 2019
f98c1b3
restore botched test
shoyer Oct 3, 2019
57cc313
variable names
shoyer Oct 3, 2019
fab4f45
jaxprize
shoyer Oct 10, 2019
1de6200
spelling
shoyer Oct 10, 2019
8fa4ac6
Use np.dot instead of @
shoyer Oct 10, 2019
2e0ccf7
linear_solve docstring, more tests
shoyer Oct 10, 2019
0e06627
Disable batching for root and linear_solve
shoyer Oct 10, 2019
eadb731
Fix linear_solve tests
shoyer Oct 10, 2019
edea7ce
remove unused imports
shoyer Oct 10, 2019
a276348
Rename to custom_linear_solve
shoyer Oct 10, 2019
30c2dc2
WIP: refactor
shoyer Oct 10, 2019
20169dc
fixup test for lazy transpose_solve error
shoyer Oct 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions jax/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from six.moves import reduce

from .. import core
from .. import linear_util as lu
from ..core import Trace, Tracer, new_master
from ..abstract_arrays import ShapedArray, make_shaped_array, array_types, raise_to_shaped
from ..ad_util import add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_p
Expand Down Expand Up @@ -172,6 +173,11 @@ def vectorized_batcher(prim, batched_args, batch_dims, **params):
assert all(batch_dims[0] == bd for bd in batch_dims[1:]), batch_dims
return prim.bind(*batched_args, **params), batch_dims[0]

def deftraced(prim):
def batcher(args, dims, **params):
return batch_fun(lu.wrap_init(prim.impl, params), args, dims)
primitive_batchers[prim] = batcher

def defbroadcasting(prim):
primitive_batchers[prim] = partial(broadcast_batcher, prim)

Expand Down
230 changes: 177 additions & 53 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import division
from __future__ import print_function

import collections
import itertools
import operator
import threading
Expand Down Expand Up @@ -850,6 +851,29 @@ def body(i, dst):
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule


def _flatten_high_level_func(
fun, tree, error_template="Expected {}, got {}",
):
"""Flatten a higher level function ``f`` of the form ``f(g, x)``.

``x``, ``g(x)`` and ``f(g, x)`` all must have the same pytree structure.
"""
def flat_fun(flat_fun2, *args_flat):
args = tree_unflatten(tree, args_flat)
fun2 = partial(apply_flat_fun_nokwargs, flat_fun2, (tree, tree))
out = fun(fun2, args)
out_flat, out_tree = tree_flatten(out)
if out_tree != tree:
raise TypeError(error_template.format(tree, out_tree))
return out_flat
return flat_fun


def _tree_error_template(func_name, input_name):
return (func_name + "() output pytree structure must match " + input_name
+ ", got {} and {}.")


def root(f, initial_guess, solve, tangent_solve):
"""Differentiably solve for a roots of a function.

Expand Down Expand Up @@ -888,15 +912,21 @@ def root(f, initial_guess, solve, tangent_solve):
guess_flat, in_args_tree = tree_flatten((initial_guess,))
guess_avals = tuple(_map(_abstractify, guess_flat))
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_args_tree, guess_avals)

in_tree, = treedef_children(in_args_tree)
if in_tree != out_tree:
raise TypeError(
"f() output pytree structure must match initial_guess, got {} and {}."
.format(out_tree, in_tree)
_tree_error_template("f", "initial_guess").format(out_tree, in_tree)
)

solve_flat = _flatten_high_level_func(
solve, in_tree, _tree_error_template("solve", "initial_guess"))
tangent_solve_flat = _flatten_high_level_func(
tangent_solve, in_tree, _tree_error_template("tangent_solve", "initial_guess"))

out_flat = root_p.bind(*itertools.chain(consts, guess_flat),
tree=out_tree, num_consts=len(consts),
jaxpr=jaxpr, solve=solve, tangent_solve=tangent_solve)
num_consts=len(consts), jaxpr=jaxpr, solve=solve_flat,
tangent_solve=tangent_solve_flat)
return tree_unflatten(out_tree, out_flat)


Expand All @@ -905,34 +935,17 @@ def _root_abstract_eval(*args, **kwargs):


def _root_impl(*args, **kwargs):
tree, num_consts, jaxpr, solve, _ = split_dict(
kwargs, ['tree', 'num_consts', 'jaxpr', 'solve', 'tangent_solve'])

f = partial(
apply_flat_fun_nokwargs,
partial(core.jaxpr_as_fun(jaxpr), *args[:num_consts]),
(tree, tree),
)
initial_guess = tree_unflatten(tree, args[num_consts:])
out = solve(f, initial_guess)

out_flat, out_tree = tree_flatten(out)
if out_tree != tree:
raise TypeError(
"solve() output pytree structure must match initial_guess, got {} and {}"
.format(out_tree, tree))

return out_flat
num_consts, jaxpr, solve, _ = split_dict(
kwargs, ['num_consts', 'jaxpr', 'solve', 'tangent_solve'])
params, initial_guess = split_list(args, [num_consts])
f = partial(core.jaxpr_as_fun(jaxpr), *params)
return solve(f, *initial_guess)


def _root_jvp(
primals, tangents, tree, num_consts, jaxpr, solve, tangent_solve):
def _root_jvp(primals, tangents, num_consts, jaxpr, solve, tangent_solve):
params = primals[:num_consts]
solution = tuple(
root_p.bind(*primals, tree=tree, num_consts=num_consts,
jaxpr=jaxpr, solve=solve, tangent_solve=tangent_solve)
)

solution = tuple(root_p.bind(*primals, num_consts=num_consts, jaxpr=jaxpr,
solve=solve, tangent_solve=tangent_solve))
params_dot = tangents[:num_consts]

# F(m, u) = 0 # system of equations in u, parameterized by m
Expand All @@ -944,31 +957,16 @@ def _root_jvp(
#
# ∂ u*(m)[v] = - (∂_1 F(m, u*(m)))^{-1} [∂_0 F(m, u*(m))[v]] # jvp

unchecked_zeros, f_jvp = api.linearize(
core.jaxpr_as_fun(jaxpr), *(params + solution)
)

params_zeros = tuple(_map(ad_util.zeros_like_jaxval, params))
solution_zeros = tuple(_map(ad_util.zeros_like_jaxval, solution))

f_linearized_at_solution = partial(
apply_flat_fun_nokwargs, partial(f_jvp, *params_zeros), (tree, tree),
)
rhs = tree_unflatten(tree, f_jvp(*(params_dot + solution_zeros)))
solution_dot = tree_map(
operator.neg, tangent_solve(f_linearized_at_solution, rhs)
)

solution_dot_flat, out_tree = tree_flatten(solution_dot)
if out_tree != tree:
raise TypeError(
"tangent_solve() output pytree structure must match initial_guess, "
"got {} and {}".format(out_tree, tree))
f = core.jaxpr_as_fun(jaxpr)
f_fixed_params = lambda *solution: f(*(params + solution))
f_fixed_solution = lambda *params: f(*(params + solution))

return solution, solution_dot_flat
_, rhs = ad.jvp(lu.wrap_init(f_fixed_solution)).call_wrapped(params, params_dot)
# TDDO(shoyer): use the lower-level ad.linearize here?
_, f_jvp_wrt_solution = api.linearize(f_fixed_params, *solution)
solution_dot = [-x for x in tangent_solve(f_jvp_wrt_solution, *rhs)]

def _root_batch(args, dims, **params):
return batching.batch_fun(lu.wrap_init(_root_impl, params), args, dims)
return solution, solution_dot


root_p = core.Primitive('root')
Expand All @@ -977,4 +975,130 @@ def _root_batch(args, dims, **params):
root_p.def_abstract_eval(_root_abstract_eval)
ad.primitive_jvps[root_p] = _root_jvp
xla.initial_style_translations[root_p] = xla.lower_fun(_root_impl, initial_style=True)
batching.primitive_batchers[root_p] = _root_batch
batching.deftraced(root_p)


_Solves = collections.namedtuple('_Solves', 'forward tangent cotangent')

def linear_solve(matvec, b, forward_solve, tangent_solve=None,
cotangent_solve=None, symmetric=False):
"""Differentiably solve the linear map matvec(x)=b for x.

Required invariant:
x = solve(matvec, b)
error = matvec(x) - b
assert all(error == 0)
"""
if tangent_solve is None:
tangent_solve = forward_solve
if cotangent_solve is None:
cotangent_solve = tangent_solve

b_flat, in_args_tree = tree_flatten((b,))
b_avals = tuple(_map(_abstractify, b_flat))
jaxpr, consts, out_tree = _initial_style_jaxpr(matvec, in_args_tree, b_avals)

in_tree, = treedef_children(in_args_tree)
if in_tree != out_tree:
raise TypeError(
_tree_error_template("matvec", "b").format(out_tree, in_tree)
)

solve = _Solves(
_flatten_high_level_func(
forward_solve, in_tree, _tree_error_template("forward_solve", "b")),
_flatten_high_level_func(
tangent_solve, in_tree, _tree_error_template("tangent_solve", "b")),
_flatten_high_level_func(
cotangent_solve, in_tree, _tree_error_template("cotangent_solve", "b")),
)

out_flat = linear_solve_p.bind(
*itertools.chain(consts, b_flat), num_consts=len(consts), jaxpr=jaxpr,
solve=solve, symmetric=symmetric)
return tree_unflatten(out_tree, out_flat)


def _linear_solve_abstract_eval(*args, **kwargs):
return args[kwargs['num_consts']:]


def _linear_solve_impl(*args, **kwargs):
num_consts, jaxpr, solve, _ = split_dict(
kwargs, ['num_consts', 'jaxpr', 'solve', 'symmetric'])
params, b = split_list(args, [num_consts])
matvec = partial(core.jaxpr_as_fun(jaxpr), *params)
return solve.forward(matvec, *b)


def _tangent_linear_map(func, params, params_dot, *x):
"""Compute the tangent of a linear map.

Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``,
this function computes ``∂A @ x``.
"""
zeros = [ad_util.zero] * len(x)
_, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
params + list(x), params_dot + zeros)
return out_tangent


def _linear_solve_jvp(primals, tangents, num_consts, jaxpr, solve, symmetric):
# A x - b = 0
# ∂A x + A ∂x - ∂b = 0
# ∂x = A^{-1} (∂b - ∂A x)

x = linear_solve_p.bind(
*primals, num_consts=num_consts, jaxpr=jaxpr, solve=solve,
symmetric=symmetric)

params, _ = split_list(primals, [num_consts])
params_dot, b_dot = split_list(tangents, [num_consts])

matvec = partial(core.jaxpr_as_fun(jaxpr), *params)
matvec_dot = partial(
_tangent_linear_map, core.jaxpr_as_fun(jaxpr), params, params_dot
)

if all(tangent is ad_util.zero for tangent in params_dot):
rhs = b_dot
elif all(tangent is ad_util.zero for tangent in b_dot):
rhs = [-u for u in matvec_dot(*x)]
else:
rhs = [u - v for u, v in zip(b_dot, matvec_dot(*x))]
x_dot = solve.tangent(matvec, *rhs)
return x, x_dot


def _transpose_function(linear_fun, xs):
"""Transpose a linear function."""
# TODO(shoyer): can we use something more direct than the vjp machinery?
# It's particularly awkward that we need the second argument to give
# particular values of the primals, which are entirely arbitrary.
_, transposed_fun = ad.vjp(lu.wrap_init(linear_fun), xs)
return transposed_fun


def _linear_solve_transpose_rule(cotangent, *primals, **kwargs):
num_consts, jaxpr, solve, symmetric = split_dict(
kwargs, ['num_consts', 'jaxpr', 'solve', 'symmetric'])
params, b = split_list(primals, [num_consts])
assert b == [ad.undefined_primal] * len(b)
if symmetric:
vecmat = partial(core.jaxpr_as_fun(jaxpr), *params)
else:
vecmat = _transpose_function(
partial(core.jaxpr_as_fun(jaxpr), *params), cotangent)
cotangent_b = solve.cotangent(vecmat, *cotangent)
return [None] * num_consts + cotangent_b


linear_solve_p = core.Primitive('linear_solve')
linear_solve_p.multiple_results = True
linear_solve_p.def_impl(_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
ad.primitive_jvps[linear_solve_p] = _linear_solve_jvp
xla.initial_style_translations[linear_solve_p] = xla.lower_fun(
_linear_solve_impl, initial_style=True)
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
batching.deftraced(linear_solve_p)
Loading