Skip to content

Commit

Permalink
Backend tensorflow.compat.v1 supports forward-mode AD via double back…
Browse files Browse the repository at this point in the history
…wards trick (lululxvi#1614)
  • Loading branch information
ZongrenZou authored Jan 3, 2024
1 parent aa69952 commit d342b4d
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions deepxde/gradients/gradients_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,22 @@ def __call__(self, i=None, j=None):

# Compute J[:, j]
if j not in self.J:
if backend_name in [
"tensorflow.compat.v1",
"paddle",
]:
# TODO: Other backends
raise NotImplementedError(
"Backend f{backend_name} doesn't support forward-mode autodiff."
if backend_name == "tensorflow.compat.v1":
# We use the double backwards trick to compute the jvp of a function in
# backend tensorflow.compat.v1, because autodiff.ForwardAccumulator is
# not supported. We note that this is not the exact jvp.
tangent = tf.one_hot([j], depth=self.xs.shape[1]) * tf.ones_like(
self.xs
)
u = tf.ones_like(self.ys)
g = tf.gradients(self.ys, self.xs, grad_ys=u)
self.J[j] = tf.gradients(g, u, grad_ys=tangent)[0]
elif backend_name == "tensorflow":
# We use tensorflow.autodiff.ForwardAccumulator to compute the jvp of
# a function.
# TODO: create the tangent in a smarter way
tangent = tf.one_hot(self.xs.shape[0] * [j], depth=self.xs.shape[1])
tangent = tf.one_hot([j], depth=self.xs.shape[1]) * tf.ones_like(
self.xs
)

def grad_fn(x):
with tf.autodiff.ForwardAccumulator(
Expand Down Expand Up @@ -72,13 +75,20 @@ def grad_fn(x):
tangent = jax.numpy.zeros(self.dim_x).at[j].set(1)
grad_fn = lambda x: jax.jvp(self.ys[1], (x,), (tangent,))[1]
self.J[j] = (jax.vmap(grad_fn)(self.xs), grad_fn)
elif backend_name == "paddle":
# TODO: Other backends
raise NotImplementedError(
"Backend f{backend_name} doesn't support forward-mode autodiff."
)

if i is None or self.dim_y == 1:
return self.J[j]

# Compute J[i, j]
if (i, j) not in self.J:
if backend_name in ["tensorflow", "pytorch", "jax"]:
if backend_name == "tensorflow.compat.v1":
self.J[i, j] = self.J[j][:, i : i + 1]
elif backend_name in ["tensorflow", "pytorch", "jax"]:
# In backend tensorflow/pytorch/jax, a tuple of a tensor/tensor/array
# and a callable is returned, so that it is consistent with the argument,
# which is also a tuple. This is useful for further computation, e.g.,
Expand Down

0 comments on commit d342b4d

Please sign in to comment.