Skip to content

Commit

Permalink
Update pairwise_square_distance_matrix(x, x) to always have zero di…
Browse files Browse the repository at this point in the history
…agonals.

PiperOrigin-RevId: 471394269
  • Loading branch information
jburnim authored and tensorflower-gardener committed Sep 1, 2022
1 parent 3d94661 commit de75e2d
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow_probability.python.distributions import normal
from tensorflow_probability.python.internal import tensorshape_util
from tensorflow_probability.python.internal import test_util
from tensorflow_probability.python.math import gradient
from tensorflow_probability.python.math import psd_kernels


Expand Down Expand Up @@ -450,6 +451,31 @@ def call_log_prob(d):
class GaussianProcessStaticTest(_GaussianProcessTest, test_util.TestCase):
is_static = True

@test_util.numpy_disable_gradient_test
def test_gradient(self):
x_obs = normal.Normal(0., 1.).sample((10, 6), seed=test_util.test_seed())
y_obs = tf.reduce_sum(x_obs, axis=-1)

def loss(length_scales):
kernel = psd_kernels.MaternFiveHalves(amplitude=tf.math.sqrt(1e-2))
kernel = psd_kernels.FeatureScaled(
kernel, scale_diag=tf.math.sqrt(length_scales))
return gaussian_process.GaussianProcess(
kernel,
index_points=x_obs,
observation_noise_variance=1.,
).log_prob(y_obs)

lscales = tf.convert_to_tensor([11.67385626, 0.21246016, 0.0215677,
0.08823962, 0.22416186, 0.06885594])

def _grad(lscales):
return gradient.value_and_gradient(loss, lscales)[1]

self.assertAllClose(_grad(lscales),
tf.function(_grad, jit_compile=True)(lscales),
atol=0.01)


@test_util.test_all_tf_execution_regimes
class GaussianProcessDynamicTest(_GaussianProcessTest, test_util.TestCase):
Expand Down
22 changes: 19 additions & 3 deletions tensorflow_probability/python/math/psd_kernels/internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,31 @@ def pairwise_square_distance_matrix(x1, x2, feature_ndims):
row_norm_x2 = sum_rightmost_ndims_preserving_shape(
tf.square(x2), feature_ndims)[..., tf.newaxis, :]

x1 = tf.reshape(x1, ps.concat(
reshaped_x1 = tf.reshape(x1, ps.concat(
[ps.shape(x1)[:-feature_ndims], [
ps.reduce_prod(ps.shape(x1)[-feature_ndims:])]], axis=0))
x2 = tf.reshape(x2, ps.concat(
reshaped_x2 = tf.reshape(x2, ps.concat(
[ps.shape(x2)[:-feature_ndims], [
ps.reduce_prod(ps.shape(x2)[-feature_ndims:])]], axis=0))
pairwise_sq = row_norm_x1 + row_norm_x2 - 2 * tf.linalg.matmul(
x1, x2, transpose_b=True)
reshaped_x1, reshaped_x2, transpose_b=True)
pairwise_sq = tf.clip_by_value(pairwise_sq, 0., np.inf)

# If we statically know that `x1` and `x2` have the same number of examples,
# then we check if they are equal so that we can ensure that the diagonal
# distances are zero in this case.
num_examples1 = tf.compat.dimension_value(x1.shape[-feature_ndims - 1])
num_examples2 = tf.compat.dimension_value(x2.shape[-feature_ndims - 1])
if num_examples1 is not None and num_examples2 is not None:
if num_examples1 == num_examples2:
all_equal = tf.reduce_all(
tf.equal(x1, x2), axis=range(-1, -feature_ndims - 2, -1))
eye = tf.eye(num_examples1, dtype=pairwise_sq.dtype)
pairwise_sq = tf.where(
all_equal[..., tf.newaxis, tf.newaxis] & (eye == 1.),
tf.zeros([], dtype=pairwise_sq.dtype),
pairwise_sq)

return pairwise_sq


Expand Down

0 comments on commit de75e2d

Please sign in to comment.