Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update chex.assert_type to check concrete types instead of just asserting that the type is a floating/integer sub-type. #270

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions distrax/_src/distributions/deterministic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions import deterministic
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -107,10 +108,11 @@ def test_sample_shape(self, loc, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(loc=jnp.zeros((), dtype=dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(loc=jnp.zeros((), dtype=dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/epsilon_greedy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import chex
from distrax._src.distributions import epsilon_greedy
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -51,11 +52,12 @@ def test_num_categories(self):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(
preferences=self.preferences, epsilon=self.epsilon, dtype=dtype)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(
preferences=self.preferences, epsilon=self.epsilon, dtype=dtype)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

def test_jittable(self):
super()._test_jittable(
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/gamma_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions import gamma
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -73,11 +74,12 @@ def test_sample_shape(self, distr_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(
concentration=jnp.ones((), dtype), rate=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(
concentration=jnp.ones((), dtype), rate=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
10 changes: 6 additions & 4 deletions distrax/_src/distributions/greedy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions import greedy
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -48,10 +49,11 @@ def test_num_categories(self):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(preferences=self.preferences, dtype=dtype)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(preferences=self.preferences, dtype=dtype)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

def test_jittable(self):
super()._test_jittable((np.array([0., 4., -1., 4.]),))
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/gumbel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions import gumbel
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -67,11 +68,12 @@ def test_sample_shape(self, distr_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/laplace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions import laplace
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -65,11 +66,12 @@ def test_sample_shape(self, distr_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/log_stddev_normal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from distrax._src.distributions import log_stddev_normal as lsn
from distrax._src.distributions import normal
import jax
import jax.experimental
import jax.numpy as jnp
import mock
import numpy as np
Expand Down Expand Up @@ -105,11 +106,12 @@ def test_sampling_batched_custom_dim(self):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = lsn.LogStddevNormal(
loc=jnp.zeros((), dtype), log_scale=jnp.zeros((), dtype))
samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0))
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = lsn.LogStddevNormal(
loc=jnp.zeros((), dtype), log_scale=jnp.zeros((), dtype))
samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0))
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

def test_kl_versus_normal(self):
loc, scale = jnp.array([2.0]), jnp.array([2.0])
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/logistic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions import logistic
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -66,11 +67,12 @@ def test_sample_shape(self, distr_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(
loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype))
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
17 changes: 11 additions & 6 deletions distrax/_src/distributions/multinomial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from distrax._src.utils import equivalence
from distrax._src.utils import math
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
from scipy import stats
Expand Down Expand Up @@ -405,12 +406,16 @@ def test_sample_and_log_prob(self, dist_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist_params = {
'logits': self.logits, 'dtype': dtype, 'total_count': self.total_count}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist_params = {
'logits': self.logits,
'dtype': dtype,
'total_count': self.total_count,
}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
def test_sample_extreme_probs(self):
Expand Down
16 changes: 9 additions & 7 deletions distrax/_src/distributions/mvn_diag_plus_low_rank_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from distrax._src.utils import equivalence

import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
from tensorflow_probability.substrates import jax as tfp
Expand Down Expand Up @@ -180,13 +181,14 @@ def test_sample_shape(self, sample_shape, loc_shape, scale_diag_shape,
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist_params = {
'loc': np.array([0., 0.], dtype),
'scale_diag': np.array([1., 1.], dtype)}
dist = MultivariateNormalDiagPlusLowRank(**dist_params)
samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0))
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist_params = {
'loc': np.array([0., 0.], dtype),
'scale_diag': np.array([1., 1.], dtype)}
dist = MultivariateNormalDiagPlusLowRank(**dist_params)
samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0))
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
16 changes: 9 additions & 7 deletions distrax/_src/distributions/mvn_diag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from distrax._src.distributions import normal
from distrax._src.utils import equivalence
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -214,13 +215,14 @@ def test_sample_and_log_prob(self, distr_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist_params = {
'loc': np.array([0., 0.], dtype),
'scale_diag': np.array([1., 1.], dtype)}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist_params = {
'loc': np.array([0., 0.], dtype),
'scale_diag': np.array([1., 1.], dtype)}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
16 changes: 9 additions & 7 deletions distrax/_src/distributions/mvn_full_covariance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions.mvn_full_covariance import MultivariateNormalFullCovariance
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -106,13 +107,14 @@ def test_sample_shape(self, sample_shape, loc_shape, covariance_matrix_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist_params = {
'loc': np.array([0., 0.], dtype),
'covariance_matrix': np.array([[1., 0.], [0., 1.]], dtype)}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist_params = {
'loc': np.array([0., 0.], dtype),
'covariance_matrix': np.array([[1., 0.], [0., 1.]], dtype)}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
16 changes: 9 additions & 7 deletions distrax/_src/distributions/mvn_tri_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import chex
from distrax._src.distributions.mvn_tri import MultivariateNormalTri
from distrax._src.utils import equivalence
import jax.experimental
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -114,13 +115,14 @@ def test_sample_shape(self, sample_shape, loc_shape, scale_tri_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist_params = {
'loc': np.array([0., 0.], dtype),
'scale_tri': np.array([[1., 0.], [0., 1.]], dtype)}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist_params = {
'loc': np.array([0., 0.], dtype),
'scale_tri': np.array([[1., 0.], [0., 1.]], dtype)}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
12 changes: 7 additions & 5 deletions distrax/_src/distributions/one_hot_categorical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from distrax._src.utils import equivalence
from distrax._src.utils import math
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
import scipy
Expand Down Expand Up @@ -178,11 +179,12 @@ def test_sample_and_log_prob(self, distr_params, sample_shape):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist_params = {'logits': self.logits, 'dtype': dtype}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist_params = {'logits': self.logits, 'dtype': dtype}
dist = self.distrax_cls(**dist_params)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

@chex.all_variants
@parameterized.named_parameters(
Expand Down
11 changes: 6 additions & 5 deletions distrax/_src/distributions/softmax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,12 @@ def test_parameters(self):
('float32', jnp.float32),
('float64', jnp.float64))
def test_sample_dtype(self, dtype):
dist = self.distrax_cls(
logits=self.logits, temperature=self.temperature, dtype=dtype)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)
with jax.experimental.enable_x64(dtype.dtype.itemsize == 8):
dist = self.distrax_cls(
logits=self.logits, temperature=self.temperature, dtype=dtype)
samples = self.variant(dist.sample)(seed=self.key)
self.assertEqual(samples.dtype, dist.dtype)
chex.assert_type(samples, dtype)

def test_jittable(self):
super()._test_jittable((np.array([2., 4., 1., 3.]),))
Expand Down
Loading