Skip to content

Commit

Permalink
only treat () as scalar shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege committed Nov 9, 2020
1 parent 8092eed commit fbb6fe2
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,11 +943,9 @@ def _draw_value(param, point=None, givens=None, size=None):


def _is_one_d(dist_shape):
if hasattr(dist_shape, "dshape") and dist_shape.dshape in ((), (0,), (1,)):
if hasattr(dist_shape, "dshape") and dist_shape.dshape in { (), }:
return True
elif hasattr(dist_shape, "shape") and dist_shape.shape in ((), (0,), (1,)):
return True
elif to_tuple(dist_shape) == ():
elif hasattr(dist_shape, "shape") and dist_shape.shape in { (), }:
return True
return False

Expand Down Expand Up @@ -1069,6 +1067,7 @@ def generate_samples(generator, *args, **kwargs):
len(samples.shape) > len(dist_shape)
and samples.shape[-len(dist_shape) :] == dist_shape[-len(dist_shape) :]
):
raise ValueError(f"This SHOULD be unreachable code. DON'T MERGE UNTIL THIS ENTIRE BLOCK WAS REMOVED. {samples.shape}, {size_tup}")
samples = samples.reshape(samples.shape[1:])

if (
Expand All @@ -1077,5 +1076,6 @@ def generate_samples(generator, *args, **kwargs):
and samples.shape[-1] == 1
and (samples.shape != size_tup or size_tup == tuple() or size_tup == (1,))
):
raise ValueError(f"This SHOULD be unreachable code. DON'T MERGE UNTIL THIS ENTIRE BLOCK WAS REMOVED. {samples.shape}, {size_tup}")
samples = samples.reshape(samples.shape[:-1])
return np.asarray(samples)

0 comments on commit fbb6fe2

Please sign in to comment.