diff --git a/mNSF/pf_multiSample.py b/mNSF/pf_multiSample.py index efe8e3a..f506c78 100644 --- a/mNSF/pf_multiSample.py +++ b/mNSF/pf_multiSample.py @@ -250,19 +250,15 @@ def sample_latent_GP_funcs(self, X, S=1, kernel=None, mu_z=None, Kuu_chol=None, mu_z = self.get_mu_z() if Kuu_chol is None: Kuu_chol = self.get_Kuu_chol(kernel=kernel, from_cache=(not chol)) - if chol: - alpha_x = self.alpha_x + if (not chol): N = X.shape[0] L = self.W.shape[1] mu_x = self.beta0+tfl.matmul(self.beta, X, transpose_b=True) #LxN - Kuf = self.Kuf - Kff_diag = self.Kff_diag - mu_tilde = mu_x + tfl.matvec(alpha_x, self.delta-mu_z, transpose_a=True) #LxN - a_t_Kchol = self.a_t_Kchol - aKa = tf.reduce_sum(tf.square(a_t_Kchol), axis=2) #LxN - aOmega_a=self.aOmega_a - Sigma_tilde = Kff_diag - aKa + aOmega_a #LxN - if (not chol): + mu_tilde = mu_x + tfl.matvec(self.alpha_x, self.delta-mu_z, transpose_a=True) #LxN + #a_t_Kchol = self.a_t_Kchol + #aKa = tf.reduce_sum(tf.square(a_t_Kchol), axis=2) #LxN + Sigma_tilde = self.Sigma_tilde #LxN + if chol: alpha_x = tfl.cholesky_solve(Kuu_chol, Kuf) #LxMxN N = X.shape[0] L = self.W.shape[1]