v0.0.5
What's Changed
import numpy as np
import jax
import finitediffx as fdx
def func(x):
return np.sum(np.sin(x))
x = np.ones([5,1])
print(fdx.fgrad(func)(x))
# [[0.54003906]
# [0.54003906]
# [0.54003906]
# [0.54003906]
# [0.54003906]]
try:
jax.grad(func)(x)
except jax.errors.TracerArrayConversionError:
print("Fail: `jax.grad` cannot be used with numpy arrays")
def func(x):
return jnp.sum(jnp.sin(x))
print(jax.grad(func)(x))
# Fail: `jax.grad` cannot be used with numpy arrays
# [[0.54030234]
# [0.54030234]
# [0.54030234]
# [0.54030234]
# [0.54030234]]
Full Changelog: v0.0.4...v0.0.5