Skip to content

Commit

Permalink
Make type of value_and_grad slightly more precise. (jax-ml#2704)
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp authored and jacobjinkelly committed Apr 21, 2020
1 parent 2603d5f commit 5dac151
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ def grad_f_aux(*args, **kwargs):
return grad_f_aux if has_aux else grad_f

def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False) -> Callable:
has_aux: bool = False, holomorphic: bool = False
) -> Callable[..., Tuple[Any, Any]]:
"""Create a function which evaluates both ``fun`` and the gradient of ``fun``.
Args:
Expand Down

0 comments on commit 5dac151

Please sign in to comment.