Skip to content

Commit

Permalink
Remove intX and floatX from distributions (#7114)
Browse files Browse the repository at this point in the history
  • Loading branch information
aerubanov authored Feb 17, 2024
1 parent 0d8ddba commit 3a304d6
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 146 deletions.
18 changes: 8 additions & 10 deletions pymc/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pymc.distributions.transforms import _default_transform
from pymc.logprob.basic import logp
from pymc.model import modelcontext
from pymc.pytensorf import floatX, intX
from pymc.util import check_dist_not_registered

__all__ = ["Bound"]
Expand Down Expand Up @@ -206,7 +205,7 @@ def __new__(
res = _ContinuousBounded(
name,
[dist, lower, upper],
initval=floatX(initval),
initval=initval.astype("float"),
size=size,
shape=shape,
**kwargs,
Expand All @@ -215,7 +214,7 @@ def __new__(
res = _DiscreteBounded(
name,
[dist, lower, upper],
initval=intX(initval),
initval=initval.astype("int"),
size=size,
shape=shape,
**kwargs,
Expand All @@ -241,15 +240,15 @@ def dist(
shape=shape,
**kwargs,
)
res.tag.test_value = floatX(initval)
res.tag.test_value = initval
else:
res = _DiscreteBounded.dist(
[dist, lower, upper],
size=size,
shape=shape,
**kwargs,
)
res.tag.test_value = intX(initval)
res.tag.test_value = initval.astype("int")
return res

@classmethod
Expand Down Expand Up @@ -286,9 +285,9 @@ def _set_values(cls, lower, upper, size, shape, initval):
size = shape

lower = np.asarray(lower)
lower = floatX(np.where(lower == None, -np.inf, lower)) # noqa E711
lower = np.where(lower == None, -np.inf, lower) # noqa E711
upper = np.asarray(upper)
upper = floatX(np.where(upper == None, np.inf, upper)) # noqa E711
upper = np.where(upper == None, np.inf, upper) # noqa E711

if initval is None:
_size = np.broadcast_shapes(to_tuple(size), np.shape(lower), np.shape(upper))
Expand All @@ -303,7 +302,6 @@ def _set_values(cls, lower, upper, size, shape, initval):
np.where(_upper == np.inf, _lower + 1, (_lower + _upper) / 2),
),
)

lower = as_tensor_variable(floatX(lower))
upper = as_tensor_variable(floatX(upper))
lower = as_tensor_variable(lower, dtype="floatX")
upper = as_tensor_variable(upper, dtype="floatX")
return lower, upper, initval
Loading

0 comments on commit 3a304d6

Please sign in to comment.