Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lax.custom_linear_solve primitive #1402

Merged
merged 29 commits into from
Oct 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
6888e4f
WIP: linear solvers
shoyer Sep 25, 2019
385e282
Merge branch 'solvers' into linear-solvers
shoyer Sep 26, 2019
189ca8c
Draft of lax.linear_solve
shoyer Sep 26, 2019
2479314
Merge branch 'master' into linear-solvers
shoyer Sep 28, 2019
5042dde
Refactor pytree munging inside lax.root.
shoyer Sep 28, 2019
42e09e5
Merge branch 'refactor-root' into linear-solvers
shoyer Sep 28, 2019
d3211a2
Fixup linear_solve
shoyer Sep 29, 2019
2281105
Linearize multiple times in _root_jvp to avoid zeros
shoyer Sep 29, 2019
c9cfdf0
Merge branch 'refactor-root' into linear-solvers
shoyer Sep 29, 2019
8a2ec33
fix deftraced
shoyer Sep 29, 2019
b1806c4
add a symmetric argument
shoyer Sep 29, 2019
5c9c8cc
Fixup float64; add a test for symmetric/non-symmetric
shoyer Oct 2, 2019
3e39769
test zeros in linear_solve_jvp
shoyer Oct 2, 2019
d5aaaad
Revisions per review
shoyer Oct 3, 2019
39c3f5c
Merge branch 'refactor-root' into linear-solvers
shoyer Oct 3, 2019
f5f3007
Adjust signature of linear_solve
shoyer Oct 3, 2019
709899b
Merge branch 'master' into linear-solvers
shoyer Oct 3, 2019
f98c1b3
restore botched test
shoyer Oct 3, 2019
57cc313
variable names
shoyer Oct 3, 2019
fab4f45
jaxprize
shoyer Oct 10, 2019
1de6200
spelling
shoyer Oct 10, 2019
8fa4ac6
Use np.dot instead of @
shoyer Oct 10, 2019
2e0ccf7
linear_solve docstring, more tests
shoyer Oct 10, 2019
0e06627
Disable batching for root and linear_solve
shoyer Oct 10, 2019
eadb731
Fix linear_solve tests
shoyer Oct 10, 2019
edea7ce
remove unused imports
shoyer Oct 10, 2019
a276348
Rename to custom_linear_solve
shoyer Oct 10, 2019
30c2dc2
WIP: refactor
shoyer Oct 10, 2019
20169dc
fixup test for lazy transpose_solve error
shoyer Oct 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 212 additions & 15 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import division
from __future__ import print_function

import collections
import itertools
import operator
import threading
Expand Down Expand Up @@ -851,7 +852,7 @@ def body(i, dst):


