Skip to content

Commit

Permalink
added SV to master branch
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobRobnik committed Oct 12, 2023
1 parent 7a59391 commit 31c5020
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 83 deletions.
32 changes: 21 additions & 11 deletions benchmarks/benchmarks_mchmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,9 @@
import jax
import jax.numpy as jnp
import os

from numpyro.examples.datasets import SP500, load_dataset
from numpyro.distributions import StudentT
from numpyro.distributions import Exponential

dirr = os.path.dirname(os.path.realpath(__file__))


### Benchmark targets ###


class StandardNormal():
"""Standard Normal distribution in d dimensions"""
Expand All @@ -22,6 +15,9 @@ def __init__(self, d):
self.variance = jnp.ones(d)
self.grad_nlogp = jax.value_and_grad(self.nlogp)

self.second_moments = jnp.ones(d)
self.variance_second_moments = 2 * self.second_moments


def nlogp(self, x):
"""- log p of the target distribution"""
Expand All @@ -36,6 +32,7 @@ def prior_draw(self, key):




class IllConditionedGaussian():
"""Gaussian distribution. Covariance matrix has eigenvalues equally spaced in log-space, going from 1/condition_bnumber^1/2 to condition_number^1/2."""

Expand Down Expand Up @@ -126,7 +123,9 @@ def __init__(self, prior = 'prior'):
eigs = np.sort(rng.gamma(shape=0.5, scale=1., size=self.d)) #eigenvalues of the Hessian
eigs *= jnp.average(1.0/eigs)
self.entropy = 0.5 * self.d
self.maxmin = (1./jnp.sqrt(eigs[0]), 1./jnp.sqrt(eigs[-1]))
R, _ = np.linalg.qr(rng.randn(self.d, self.d)) #random rotation
self.map_to_worst = (R.T)[[0, -1], :]
self.Hessian = R @ np.diag(eigs) @ R.T

# analytic ground truth moments
Expand All @@ -149,7 +148,7 @@ def __init__(self, prior = 'prior'):

else: # N(0, sigma_true_max)
self.prior_draw = lambda key: jax.random.normal(key, shape=(self.d,)) * jnp.max(1.0/jnp.sqrt(eigs))

def nlogp(self, x):
"""- log p of the target distribution"""
return 0.5 * x.T @ self.Hessian @ x
Expand Down Expand Up @@ -488,8 +487,7 @@ class StochasticVolatility():
"""Example from https://num.pyro.ai/en/latest/examples/stochastic_volatility.html"""

def __init__(self):
_, fetch = load_dataset(SP500, shuffle=False)
SP500_dates, self.SP500_returns = fetch()
self.SP500_returns = np.load(dirr + '/SP500.npy')

self.name = 'SV'
self.d = 2429
Expand All @@ -511,7 +509,7 @@ def nlogp(self, x):

l1= (jnp.exp(x[-2]) - x[-2]) + (jnp.exp(x[-1]) - x[-1])
l2 = (self.d - 2) * jnp.log(sigma) + 0.5 * (jnp.square(x[0]) + jnp.sum(jnp.square(x[1:-2] - x[:-3]))) / jnp.square(sigma)
l3 = -jnp.sum(StudentT(df=nu, scale= jnp.exp(x[:-2])).log_prob(self.SP500_returns))
l3 = jnp.sum(nlogp_StudentT(self.SP500_returns, nu, jnp.exp(x[:-2])))

return l1 + l2 + l3

Expand All @@ -538,6 +536,18 @@ def prior_draw(self, key):
walk = random_walk(key_walk, self.d - 2) * params[0]
return jnp.concatenate((walk, jnp.log(params/scales)))


def nlogp_StudentT(x, df, scale):
y = x / scale
z = (
jnp.log(scale)
+ 0.5 * jnp.log(df)
+ 0.5 * jnp.log(jnp.pi)
+ jax.scipy.special.gammaln(0.5 * df)
- jax.scipy.special.gammaln(0.5 * (df + 1.0))
)
return 0.5 * (df + 1.0) * jnp.log1p(y**2.0 / df) + z


