Skip to content

Commit

Permalink
simplify delay example
Browse files Browse the repository at this point in the history
  • Loading branch information
marmaduke woodman committed Feb 13, 2024
1 parent 2d0e871 commit 628d034
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 68 deletions.
55 changes: 19 additions & 36 deletions examples/delays-hcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,55 +28,37 @@ def load(fname):

# load one connectome
W, L = jp.array(npz['070-DesikanKilliany'][:,0])
n_to, n_from = W.shape

# setup aux vars for delays
# setup delays
dt = 0.1
v_c = 10.0 # m/s
lags = jp.floor(L / v_c / dt).astype('i')
ix_lag_from = jp.tile(jp.r_[:n_from], (n_to, 1))
max_lag = lags.max() + 1
print(max_lag)
Wt = jp.log(1+W.T[:,:,None]) # enable bcast for crv
dh = vb.make_delay_helper( jp.log(W+1), L, dt=dt)

# define parameters
from collections import namedtuple
Params = namedtuple('Params', 'dh theta k')

def dfun(buf, rv, t: int, p):
# we could close over the vars or pass in like so:
Wt, lags, ix_lag_from, mpr_theta, k = p
crv = (Wt * buf[t - lags, :, ix_lag_from]).sum(axis=1).T
return vb.mpr_dfun(rv, k*crv, mpr_theta)
# define our model
def dfun(buf, rv, t: int, p: Params):
crv = vb.delay_apply(p.dh, t, buf) # compute delay coupling
return vb.mpr_dfun(rv, p.k*crv, p.theta) # compute dynamics

def rgt0(rv, p):
def ensure_r_positive(rv, _):
r, v = rv
return jp.array([ r*(r>0), v ])

# buf should cover all delays + noise for time steps to take
chunk_len = int(10 / dt) # 10 ms
buf = jp.zeros((max_lag + chunk_len, 2, n_from))
buf = buf.at[:max_lag+1].add( jp.r_[0.1,-2.0].reshape(2,1) )
buf = jp.zeros((dh.max_lag + chunk_len, 2, dh.n_from))
buf = buf.at[:dh.max_lag+1].add( jp.r_[0.1,-2.0].reshape(2,1) )

# don't provide random numbers w/ make_continuation
# buf = buf.at[max_lag+1:].set( vb.randn(chunk_len-1, 2, n_from) )

# pack parameters (could/should be dict / dataclass)
k = 0.01
p = Wt, lags, ix_lag_from, vb.mpr_default_theta, k

# compile & run our loop & check outputs
_, run_chunk = vb.make_sdde(dt, max_lag, dfun, gfun=1e-3, unroll=10, adhoc=rgt0)
def _check(buf):
buf, rv = run_chunk(buf, p)
assert buf.shape[0] == (max_lag + chunk_len)
assert rv.shape == (chunk_len, 2, n_from)
_check(buf)

# jit the buffer updates
cont_chunk = vb.make_continuation(run_chunk, chunk_len, max_lag, n_from, n_svar=2, stochastic=True)
# compile model and enable continuations
_, run_chunk = vb.make_sdde(dt, dh.max_lag, dfun, gfun=1e-3, unroll=10, adhoc=ensure_r_positive)
cont_chunk = vb.make_continuation(run_chunk, chunk_len, dh.max_lag, dh.n_from, n_svar=2, stochastic=True)

# setup time avg and bold monitors
ta_buf, ta_step, ta_samp = vb.make_timeavg((2, n_from))
ta_buf, ta_step, ta_samp = vb.make_timeavg((2, dh.n_from))
ta_samp = vb.make_offline(ta_step, ta_samp)
bold_buf, bold_step, bold_samp = vb.make_bold((2, n_from), dt, vb.bold_default_theta)
bold_buf, bold_step, bold_samp = vb.make_bold((2, dh.n_from), dt, vb.bold_default_theta)
bold_samp = vb.make_offline(bold_step, bold_samp)

# run chunk w/ monitors
Expand All @@ -93,7 +75,8 @@ def run_one_second(bufs, key):
return jax.lax.scan(chunk_ta_bold, bufs, keys)

# pack buffers and run it one minute
bufs = p, buf, ta_buf, bold_buf
params = Params(dh, vb.mpr_default_theta, 0.01)
bufs = params, buf, ta_buf, bold_buf
ta, bold = [], []
keys = jax.random.split(jax.random.PRNGKey(42), 60)
for i, key in enumerate(tqdm.tqdm(keys)):
Expand Down
2 changes: 1 addition & 1 deletion vbjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _use_many_cores():
)
from .regmap import make_region_mapping
from .coupling import (
make_diff_cfun, make_linear_cfun, make_delayed_coupling
make_diff_cfun, make_linear_cfun, make_delay_helper, delay_apply,
)
from .connectome import make_conn_latent_mvnorm
from .sparse import make_spmv, csr_to_jax_bcoo, make_sg_spmv
Expand Down
49 changes: 18 additions & 31 deletions vbjax/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,27 @@
"""

import jax.numpy as np
from collections import namedtuple
import jax.numpy as jp


def make_delayed_coupling(weights, delay_steps, pre, post, nh, isvar):
"""
Construct a dense delayed coupling function.
Parameters
==========
weights : array
Coupling weights.
delay_steps : array
Number of delay steps per connection i, j.
...
To be finished
...
DelayHelper = namedtuple('DelayHelper', 'Wt lags ix_lag_from max_lag n_to n_from')

Notes
=====
def make_delay_helper(W, L, dt=0.1, v_c=10.0) -> DelayHelper:
n_to, n_from = W.shape
lags = jp.floor(L / v_c / dt).astype('i')
ix_lag_from = jp.tile(jp.r_[:n_from], (n_to, 1))
max_lag = lags.max() + 1
Wt = W.T[:,:,None] # enable bcast for coupling vars
dh = DelayHelper(Wt, lags, ix_lag_from, max_lag, n_to, n_from)
return dh

- This construction assumes a particular layout for the history
buffer: xt.shape == (nh+1+nt, nsvar, nnode, ...).
def delay_apply(dh: DelayHelper, t, buf):
return (dh.Wt * buf[t - dh.lags, :, dh.ix_lag_from]).sum(axis=1).T

"""
nn = weights.shape[0]
nodes = np.tile(np.r_[:nn], (nn, 1))
def cfun(t, xt, x, params):
dx = xt[nh + t - delay_steps, isvar, nodes]
xij = pre(dx, x, params)
gx = (weights * xij).sum(axis=1)
return post(gx, params)
return cfun
# TODO impl sparse delay_apply

# TODO the following are not used: maybe drop them

def make_linear_cfun(SC, a=1.0, b=0.0):
"""Construct a linear coupling function with slope `a` and offset `b`.
Expand All @@ -46,15 +33,15 @@ def cfun(xj):
if xj.ndim == 1: # no delays
gx = SC @ xj
elif xj.ndim == 2: # delays
gx = np.sum(SC * xj, axis=1)
gx = jp.sum(SC * xj, axis=1)
return a*gx + b
return cfun


def make_diff_cfun(SC, a=1.0, b=0.0):
"""Construct a linear difference coupling."""
nn = np.r_[:SC.shape[0]]
diffdiag = np.diag(SC) - SC.sum(axis=1)
nn = jp.r_[:SC.shape[0]]
diffdiag = jp.diag(SC) - SC.sum(axis=1)
SC_ = SC.at[nn,nn].set(diffdiag)
# fix diagonal according to trick
return make_linear_cfun(SC_, a=a, b=b)

0 comments on commit 628d034

Please sign in to comment.