Skip to content

Commit

Permalink
instantiate zeros
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobjinkelly committed May 1, 2020
1 parent 49a8901 commit 3b9f7e1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax.util import unzip2
from jax import ad_util
from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten)
treedef_is_leaf, tree_flatten, tree_unflatten, tree_map)
import jax.linear_util as lu
from jax.interpreters import xla
from jax.lax import lax
Expand Down Expand Up @@ -59,6 +59,7 @@ def jet_fun(primals, series):
with core.new_master(JetTrace) as master:
out_primals, out_terms = yield (master, primals, series), {}
del master
out_terms = [tree_map(jax.numpy.zeros_like, series[0]) if s is zero_series else s for s in out_terms]
yield out_primals, out_terms

@lu.transformation
Expand Down
15 changes: 15 additions & 0 deletions tests/jet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,21 @@ def test_select(self):
series_in = (terms_b, terms_x, terms_y)
self.check_jet(np.where, primals, series_in)

def test_inst_zero(self):
def f(x):
return 2.
def g(x):
return 2. + 0 * x
x = np.ones(1)
order = 3
f_out_primals, f_out_series = jet(f, (x, ), ([np.ones_like(x) for _ in range(order)], ))
assert f_out_series is not zero_series

g_out_primals, g_out_series = jet(g, (x, ), ([np.ones_like(x) for _ in range(order)], ))

assert g_out_primals == f_out_primals
assert g_out_series == f_out_series


if __name__ == '__main__':
absltest.main()

0 comments on commit 3b9f7e1

Please sign in to comment.