Skip to content

Commit

Permalink
removing jax.Array type to avoid the need of using the new jax version
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobRobnik committed Nov 27, 2023
1 parent 6fdbd35 commit f300409
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions mclmc/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
class State(NamedTuple):
"""Dynamical state"""

x: jax.Array
u: jax.Array
x: any#jax.Array
u: any#jax.Array
l: float
g: jax.Array
g: any#jax.Array
key: tuple


Expand Down
2 changes: 1 addition & 1 deletion mclmc/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class Hyperparameters(NamedTuple):

L: float
eps: float
sigma: jax.Array
sigma: any


# all tuning functions are wrappers, recieving some parameters and returning a function
Expand Down

0 comments on commit f300409

Please sign in to comment.