def _flatten_higher_order_func(
f, tree, error_template="Expected {}, got {}",
f, tree, func_name, input_name,
):
"""Flatten a higher order function ``f`` of the form ``f(g, x)``.

Expand All @@ -871,15 +872,17 @@ def flat_fun(flat_g, *args_flat):
g = partial(apply_flat_fun_nokwargs, flat_g, (tree, tree))
out = f(g, args)
out_flat, out_tree = tree_flatten(out)
if out_tree != tree:
raise TypeError(error_template.format(tree, out_tree))
_check_tree(func_name, input_name, out_tree, tree)
return out_flat
return flat_fun


def _root_tree_error_template(func_name):
return (func_name + "() output pytree structure must match initial_guess, "
+ "got {} and {}.")
def _check_tree(func_name, expected_name, actual_tree, expected_tree):
if actual_tree != expected_tree:
raise TypeError(
"{}() output pytree structure must match {}, got {} and {}."
.format(func_name, expected_name, actual_tree, expected_tree))



def root(f, initial_guess, solve, tangent_solve):
Expand Down Expand Up @@ -922,13 +925,12 @@ def root(f, initial_guess, solve, tangent_solve):
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_args_tree, guess_avals)

in_tree, = treedef_children(in_args_tree)
if in_tree != out_tree:
raise TypeError(_root_tree_error_template("f").format(out_tree, in_tree))
_check_tree("f", "initial_guess", out_tree, in_tree)

solve_flat = _flatten_higher_order_func(
solve, in_tree, _root_tree_error_template("solve"))
solve, in_tree, "solve", "initial_guess")
tangent_solve_flat = _flatten_higher_order_func(
tangent_solve, in_tree, _root_tree_error_template("tangent_solve"))
tangent_solve, in_tree, "tangent_solve", "initial_guess")

out_flat = root_p.bind(*itertools.chain(consts, guess_flat),
num_consts=len(consts), jaxpr=jaxpr, solve=solve_flat,
Expand Down Expand Up @@ -974,14 +976,209 @@ def _root_jvp(primals, tangents, num_consts, jaxpr, solve, tangent_solve):
return solution, solution_dot


def _root_batch(args, dims, **params):
return batching.batch_fun(lu.wrap_init(_root_impl, params), args, dims)


root_p = core.Primitive('root')
root_p.multiple_results = True
root_p.def_impl(_root_impl)
root_p.def_abstract_eval(_root_abstract_eval)
ad.primitive_jvps[root_p] = _root_jvp
xla.initial_style_translations[root_p] = xla.lower_fun(_root_impl, initial_style=True)
batching.primitive_batchers[root_p] = _root_batch
# TODO(shoyer): write batching rule


class _LinearSolveTuple(collections.namedtuple(
'_LinearSolveTuple', 'matvec, vecmat, solve, transpose_solve')):

def transpose(self):
return type(self)(self.vecmat, self.matvec, self.transpose_solve, self.solve)


def _split_linear_solve_args(args, const_lengths):
params_list = split_list(args, list(const_lengths))
return _LinearSolveTuple(*params_list[:-1]), params_list[-1]


def _transpose_function(linear_fun, primals):
"""Transpose a linear function."""
# TODO(shoyer): can we use something more direct than the vjp machinery?
# It's particularly awkward that we need the second argument to give
# particular values of the primals, which are entirely arbitrary.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. It might be a new transformation, but one that shares a lot of machinery with the existing jvp/transpose.

_, vjp_fun = api.vjp(linear_fun, primals)

def transposed_fun(x):
(y,) = vjp_fun(x)
return y

return transposed_fun


def _flatten(args):
return [x for arg in args for x in arg]


def _check_shapes(func_name, expected_name, actual, expected, tree):
actual_shapes = _map(onp.shape, actual)
expected_shapes = _map(onp.shape, expected)
if actual_shapes != expected_shapes:
actual_shape_tree = tree_unflatten(tree, actual_shapes)
act_shape_tree = tree_unflatten(tree, actual_shapes)
raise ValueError('{}() output shapes must match {}, got {} and {}'
.format(func_name, expected_name,
tree_unflatten(tree, actual_shapes),
tree_unflatten(tree, expected_shapes)))


def custom_linear_solve(
matvec, b, solve, transpose_solve=None, symmetric=False):
"""Perform a matrix-free linear solve with implicitly defined gradients.

This function allows for overriding or defining gradients for a linear
solve directly via implicit differentiation at the solution, rather than by
differenting *through* the solve operation. This can sometimes be much faster
or more numerically stable, or differentiating through the solve operation
may not even be implemented (e.g., if ``solve`` using ``lax.while_loop``).

Required invariant:
x = solve(matvec, b) # solve the linear equation
assert matvec(x) == b # not checked

Args:
matvec: linear function to invert. Must be differentiable.
b: constant right handle side of the equation. May be any nested structure
of arrays.
solve: higher level function that solves for solution to the linear
equation, i.e., ``matvec(solve(matvec, x)) == x`` for all ``x`` of the
same form as ``b``. This function need not be differenatiable.
transpose_solve: higher level function for solving the transpose linear
equation, i.e., ``vecmat(transpose_solve(vecmat, x)) == x``, where
``vecmat`` is the transpose of the linear map ``matvec`` (computed
automatically with autodiff). Required for backwards mode automatic
differentiation, unless ``symmetric=True``, in which case ``solve``
provides the default value.
symmetric: bool indicating if it is safe to assume the linear map
corresponds to a symmetric matrix, i.e., ``matvec == vecmat``.

