Skip to content

Commit

Permalink
Update pf_multiSample.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwang12 authored Oct 9, 2024
1 parent 73fa25f commit 933de68
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions mNSF/pf_multiSample.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,19 +250,15 @@ def sample_latent_GP_funcs(self, X, S=1, kernel=None, mu_z=None, Kuu_chol=None,
mu_z = self.get_mu_z()
if Kuu_chol is None:
Kuu_chol = self.get_Kuu_chol(kernel=kernel, from_cache=(not chol))
if chol:
alpha_x = self.alpha_x
if (not chol):
N = X.shape[0]
L = self.W.shape[1]
mu_x = self.beta0+tfl.matmul(self.beta, X, transpose_b=True) #LxN
Kuf = self.Kuf
Kff_diag = self.Kff_diag
mu_tilde = mu_x + tfl.matvec(alpha_x, self.delta-mu_z, transpose_a=True) #LxN
a_t_Kchol = self.a_t_Kchol
aKa = tf.reduce_sum(tf.square(a_t_Kchol), axis=2) #LxN
aOmega_a=self.aOmega_a
Sigma_tilde = Kff_diag - aKa + aOmega_a #LxN
if (not chol):
mu_tilde = mu_x + tfl.matvec(self.alpha_x, self.delta-mu_z, transpose_a=True) #LxN
#a_t_Kchol = self.a_t_Kchol
#aKa = tf.reduce_sum(tf.square(a_t_Kchol), axis=2) #LxN
Sigma_tilde = self.Sigma_tilde #LxN
if chol:
alpha_x = tfl.cholesky_solve(Kuu_chol, Kuf) #LxMxN
N = X.shape[0]
L = self.W.shape[1]
Expand Down

0 comments on commit 933de68

Please sign in to comment.