Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hessians of network functions #2682

Closed
g-benton opened this issue Apr 11, 2020 · 2 comments
Closed

Hessians of network functions #2682

g-benton opened this issue Apr 11, 2020 · 2 comments

Comments

@g-benton
Copy link

I'm trying to get out the Hessian of the function produced by a neural network, and I cannot figure out a way to actually get access to the Hessian matrix.

from jax import random, jacfwd, jacrev, tree_flatten
from jax.experimental import stax
import jax.numpy as np

key = random.PRNGKey(10)
key, x_key, y_key = random.split(key, 3)

target_fn = lambda x: np.sin(x) + 0.2 * np.sin(4 * x)

train_points=10
train_xs = random.uniform(x_key, (train_points, 1), minval=-np.pi, maxval=np.pi)
train_ys = target_fn(train_xs)


init_fn, apply_fn = stax.serial(
    stax.Dense(35), stax.Relu,
    stax.Dense(35), stax.Relu,
    stax.Dense(1)
)

_, params = init_fn(key, (-1, 1))

def get_hessian(fun):
    return jacfwd(jacrev(fun))

f = lambda W: apply_fn(W, train_xs[0, :])

# can do this by datapoint or in batches
hessian = get_hessian(apply_fn)(params, train_xs[0, :]) 

# how does one make sense of this list?
print([p.shape for p in tree_flatten(hessian)[0]]) 

Is there a way to manipulate the hessian so that I can view it directly as a matrix. Ideally one would be able to batch this response and call hessian = get_hessian(apply_fn)(params, train_xs) using the full training data and get an n x p x p array back containing the Hessians evaluated at each of the n data points.

I'm not sure if I'm missing something already built in or if this is something I would need to built. Any help is appreciated!

@hawkinsp
Copy link
Collaborator

Note: get_hessian is available as jax.hessian, so you don't need to define it yourself here.

hessian, jacfwd and jacrev accept Python trees of arrays, not just arrays. One way to think of what you get back from hessian is a block-sparse matrix. I think one way to understand the structure would be to use another smaller example with dictionaries as arguments:

import jax, jax.numpy as jnp
import numpy as np

def f(v):
  x = v['x']
  y = v['y']
  return jnp.sum(jnp.sin(x * jnp.cos(y)))

x = np.random.randn(3)
y = np.random.randn(3)

h = jax.hessian(f)({'x': x, 'y': y})

print(h)
print(h['x']['x'])

prints:

{'x': {'x': DeviceArray([[0.9596267 , 0.        , 0.        ],
             [0.        , 0.31985453, 0.        ],
             [0.        , 0.        , 0.0016721 ]], dtype=float32), 'y': DeviceArray([[-0.27333596,  0.        ,  0.        ],
             [ 0.        ,  0.35417664,  0.        ],
             [ 0.        ,  0.        ,  0.94545966]], dtype=float32)}, 'y': {'x': DeviceArray([[-0.27333596, -0.        , -0.        ],
             [ 0.        ,  0.3541766 ,  0.        ],
             [ 0.        ,  0.        ,  0.94545966]], dtype=float32), 'y': DeviceArray([[-0.29960722, -0.        , -0.        ],
             [ 0.        ,  2.2319145 ,  0.        ],
             [ 0.        ,  0.        ,  0.01589215]], dtype=float32)}}

and

[[-0.27333596  0.          0.        ]
 [ 0.          0.35417664  0.        ]
 [ 0.          0.          0.94545966]]

In other words, the output of hessian contains a two nested copies of the input python tree structure {"x": ..., "y", ...}, just like you'd expect two copies of each input array dimension in the hessian.

If you don't want JAX to maintain the sparse structure for you, you can simply flatten the input into a 1D dense array before calling hessian, e.g. something like this:

def flatten(v):
  def f(v):
    leaves, _ = jax.tree_util.tree_flatten(v)
    return jnp.concatenate([x.ravel() for x in leaves])
  out, pullback = jax.vjp(f, v)
  return out, lambda x: pullback(x)[0]

flat_xy, unflatten = flatten(xy)
h = jax.hessian(lambda t: f(unflatten(t)))(flat_xy)
print(h)

which prints:

[[ 9.5962667e-01  0.0000000e+00  0.0000000e+00 -2.7333596e-01
   0.0000000e+00  0.0000000e+00]
 [ 0.0000000e+00  3.1985453e-01  0.0000000e+00  0.0000000e+00
   3.5417664e-01  0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  1.6720986e-03  0.0000000e+00
   0.0000000e+00  9.4545966e-01]
 [-2.7333596e-01  0.0000000e+00  0.0000000e+00 -2.9960722e-01
   0.0000000e+00  0.0000000e+00]
 [ 0.0000000e+00  3.5417661e-01  0.0000000e+00  0.0000000e+00
   2.2319145e+00  0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  9.4545966e-01  0.0000000e+00
   0.0000000e+00  1.5892152e-02]]

Does that help answer your question?

@g-benton
Copy link
Author

This is perfect thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants