diff --git a/discussion/adaptive_malt/adaptive_malt.py b/discussion/adaptive_malt/adaptive_malt.py index a6edb0b0e9..3952b04d09 100644 --- a/discussion/adaptive_malt/adaptive_malt.py +++ b/discussion/adaptive_malt/adaptive_malt.py @@ -350,7 +350,7 @@ def adaptive_mcmc_step( target_log_prob_fn: fun_mc.PotentialFn, num_mala_steps: int, num_adaptation_steps: int, - seed: jax.random.KeyArray, + seed: jax.Array, method: str = 'hmc', damping: Optional[jnp.ndarray] = None, scalar_step_size: Optional[jnp.ndarray] = None, @@ -778,7 +778,7 @@ def adaptive_nuts_step( target_log_prob_fn: fun_mc.PotentialFn, num_mala_steps: int, num_adaptation_steps: int, - seed: jax.random.KeyArray, + seed: jax.Array, scalar_step_size: Optional[jnp.ndarray] = None, vector_step_size: Optional[jnp.ndarray] = None, rvar_factor: int = 8, @@ -1040,7 +1040,7 @@ class MeadsExtra(NamedTuple): def meads_init(state: jnp.ndarray, target_log_prob_fn: fun_mc.PotentialFn, - num_folds: int, seed: jax.random.KeyArray): + num_folds: int, seed: jax.Array): """Initializes MEADS.""" num_dimensions = state.shape[-1] num_chains = state.shape[0] @@ -1062,7 +1062,7 @@ def meads_init(state: jnp.ndarray, target_log_prob_fn: fun_mc.PotentialFn, def meads_step(meads_state: MeadsState, target_log_prob_fn: fun_mc.PotentialFn, - seed: jax.random.KeyArray, + seed: jax.Array, vector_step_size: Optional[jnp.ndarray] = None, damping: Optional[jnp.ndarray] = None, step_size_multiplier: float = 0.5, @@ -1221,7 +1221,7 @@ def run_adaptive_mcmc_on_target( init_step_size: jnp.ndarray, num_adaptation_steps: int, num_results: int, - seed: jax.random.KeyArray, + seed: jax.Array, num_mala_steps: int = 100, rvar_smoothing: int = 0, trajectory_opt_kwargs: Mapping[str, Any] = immutabledict.immutabledict({ @@ -1358,7 +1358,7 @@ def run_adaptive_nuts_on_target( init_step_size: jnp.ndarray, num_adaptation_steps: int, num_results: int, - seed: jax.random.KeyArray, + seed: jax.Array, num_mala_steps: int = 100, rvar_smoothing: int = 0, num_chains: Optional[int] = None, @@ -1478,7 +1478,7 @@ def run_meads_on_target( num_adaptation_steps: int, num_results: int, thinning: int, - seed: jax.random.KeyArray, + seed: jax.Array, num_folds: int, num_chains: Optional[int] = None, init_x: Optional[jnp.ndarray] = None, @@ -1596,7 +1596,7 @@ def run_fixed_mcmc_on_target( target: gym.targets.Model, init_x: jnp.ndarray, method: str, - seed: jax.random.KeyArray, + seed: jax.Array, num_warmup_steps: int, num_results: int, scalar_step_size: jnp.ndarray, @@ -1706,7 +1706,7 @@ def run_vi_on_target( init_x: jnp.ndarray, num_steps: int, learning_rate: float, - seed: jax.random.KeyArray, + seed: jax.Array, ): """Run VI on a target. diff --git a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py index f017b19f86..bb0805c4d5 100644 --- a/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py +++ b/spinoffs/fun_mc/fun_mc/dynamic/backend_jax/util.py @@ -97,7 +97,9 @@ def make_tensor_seed(seed): """Converts a seed to a `Tensor` seed.""" if seed is None: raise ValueError('seed must not be None when using JAX') - if isinstance(seed, jax.random.PRNGKeyArray): + if hasattr(seed, 'dtype') and jax.dtypes.issubdtype( + seed.dtype, jax.dtypes.prng_key + ): return seed return jnp.asarray(seed, jnp.uint32) diff --git a/tensorflow_probability/python/internal/backend/numpy/ops.py b/tensorflow_probability/python/internal/backend/numpy/ops.py index 0cb8cc9ddb..2d396f5ef4 100644 --- a/tensorflow_probability/python/internal/backend/numpy/ops.py +++ b/tensorflow_probability/python/internal/backend/numpy/ops.py @@ -218,10 +218,14 @@ def _default_convert_to_tensor(value, dtype=None): """Default tensor conversion function for array, bool, int, float, and complex.""" if JAX_MODE: # TODO(b/223267515): We shouldn't need to specialize here. - if 'PRNGKeyArray' in str(type(value)): + if hasattr(value, 'dtype') and jax.dtypes.issubdtype( + value.dtype, jax.dtypes.prng_key + ): return value if isinstance(value, (list, tuple)) and value: - if 'PRNGKeyArray' in str(type(value[0])): + if hasattr(value[0], 'dtype') and jax.dtypes.issubdtype( + value[0].dtype, jax.dtypes.prng_key + ): return np.stack(value, axis=0) inferred_dtype = _infer_dtype(value, np.float32) diff --git a/tensorflow_probability/python/internal/loop_util.py b/tensorflow_probability/python/internal/loop_util.py index 4eaa41ae0b..f7272c62b1 100644 --- a/tensorflow_probability/python/internal/loop_util.py +++ b/tensorflow_probability/python/internal/loop_util.py @@ -52,8 +52,8 @@ def _convert_variables_to_tensors(values): def tensor_array_from_element(elem, size=None, **kwargs): """Construct a tf.TensorArray of elements with the dtype + shape of `elem`.""" - if JAX_MODE and isinstance(elem, jax.random.PRNGKeyArray): - # If `trace_elt` is a `PRNGKeyArray`, then then it is not possible to create + if JAX_MODE and jax.dtypes.issubdtype(elem.dtype, jax.dtypes.prng_key): + # If `trace_elt` is a typed prng key, then then it is not possible to create # a matching (i.e., with the same custom PRNG) instance/array inside # `TensorArray.__init__` given just a `dtype`, `size`, and `shape`. # diff --git a/tensorflow_probability/python/internal/test_util.py b/tensorflow_probability/python/internal/test_util.py index 9af934ed5f..1058376766 100644 --- a/tensorflow_probability/python/internal/test_util.py +++ b/tensorflow_probability/python/internal/test_util.py @@ -163,8 +163,12 @@ def evaluate(self, x): def _evaluate(x): if x is None: return x - # TODO(b/223267515): Improve handling of JAX PRNGKeyArray objects. - if JAX_MODE and isinstance(x, jax.random.PRNGKeyArray): + # TODO(b/223267515): Improve handling of JAX typed PRNG keys. + if ( + JAX_MODE + and hasattr(x, 'dtype') + and jax.dtypes.issubdtype(x.dtype, jax.dtypes.prng_key) + ): return x return np.array(x) return tf.nest.map_structure(_evaluate, x, expand_composites=True) @@ -177,11 +181,15 @@ def _GetNdArray(self, a): def _evaluateTensors(self, a, b): if JAX_MODE: import jax # pylint: disable=g-import-not-at-top - # HACK: In assertions (like self.assertAllClose), convert PRNGKeyArrays - # to "normal" arrays so they can be compared with our existing machinery. - if isinstance(a, jax.random.PRNGKeyArray): + # HACK: In assertions (like self.assertAllClose), convert typed PRNG keys + # to raw arrays so they can be compared with our existing machinery. + if hasattr(a, 'dtype') and jax.dtypes.issubdtype( + a.dtype, jax.dtypes.prng_key + ): a = jax.random.key_data(a) - if isinstance(b, jax.random.PRNGKeyArray): + if hasattr(b, 'dtype') and jax.dtypes.issubdtype( + b.dtype, jax.dtypes.prng_key + ): b = jax.random.key_data(b) if tf.is_tensor(a) and tf.is_tensor(b): (a, b) = self.evaluate([a, b])