Skip to content

Commit

Permalink
Merge pull request #733 from fabianp:doctest
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 601143290
  • Loading branch information
OptaxDev committed Jan 24, 2024
2 parents fd3ac5c + 56ec0f3 commit d3ad31c
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 12 deletions.
7 changes: 7 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ def new_process_docstring(app, what, name, obj, options, lines):
'sphinxcontrib.collections'
]

# so we don't have to do the canonical imports on every doctest
doctest_global_setup = '''
import optax
import jax
import jax.numpy as jnp
'''

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']

Expand Down
16 changes: 8 additions & 8 deletions optax/tree_utils/_state_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ def tree_map_params(
`transform_non_params` can be used to replace any remaining fields as
required, in this case, we replace those fields by None.
>>> params, specs = ... # Trees with the same shape
>>> params, specs = jnp.array(0.), jnp.array(0.) # Trees with the same shape
>>> opt = optax.sgd(1e-3)
>>> state = opt.init(params)
>>>
>>> opt_specs = optax.tree_map_params(
>>> opt,
>>> lambda _, spec: spec,
>>> state,
>>> specs,
>>> transform_non_params=lambda _: None,
>>> )
... opt,
... lambda _, spec: spec,
... state,
... specs,
... transform_non_params=lambda _: None,
... )
Args:
initable: A callable taking parameters and returning an optimizer state, or
Expand Down
8 changes: 4 additions & 4 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric:
inner product between ``tree_x`` and ``tree_y``, a scalar value.
>>> optax.tree_utils.tree_vdot(
>>> {a: jnp.array([1, 2]), b: jnp.array([1, 2])},
>>> {a: jnp.array([-1, -1]), b: jnp.array([1, 1])},
>>> )
0.0
... {'a': jnp.array([1, 2]), 'b': jnp.array([1, 2])},
... {'a': jnp.array([-1, -1]), 'b': jnp.array([1, 1])},
... )
Array(0, dtype=int32)
Implementation detail: we upcast the values to the highest precision to avoid
numerical issues.
Expand Down
2 changes: 2 additions & 0 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ cd ..
# Build Sphinx docs.
pip install -e ".[docs]"
cd docs && make html
# run doctests
make doctest
cd ..

set +u
Expand Down

0 comments on commit d3ad31c

Please sign in to comment.