def random_walk(key, num):
""" Genereting process for the standard normal walk:
Expand Down
Binary file modified sampling/__pycache__/correlation_length.cpython-38.pyc
Binary file not shown.
Binary file modified sampling/__pycache__/sampler.cpython-38.pyc
Binary file not shown.
30 changes: 19 additions & 11 deletions sampling/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def update(eps, x, u):



def minimal_norm(d, T, V, sigma):
def minimal_norm(d, T, V):

def step(x, u, g, eps):
def step(x, u, g, eps, sigma):
"""Integrator from https://arxiv.org/pdf/hep-lat/0505020.pdf, see Equation 20."""

# V T V T V
Expand All @@ -75,9 +75,9 @@ def step(x, u, g, eps):



def leapfrog(d, T, V, sigma):
def leapfrog(d, T, V):

def step(x, u, g, eps):
def step(x, u, g, eps, sigma):

# V T V
uu, r1 = V(eps * 0.5, u, g * sigma)
Expand All @@ -93,31 +93,32 @@ def step(x, u, g, eps):



def hamiltonian(integrator, sigma, grad_nlogp, d, sequential = True):
def hamiltonian(integrator, grad_nlogp, d, sequential = True):

T = update_position(grad_nlogp)
V = update_momentum(d, sequential)

if integrator == "LF": #leapfrog (first updates the velocity)
return leapfrog(d, T, V, sigma)
return leapfrog(d, T, V)

elif integrator== 'MN': #minimal norm integrator (first updates the velocity)
return minimal_norm(d, T, V, sigma)
return minimal_norm(d, T, V)

else:
raise Exception("Integrator must be either MN (minimal_norm) or LF (leapfrog)")


def mclmc(hamiltonian_dynamics, partially_refresh_momentum):

def mclmc(hamiltonian_dynamics, partially_refresh_momentum, d):

def step(self, x, u, g, random_key, L, eps):
def step(x, u, g, random_key, L, eps, sigma):
"""One step of the generalized dynamics."""

# Hamiltonian step
xx, uu, ll, gg, kinetic_change = hamiltonian_dynamics(x=x, u=u, g=g, eps=eps)
xx, uu, ll, gg, kinetic_change = hamiltonian_dynamics(x=x, u=u, g=g, eps=eps, sigma = sigma)

# Langevin-like noise
nu = jnp.sqrt((jnp.exp(2 * eps / L) - 1.) / self.Target.d)
nu = jnp.sqrt((jnp.exp(2 * eps / L) - 1.) / d)
uu, key = partially_refresh_momentum(u= uu, random_key= random_key, nu= nu)

return xx, uu, ll, gg, kinetic_change, key
Expand All @@ -129,35 +130,42 @@ def step(self, x, u, g, random_key, L, eps):
def random_unit_vector(d, sequential= True):
"""Generates a random (isotropic) unit vector."""


def rng_sequential(random_key):
key, subkey = jax.random.split(random_key)
u = jax.random.normal(subkey, shape = (d, ))
u /= jnp.sqrt(jnp.sum(jnp.square(u)))
return u, key


def rng_parallel(random_key, num_chains):
key, subkey = jax.random.split(random_key)
u = jax.random.normal(subkey, shape = (num_chains, d))
normed_u = u / jnp.sqrt(jnp.sum(jnp.square(u), axis = 1))[:, None]
return normed_u, key


return rng_sequential if sequential else rng_parallel




def partially_refresh_momentum(d, sequential= True):
"""Adds a small noise to u and normalizes."""


def rng_sequential(u, random_key, nu):
key, subkey = jax.random.split(random_key)
z = nu * jax.random.normal(subkey, shape = (d, ))

return (u + z) / jnp.sqrt(jnp.sum(jnp.square(u + z))), key


def rng_parallel(u, random_key, nu):
key, subkey = jax.random.split(random_key)
noise = nu * jax.random.normal(subkey, shape= u.shape, dtype=u.dtype)

return (u + noise) / jnp.sqrt(jnp.sum(jnp.square(u + noise), axis = 1))[:, None], key


return rng_sequential if sequential else rng_parallel
Loading

0 comments on commit 31c5020

Please sign in to comment.