Returns:
Result of ``solve(matvec, b)``, with gradients defined assuming that the
solution ``x`` satisfies the linear equation ``matvec(x) == b``.
"""
if transpose_solve is None and symmetric:
transpose_solve = solve

b_flat, in_args_tree = tree_flatten((b,))
b_avals = tuple(_map(_abstractify, b_flat))
matvec_jaxpr, matvec_consts, out_tree = _initial_style_jaxpr(
matvec, in_args_tree, b_avals)

tree, = treedef_children(in_args_tree)
_check_tree("matvec", "b", out_tree, tree)

solve_jaxpr, solve_consts, out_tree = _initial_style_jaxpr(
partial(solve, matvec), in_args_tree, b_avals)
_check_tree("solve", "b", out_tree, tree)

if transpose_solve is None:
vecmat_jaxpr = tr_solve_jaxpr = None
vecmat_consts = tr_solve_consts = []
else:
if symmetric:
vecmat = matvec
vecmat_jaxpr = matvec_jaxpr
vecmat_consts = matvec_consts
else:
vecmat = _transpose_function(matvec, b)
vecmat_jaxpr, vecmat_consts, out_tree = _initial_style_jaxpr(
vecmat, in_args_tree, b_avals)
assert out_tree == tree

tr_solve_jaxpr, tr_solve_consts, out_tree = _initial_style_jaxpr(
partial(transpose_solve, vecmat), in_args_tree, b_avals)
_check_tree("transpose_solve", "b", out_tree, tree)

all_consts = [matvec_consts, vecmat_consts, solve_consts, tr_solve_consts]
const_lengths = _LinearSolveTuple(*_map(len, all_consts))
jaxprs = _LinearSolveTuple(
matvec_jaxpr, vecmat_jaxpr, solve_jaxpr, tr_solve_jaxpr)

out_flat = custom_linear_solve_p.bind(
*(_flatten(all_consts) + b_flat),
const_lengths=const_lengths, jaxprs=jaxprs, tree=tree)
return tree_unflatten(tree, out_flat)


def _custom_linear_solve_abstract_eval(*args, **kwargs):
return args[sum(kwargs['const_lengths']):]


def _custom_linear_solve_impl(*args, **kwargs):
const_lengths, jaxprs, tree = split_dict(
kwargs, ['const_lengths', 'jaxprs', 'tree'])
params, b = _split_linear_solve_args(args, const_lengths)
x = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + b))
_check_shapes('solve', 'b', x, b, tree)
return x


def _tangent_linear_map(func, params, params_dot, *x):
"""Compute the tangent of a linear map.

Assuming ``func(*params, *x)`` is linear in ``x`` and computes ``A @ x``,
this function computes ``∂A @ x``.
"""
assert any(p is not ad_util.zero for p in params_dot)
zeros = [ad_util.zero] * len(x)
_, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped(
params + list(x), params_dot + zeros)
return out_tangent


def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs, tree):
# A x - b = 0
# ∂A x + A ∂x - ∂b = 0
# ∂x = A^{-1} (∂b - ∂A x)

kwargs = dict(const_lengths=const_lengths, jaxprs=jaxprs, tree=tree)
x = custom_linear_solve_p.bind(*primals, **kwargs)

params, _ = _split_linear_solve_args(primals, const_lengths)
params_dot, b_dot = _split_linear_solve_args(tangents, const_lengths)

if all(p is ad_util.zero for p in params_dot.matvec):
# no need to evaluate matvec_tangents
rhs = b_dot
else:
matvec_tangents = _tangent_linear_map(
core.jaxpr_as_fun(jaxprs.matvec), params.matvec, params_dot.matvec, *x)
_check_shapes("matvec", "b", matvec_tangents, x, tree)
rhs = _map(ad.add_tangents, b_dot, _map(operator.neg, matvec_tangents))

x_dot = custom_linear_solve_p.bind(*(_flatten(params) + rhs), **kwargs)

return x, x_dot


def _custom_linear_solve_transpose_rule(cotangent, *primals, **kwargs):
const_lengths, jaxprs, tree = split_dict(
kwargs, ['const_lengths', 'jaxprs', 'tree'])

if jaxprs.transpose_solve is None:
raise TypeError('transpose_solve required for backwards mode automatic '
'differentiation of custom_linear_solve')

params, b = _split_linear_solve_args(primals, const_lengths)
assert b == [ad.undefined_primal] * len(b)
cotangent_b = custom_linear_solve_p.bind(
*(_flatten(params.transpose()) + cotangent),
const_lengths=const_lengths.transpose(), jaxprs=jaxprs.transpose(),
tree=tree)
return [None] * sum(const_lengths) + cotangent_b


custom_linear_solve_p = core.Primitive('custom_linear_solve')
custom_linear_solve_p.multiple_results = True
custom_linear_solve_p.def_impl(_custom_linear_solve_impl)
custom_linear_solve_p.def_abstract_eval(_custom_linear_solve_abstract_eval)
ad.primitive_jvps[custom_linear_solve_p] = _custom_linear_solve_jvp
xla.initial_style_translations[custom_linear_solve_p] = xla.lower_fun(
_custom_linear_solve_impl, initial_style=True)
ad.primitive_transposes[custom_linear_solve_p] = _custom_linear_solve_transpose_rule
# TODO(shoyer): write batching rule
Loading