Skip to content

Commit

Permalink
Fix tests for random.categorical with multi-dimensional logits (#2955)
Browse files Browse the repository at this point in the history
  • Loading branch information
juliuskunze authored May 5, 2020
1 parent 7116cc5 commit e4d8cac
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,34 +246,33 @@ def testBernoulli(self, p, dtype):
for (p, axis) in [
([.25] * 4, -1),
([.1, .2, .3, .4], -1),
([[.25, .25], [.1, .9]], 1),
([[.25, .1], [.25, .9]], 0),
([[.5, .5], [.1, .9]], 1),
([[.5, .1], [.5, .9]], 0),
]
for sample_shape in [(10000,), (5000, 2)]
for dtype in [onp.float32, onp.float64]))
def testCategorical(self, p, axis, dtype, sample_shape):
key = random.PRNGKey(0)
p = onp.array(p, dtype=dtype)
logits = onp.log(p) - 42 # test unnormalized
shape = sample_shape + tuple(onp.delete(logits.shape, axis))
out_shape = tuple(onp.delete(logits.shape, axis))
shape = sample_shape + out_shape
rand = lambda key, p: random.categorical(key, logits, shape=shape, axis=axis)
crand = api.jit(rand)

uncompiled_samples = rand(key, p)
compiled_samples = crand(key, p)

if p.ndim > 1:
self.skipTest("multi-dimensional categorical tests are currently broken!")
if axis < 0:
axis += len(logits.shape)

for samples in [uncompiled_samples, compiled_samples]:
if axis < 0:
axis += len(logits.shape)

assert samples.shape == shape

samples = np.reshape(samples, (10000,) + out_shape)
if len(p.shape[:-1]) > 0:
for cat_index, p_ in enumerate(p):
self._CheckChiSquared(samples[:, cat_index], pmf=lambda x: p_[x])
ps = onp.transpose(p, (1, 0)) if axis == 0 else p
for cat_samples, cat_p in zip(samples.transpose(), ps):
self._CheckChiSquared(cat_samples, pmf=lambda x: cat_p[x])
else:
self._CheckChiSquared(samples, pmf=lambda x: p[x])

Expand Down

0 comments on commit e4d8cac

Please sign in to comment.