Skip to content

Commit

Permalink
Update pf.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Oct 9, 2024
1 parent b901d4f commit e167e95
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions mNSF/NSF/pf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
dtp = "float32"
rng = np.random.default_rng()

def checkpoint_grad(x):
y = tf.identity(x)
def grad(dy):
return tf.gradients(y, x, dy)[0]
return y, grad

class ProcessFactorization(tf.Module):
def __init__(self, J, L, Z, lik="poi", chol = True, X=None, psd_kernel=tfk.MaternThreeHalves,
nugget=1e-5, length_scale=0.1, disp="default",
Expand Down Expand Up @@ -238,7 +244,8 @@ def sample_latent_GP_funcs(self, X, S=1, kernel=None, mu_z=None, Kuu_chol=None,
Kuf = kernel.matrix(self.Z, X) #LxMxN
Kff_diag = kernel.apply(X, X, example_ndims=1)+self.nugget #LxN
alpha_x = tfl.cholesky_solve(Kuu_chol, Kuf) #LxMxN
mu_tilde = mu_x + tfl.matvec(alpha_x, self.delta-mu_z, transpose_a=True) #LxN
#mu_tilde = mu_x + tfl.matvec(alpha_x, self.delta-mu_z, transpose_a=True) #LxN
mu_tilde = checkpoint_grad(mu_x + tfl.matvec(alpha_x, self.delta-mu_z, transpose_a=True))
#compute the alpha(x_i)'(K_uu-Omega)alpha(x_i) term
a_t_Kchol = tfl.matmul(alpha_x, Kuu_chol, transpose_a=True) #LxNxM
aKa = tf.reduce_sum(tf.square(a_t_Kchol), axis=2) #LxN
Expand All @@ -254,8 +261,8 @@ def sample_predictive_mean(self, X, sz=1, S=1, kernel=None, mu_z=None, Kuu_chol=
sz is a tensor of shape (N,1) of size factors.
Typically sz would be the rowSums or rowMeans of the outcome matrix Y.
"""
F = self.sample_latent_GP_funcs(X, S=S, kernel=kernel, mu_z=mu_z,
Kuu_chol=Kuu_chol, chol=chol) #SxLxN
F = checkpoint_grad(self.sample_latent_GP_funcs(X, S=S, kernel=kernel, mu_z=mu_z,
Kuu_chol=Kuu_chol, chol=chol))
if self.nonneg:
Lam = tfl.matrix_transpose(tfl.matmul(self.W, tf.exp(F))) #SxNxJ
if self.lik=="gau":
Expand Down Expand Up @@ -320,7 +327,7 @@ def elbo_avg(self, X, Y, sz=1, S=1, Ntot=None, chol=True):
#kl_terms is not affected by minibatching so use reduce_sum
#print(1111)
kl_term = tf.reduce_sum(self.eval_kl_term(mu_z, Kuu_chol))
Mu = self.sample_predictive_mean(X, sz=sz, S=S, kernel=ker, mu_z=mu_z, Kuu_chol=Kuu_chol)
Mu = checkpoint_grad(self.sample_predictive_mean(X, sz=sz, S=S, kernel=ker, mu_z=mu_z, Kuu_chol=Kuu_chol))
eloglik = likelihoods.lik_to_distr(self.lik, Mu, self.disp).log_prob(Y)
return J*tf.reduce_mean(eloglik) - kl_term/Ntot

Expand Down

0 comments on commit e167e95

Please sign in to comment.