From 3aa27532e00212662012cacd57993576b071e453 Mon Sep 17 00:00:00 2001 From: Jacob Burnim Date: Tue, 13 Jul 2021 09:35:01 -0700 Subject: [PATCH] Add BUILD tag for TFP-compatibility tests. PiperOrigin-RevId: 384483208 --- distrax/_src/bijectors/bijector.py | 1 + distrax/_src/distributions/distribution.py | 1 + 2 files changed, 2 insertions(+) diff --git a/distrax/_src/bijectors/bijector.py b/distrax/_src/bijectors/bijector.py index 05b76a2..57e1230 100644 --- a/distrax/_src/bijectors/bijector.py +++ b/distrax/_src/bijectors/bijector.py @@ -94,6 +94,7 @@ def __init__(self, Only set to True if you're absoltely sure the Jacobian determinant is constant; if you're not sure, set to None. """ + if event_ndims_out is None: event_ndims_out = event_ndims_in if event_ndims_in < 0: diff --git a/distrax/_src/distributions/distribution.py b/distrax/_src/distributions/distribution.py index f7d9fd6..6f0b15d 100644 --- a/distrax/_src/distributions/distribution.py +++ b/distrax/_src/distributions/distribution.py @@ -125,6 +125,7 @@ def sample(self, Returns: A sample of shape `sample_shape` + `batch_shape` + `event_shape`. """ + rng, sample_shape = convert_seed_and_sample_shape(seed, sample_shape) num_samples = functools.reduce(operator.mul, sample_shape, 1) # product