From 0dbbc27bb1d55261002cfbe9785425814a14c43a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 28 Apr 2020 11:58:51 -0400 Subject: [PATCH] Clarify that `grad` requires arguments to be differentiated to be of inexact type. (#2712) --- jax/api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/jax/api.py b/jax/api.py index 494f091263a3..d9f157369375 100644 --- a/jax/api.py +++ b/jax/api.py @@ -341,7 +341,9 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0, Args: fun: Function to be differentiated. Its arguments at positions specified by - ``argnums`` should be arrays, scalars, or standard Python containers. It + ``argnums`` should be arrays, scalars, or standard Python containers. + Argument arrays in the positions specified by ``argnums`` must be of + inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape ``()`` but not arrays with shape ``(1,)`` etc.) argnums: Optional, integer or sequence of integers. Specifies which