Skip to content

v0.0.5

Compare
Choose a tag to compare
@ASEM000 ASEM000 released this 08 Jun 12:45
· 19 commits to main since this release

What's Changed

  • Enable pytree step_size, offsets input by @ASEM000 in #5
  • fgrad can handle arrays
import numpy as np
import jax
import finitediffx as fdx 

def func(x):
    return np.sum(np.sin(x))

x = np.ones([5,1])

print(fdx.fgrad(func)(x))
# [[0.54003906]
#  [0.54003906]
#  [0.54003906]
#  [0.54003906]
#  [0.54003906]]


try:
    jax.grad(func)(x)
except jax.errors.TracerArrayConversionError:
    print("Fail: `jax.grad` cannot be used with numpy arrays")
    def func(x):
        return jnp.sum(jnp.sin(x))
    print(jax.grad(func)(x))
# Fail: `jax.grad` cannot be used with numpy arrays
# [[0.54030234]
#  [0.54030234]
#  [0.54030234]
#  [0.54030234]
#  [0.54030234]]

Full Changelog: v0.0.4...v0.0.5