diff --git a/tensorflow_probability/python/distributions/gaussian_process_test.py b/tensorflow_probability/python/distributions/gaussian_process_test.py index e9c9e5805a..e43585f95f 100644 --- a/tensorflow_probability/python/distributions/gaussian_process_test.py +++ b/tensorflow_probability/python/distributions/gaussian_process_test.py @@ -25,6 +25,7 @@ from tensorflow_probability.python.distributions import normal from tensorflow_probability.python.internal import tensorshape_util from tensorflow_probability.python.internal import test_util +from tensorflow_probability.python.math import gradient from tensorflow_probability.python.math import psd_kernels @@ -450,6 +451,31 @@ def call_log_prob(d): class GaussianProcessStaticTest(_GaussianProcessTest, test_util.TestCase): is_static = True + @test_util.numpy_disable_gradient_test + def test_gradient(self): + x_obs = normal.Normal(0., 1.).sample((10, 6), seed=test_util.test_seed()) + y_obs = tf.reduce_sum(x_obs, axis=-1) + + def loss(length_scales): + kernel = psd_kernels.MaternFiveHalves(amplitude=tf.math.sqrt(1e-2)) + kernel = psd_kernels.FeatureScaled( + kernel, scale_diag=tf.math.sqrt(length_scales)) + return gaussian_process.GaussianProcess( + kernel, + index_points=x_obs, + observation_noise_variance=1., + ).log_prob(y_obs) + + lscales = tf.convert_to_tensor([11.67385626, 0.21246016, 0.0215677, + 0.08823962, 0.22416186, 0.06885594]) + + def _grad(lscales): + return gradient.value_and_gradient(loss, lscales)[1] + + self.assertAllClose(_grad(lscales), + tf.function(_grad, jit_compile=True)(lscales), + atol=0.01) + @test_util.test_all_tf_execution_regimes class GaussianProcessDynamicTest(_GaussianProcessTest, test_util.TestCase): diff --git a/tensorflow_probability/python/math/psd_kernels/internal/util.py b/tensorflow_probability/python/math/psd_kernels/internal/util.py index 039be12489..ec4bd2809c 100644 --- a/tensorflow_probability/python/math/psd_kernels/internal/util.py +++ b/tensorflow_probability/python/math/psd_kernels/internal/util.py @@ -224,15 +224,31 @@ def pairwise_square_distance_matrix(x1, x2, feature_ndims): row_norm_x2 = sum_rightmost_ndims_preserving_shape( tf.square(x2), feature_ndims)[..., tf.newaxis, :] - x1 = tf.reshape(x1, ps.concat( + reshaped_x1 = tf.reshape(x1, ps.concat( [ps.shape(x1)[:-feature_ndims], [ ps.reduce_prod(ps.shape(x1)[-feature_ndims:])]], axis=0)) - x2 = tf.reshape(x2, ps.concat( + reshaped_x2 = tf.reshape(x2, ps.concat( [ps.shape(x2)[:-feature_ndims], [ ps.reduce_prod(ps.shape(x2)[-feature_ndims:])]], axis=0)) pairwise_sq = row_norm_x1 + row_norm_x2 - 2 * tf.linalg.matmul( - x1, x2, transpose_b=True) + reshaped_x1, reshaped_x2, transpose_b=True) pairwise_sq = tf.clip_by_value(pairwise_sq, 0., np.inf) + + # If we statically know that `x1` and `x2` have the same number of examples, + # then we check if they are equal so that we can ensure that the diagonal + # distances are zero in this case. + num_examples1 = tf.compat.dimension_value(x1.shape[-feature_ndims - 1]) + num_examples2 = tf.compat.dimension_value(x2.shape[-feature_ndims - 1]) + if num_examples1 is not None and num_examples2 is not None: + if num_examples1 == num_examples2: + all_equal = tf.reduce_all( + tf.equal(x1, x2), axis=range(-1, -feature_ndims - 2, -1)) + eye = tf.eye(num_examples1, dtype=pairwise_sq.dtype) + pairwise_sq = tf.where( + all_equal[..., tf.newaxis, tf.newaxis] & (eye == 1.), + tf.zeros([], dtype=pairwise_sq.dtype), + pairwise_sq) + return pairwise_sq