diff --git a/mclmc/dynamics.py b/mclmc/dynamics.py index 30b1568..63ae570 100644 --- a/mclmc/dynamics.py +++ b/mclmc/dynamics.py @@ -8,10 +8,10 @@ class State(NamedTuple): """Dynamical state""" - x: jax.Array - u: jax.Array + x: any#jax.Array + u: any#jax.Array l: float - g: jax.Array + g: any#jax.Array key: tuple diff --git a/mclmc/tune.py b/mclmc/tune.py index aabb66a..54bc478 100644 --- a/mclmc/tune.py +++ b/mclmc/tune.py @@ -12,7 +12,7 @@ class Hyperparameters(NamedTuple): L: float eps: float - sigma: jax.Array + sigma: any # all tuning functions are wrappers, recieving some parameters and returning a function