From fbb6fe2a365bd4c0dfd25646ccd3f3422ae1b971 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Mon, 9 Nov 2020 17:00:37 +0100 Subject: [PATCH] only treat () as scalar shapes closes #4206 --- pymc3/distributions/distribution.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index a32031ebb8d..2799049ea34 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -943,11 +943,9 @@ def _draw_value(param, point=None, givens=None, size=None): def _is_one_d(dist_shape): - if hasattr(dist_shape, "dshape") and dist_shape.dshape in ((), (0,), (1,)): + if hasattr(dist_shape, "dshape") and dist_shape.dshape in { (), }: return True - elif hasattr(dist_shape, "shape") and dist_shape.shape in ((), (0,), (1,)): - return True - elif to_tuple(dist_shape) == (): + elif hasattr(dist_shape, "shape") and dist_shape.shape in { (), }: return True return False @@ -1069,6 +1067,7 @@ def generate_samples(generator, *args, **kwargs): len(samples.shape) > len(dist_shape) and samples.shape[-len(dist_shape) :] == dist_shape[-len(dist_shape) :] ): + raise ValueError(f"This SHOULD be unreachable code. DON'T MERGE UNTIL THIS ENTIRE BLOCK WAS REMOVED. {samples.shape}, {size_tup}") samples = samples.reshape(samples.shape[1:]) if ( @@ -1077,5 +1076,6 @@ def generate_samples(generator, *args, **kwargs): and samples.shape[-1] == 1 and (samples.shape != size_tup or size_tup == tuple() or size_tup == (1,)) ): + raise ValueError(f"This SHOULD be unreachable code. DON'T MERGE UNTIL THIS ENTIRE BLOCK WAS REMOVED. {samples.shape}, {size_tup}") samples = samples.reshape(samples.shape[:-1]) return np.asarray(samples)