Skip to content

Commit

Permalink
Separate gathering to a function
Browse files Browse the repository at this point in the history
  • Loading branch information
trossi committed Oct 11, 2024
1 parent bd19d68 commit f3aea77
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions hmsc/updaters/updateEta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,32 @@
from hmsc.utils.import_utils import calculate_idDW12, set_slice
tfla, tfm, tfr, tfs = tf.linalg, tf.math, tf.random, tf.sparse


def gather_idDW12st(rLPar, AlphaInd):
if "idDW12g" in rLPar:
idDW12st = tf.gather(rLPar["idDW12g"], AlphaInd)
else: # lowmem
var = rLPar["idDW12st_var"]
m = tf.shape(AlphaInd)[0]
var.assign(tf.zeros(shape=[m, *var.shape[1:]], dtype=var.dtype))

cond = lambda j: tf.less(j, m)
def body(j):
i = AlphaInd[j]
idDW12 = calculate_idDW12(rLPar["d12"], rLPar["alpha"][i], rLPar["idDg"][i])
set_slice(var, j, idDW12)
return [j + 1, ]

j = tf.constant(0)
tf.while_loop(cond, body, [j])

idDW12st = var.read_value_no_copy()

tf.print('idDW12st', tf.shape(idDW12st), tfm.reduce_min(idDW12st, axis=[1, 2]), tfm.reduce_max(idDW12st, axis=[1, 2]))
return idDW12st



@tf_named_func("eta")
def updateEta(params, modelDims, data, rLHyperparams, dtype=np.float64):
"""Update conditional updater(s):
Expand Down Expand Up @@ -78,27 +104,7 @@ def updateEta(params, modelDims, data, rLHyperparams, dtype=np.float64):
elif rLPar["spatialMethod"] == "GPP":
idDst = tf.gather(rLPar["idDg"], AlphaInd)
Fst = tf.gather(rLPar["Fg"], AlphaInd)
if "idDW12g" in rLPar:
idDW12st = tf.gather(rLPar["idDW12g"], AlphaInd)
else: # lowmem
var = rLPar["idDW12st_var"]
m = tf.shape(AlphaInd)[0]
var.assign(tf.zeros(shape=[m, *rLPar["d12"].shape], dtype=dtype))

cond = lambda j: tf.less(j, m)
def body(j):
i = AlphaInd[j]
idDW12 = calculate_idDW12(rLPar["d12"], rLPar["alpha"][i], idDg[i])
set_slice(var, j, idDW12)
return [j + 1, ]

j = tf.constant(0)
tf.while_loop(cond, body, [j])

idDW12st = var.read_value_no_copy()

tf.print('idDW12st', tf.shape(idDW12st), tfm.reduce_min(idDW12st, axis=[1, 2]), tfm.reduce_max(idDW12st, axis=[1, 2]))

idDW12st = gather_idDW12st(rLPar, AlphaInd)
EtaListNew[r] = modelSpatialGPP(LamInvSigLam, mu0, Fst, idDst, idDW12st, rLPar["nK"], npVec[r], nf, dtype)
elif rLPar["spatialMethod"] == "NNGP":
modelSpatialNNGP_local = lambda LamInvSigLam, mu0, Alpha, nf: modelSpatialNNGP_scipy(LamInvSigLam, mu0, Alpha, rLPar["iWList_csr"], npVec[r], nf, dtype)
Expand Down

0 comments on commit f3aea77

Please sign in to comment.