Skip to content

Commit

Permalink
Improve numerical stability of VonMises Sample / CDF gradients by cha…
Browse files Browse the repository at this point in the history
…nging (1 - cos(x)) calls to 2 * sin(x / 2)**2.

PiperOrigin-RevId: 471175247
  • Loading branch information
srvasude authored and tensorflower-gardener committed Aug 31, 2022
1 parent 6e57a53 commit 02e3297
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@
'SigmoidBeta': 5e-4,
'StudentT': 1e-5,
'TruncatedCauchy': 5e-5,
'VonMises': 2e-2, # TODO(b/160000258):
'VonMisesFisher': 5e-3,
'WishartTriL': 1e-5,
})
Expand Down
12 changes: 8 additions & 4 deletions tensorflow_probability/python/distributions/von_mises.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
__all__ = ['VonMises']


def cosxm1(x):
return -2 * tf.math.square(tf.math.sin(x / 2.))


class VonMises(distribution.AutoCompositeTensorDistribution):
"""The von Mises distribution over angles.
Expand Down Expand Up @@ -182,7 +186,7 @@ def _log_prob(self, x):

def _log_unnormalized_prob(self, x, loc, concentration):
z = self._z(x, loc=loc)
return concentration * (tf.cos(z) - 1)
return concentration * cosxm1(z)

def _prob(self, x):
concentration = tf.convert_to_tensor(self.concentration)
Expand Down Expand Up @@ -401,7 +405,7 @@ def von_mises_cdf(x, concentration):
dcdf_dconcentration_normal)

def grad(dy):
prob = tf.exp(concentration * (tf.cos(x) - 1.)) / (
prob = tf.exp(concentration * cosxm1(x)) / (
(2. * np.pi) * tf.math.bessel_i0e(concentration))
return dy * prob, dy * dcdf_dconcentration

Expand Down Expand Up @@ -584,7 +588,7 @@ def _von_mises_sample_bwd(_, aux, dy):
broadcast_concentration = tf.broadcast_to(concentration, ps.shape(samples))
_, dcdf_dconcentration = value_and_gradient(
lambda conc: von_mises_cdf(samples, conc), broadcast_concentration)
inv_prob = tf.exp(-broadcast_concentration * (tf.cos(samples) - 1.)) * (
inv_prob = tf.exp(-broadcast_concentration * cosxm1(samples)) * (
(2. * np.pi) * tf.math.bessel_i0e(broadcast_concentration))
# Compute the implicit reparameterization gradient [2],
# dz/dconc = -(dF(z; conc) / dconc) / p(z; conc)
Expand All @@ -611,7 +615,7 @@ def _von_mises_sample_jvp(shape, primals, tangents):

_, dcdf_dconcentration = value_and_gradient(
lambda conc: von_mises_cdf(samples, conc), broadcast_concentration)
inv_prob = tf.exp(-concentration * (tf.cos(samples) - 1.)) * (
inv_prob = tf.exp(-concentration * cosxm1(samples)) * (
(2. * np.pi) * tf.math.bessel_i0e(concentration))
# Compute the implicit derivative,
# dz = dconc * -(dF(z; conc) / dconc) / p(z; conc)
Expand Down
16 changes: 10 additions & 6 deletions tensorflow_probability/python/distributions/von_mises_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def testVonMisesPdfUniform(self):
expected_prob = np.array([1. / (2. * np.pi)] * 6)
self.assertAllClose(expected_prob, self.evaluate(prob))

@test_util.disable_test_for_backend(
disable_numpy=True, reason='CDF computation uses autograd')
def testVonMisesCdf(self):
locs_v = np.reshape(np.linspace(-10., 10., 20), [-1, 1, 1])
concentrations_v = np.reshape(np.logspace(-3., 3., 20), [1, -1, 1])
Expand All @@ -126,6 +128,8 @@ def testVonMisesCdf(self):
expected_cdf = sp_stats.vonmises.cdf(x, concentrations_v, loc=locs_v)
self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=1e-4, rtol=1e-4)

@test_util.disable_test_for_backend(
disable_numpy=True, reason='CDF computation uses autograd')
def testVonMisesCdfUniform(self):
x = np.linspace(-np.pi, np.pi, 20)
dist = von_mises.VonMises(
Expand All @@ -151,13 +155,13 @@ def testVonMisesCdfGradient(self):
cdf = dist.cdf(x)

self.assertLess(
tf1.test.compute_gradient_error(x, x.shape, cdf, cdf.shape), 1e-3)
tf1.test.compute_gradient_error(x, x.shape, cdf, cdf.shape), 1e-4)
self.assertLess(
tf1.test.compute_gradient_error(locs, locs.shape, cdf, cdf.shape),
1e-3)
1e-4)
self.assertLess(
tf1.test.compute_gradient_error(concentrations, concentrations.shape,
cdf, cdf.shape), 1e-3)
cdf, cdf.shape), 1e-4)

@test_util.numpy_disable_gradient_test
def testVonMisesCdfGradientSimple(self):
Expand Down Expand Up @@ -185,10 +189,10 @@ def testVonMisesCdfGradientSimple(self):
- von_mises.VonMises(loc, concentration,
validate_args=True).cdf(x - eps)) / (2 * eps))

self.assertAlmostEqual(dcdf_dloc, dcdf_dloc_diff, places=3)
self.assertAlmostEqual(dcdf_dloc, dcdf_dloc_diff, places=4)
self.assertAlmostEqual(
dcdf_dconcentration, dcdf_dconcentration_diff, places=3)
self.assertAlmostEqual(dcdf_dx, dcdf_dx_diff, places=3)
dcdf_dconcentration, dcdf_dconcentration_diff, places=4)
self.assertAlmostEqual(dcdf_dx, dcdf_dx_diff, places=4)

def testVonMisesEntropy(self):
locs_v = np.array([-2., -1., 0.3, 3.2]).reshape([-1, 1])
Expand Down

0 comments on commit 02e3297

Please sign in to comment.