Skip to content

Commit

Permalink
Remove automatic normalization in Categorical and Multinomial distrib…
Browse files Browse the repository at this point in the history
…utions

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
  • Loading branch information
Luke LB and ricardoV94 committed Jan 26, 2022
1 parent eb5177a commit d295f3b
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 20 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
- `math.log1mexp` and `math.log1mexp_numpy` will expect negative inputs in the future. A `FutureWarning` is now raised unless `negative_input=True` is set (see [#4860](https://github.com/pymc-devs/pymc/pull/4860)).
- Changed name of `Lognormal` distribution to `LogNormal` to harmonize CamelCase usage for distribution names.
- Attempt to iterate over MultiTrace will raise NotImplementedError.
- Removed silent normalisation of `p` parameters in Categorical and Multinomial distributions (see [#5370](https://github.com/pymc-devs/pymc/pull/5370)).
- ...


Expand Down
14 changes: 12 additions & 2 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

import aesara.tensor as at
import numpy as np

Expand Down Expand Up @@ -1233,7 +1235,16 @@ class Categorical(Discrete):

@classmethod
def dist(cls, p, **kwargs):

if isinstance(p, np.ndarray) or isinstance(p, list):
if (np.asarray(p) < 0).any():
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")
p_sum = np.sum([p], axis=-1)
if not np.all(np.isclose(p_sum, 1.0)):
warnings.warn(
f"`p` parameters sum to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.",
UserWarning,
)
p = p / at.sum(p, axis=-1, keepdims=True)
p = at.as_tensor_variable(floatX(p))
return super().dist([p], **kwargs)

Expand All @@ -1256,7 +1267,6 @@ def logp(value, p):
"""
k = at.shape(p)[-1]
p_ = p
p = p_ / at.sum(p_, axis=-1, keepdims=True)
value_clip = at.clip(value, 0, k - 1)

if p.ndim > 1:
Expand Down
14 changes: 11 additions & 3 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,18 @@ class Multinomial(Discrete):

@classmethod
def dist(cls, n, p, *args, **kwargs):
p = p / at.sum(p, axis=-1, keepdims=True)
if isinstance(p, np.ndarray) or isinstance(p, list):
if (np.asarray(p) < 0).any():
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")
p_sum = np.sum([p], axis=-1)
if not np.all(np.isclose(p_sum, 1.0)):
warnings.warn(
f"`p` parameters sum up to {p_sum}, instead of 1.0. They will be automatically rescaled. You can rescale them directly to get rid of this warning.",
UserWarning,
)
p = p / at.sum(p, axis=-1, keepdims=True)
n = at.as_tensor_variable(n)
p = at.as_tensor_variable(p)

return super().dist([n, p], *args, **kwargs)

def get_moment(rv, size, n, p):
Expand Down Expand Up @@ -582,7 +590,7 @@ def logp(value, n, p):
return check_parameters(
res,
p <= 1,
at.eq(at.sum(p, axis=-1), 1),
at.isclose(at.sum(p, axis=-1), 1),
at.ge(n, 0),
msg="p <= 1, sum(p) = 1, n >= 0",
)
Expand Down
72 changes: 57 additions & 15 deletions pymc/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2189,19 +2189,43 @@ def test_multinomial(self, n):
lambda value, n, p: scipy.stats.multinomial.logpmf(value, n, p),
)

def test_multinomial_invalid(self):
# Test non-scalar invalid parameters/values
value = np.array([[1, 2, 2], [4, 0, 1]])
def test_multinomial_invalid_value(self):
# Test passing non-scalar invalid parameters/values to an otherwise valid Multinomial,
# evaluates to -inf
value = np.array([[1, 2, 2], [3, -1, 0]])
valid_dist = Multinomial.dist(n=5, p=np.ones(3) / 3)
assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))

invalid_dist = Multinomial.dist(n=5, p=[-1, 1, 1], size=2)
# TODO: Multinomial normalizes p, so it is impossible to trigger p checks
# with pytest.raises(ParameterValueError):
with does_not_raise():
def test_multinomial_negative_p(self):
# test passing a list/numpy with negative p raises an immediate error
with pytest.raises(ValueError, match="[-1, 1, 1]"):
with Model() as model:
x = Multinomial("x", n=5, p=[-1, 1, 1])

def test_multinomial_p_not_normalized(self):
# test UserWarning is raised for p vals that sum to more than 1
# and normaliation is triggered
with pytest.warns(UserWarning, match="[5]"):
with pm.Model() as m:
x = pm.Multinomial("x", n=5, p=[1, 1, 1, 1, 1])
# test stored p-vals have been normalised
assert np.isclose(m.x.owner.inputs[4].sum().eval(), 1.0)

def test_multinomial_negative_p_symbolic(self):
# Passing symbolic negative p does not raise an immediate error, but evaluating
# logp raises a ParameterValueError
with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([-1, 0.5, 0.5]))
pm.logp(invalid_dist, value).eval()

value[1] -= 1
valid_dist = Multinomial.dist(n=5, p=np.ones(3) / 3)
assert np.all(np.isfinite(pm.logp(valid_dist, value).eval()) == np.array([True, False]))
def test_multinomial_p_not_normalized_symbolic(self):
# Passing symbolic p that do not add up to on does not raise any warning, but evaluating
# logp raises a ParameterValueError
with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Multinomial.dist(n=1, p=at.as_tensor_variable([1, 0.5, 0.5]))
pm.logp(invalid_dist, value).eval()

@pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -2317,12 +2341,22 @@ def test_categorical_bounds(self):
np.array([-1, -1, 0, 0]),
],
)
def test_categorical_valid_p(self, p):
with Model():
x = Categorical("x", p=p)
def test_categorical_negative_p(self, p):
with pytest.raises(ValueError, match=f"{p}"):
with Model():
x = Categorical("x", p=p)

with pytest.raises(ParameterValueError):
logp(x, 2).eval()
def test_categorical_negative_p_symbolic(self):
with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([-1, 0.5, 0.5]))
pm.logp(invalid_dist, value).eval()

def test_categorical_p_not_normalized_symbolic(self):
with pytest.raises(ParameterValueError):
value = np.array([[1, 1, 1]])
invalid_dist = pm.Categorical.dist(p=at.as_tensor_variable([2, 2, 2]))
pm.logp(invalid_dist, value).eval()

@pytest.mark.parametrize("n", [2, 3, 4])
def test_categorical(self, n):
Expand All @@ -2333,6 +2367,14 @@ def test_categorical(self, n):
lambda value, p: categorical_logpdf(value, p),
)

def test_categorical_p_not_normalized(self):
# test UserWarning is raised for p vals that sum to more than 1
# and normaliation is triggered
with pytest.warns(UserWarning, match="[5]"):
with pm.Model() as m:
x = pm.Categorical("x", p=[1, 1, 1, 1, 1])
assert np.isclose(m.x.owner.inputs[3].sum().eval(), 1.0)

@pytest.mark.parametrize("n", [2, 3, 4])
def test_orderedlogistic(self, n):
self.check_logp(
Expand Down
1 change: 1 addition & 0 deletions pymc/tests/test_idata_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ def test_multivariate_observations(self):
data = np.random.multinomial(20, [0.2, 0.3, 0.5], size=20)
with pm.Model(coords=coords):
p = pm.Beta("p", 1, 1, size=(3,))
p = p / p.sum()
pm.Multinomial("y", 20, p, dims=("experiment", "direction"), observed=data)
idata = pm.sample(draws=50, chains=2, tune=100, return_inferencedata=True)
test_dict = {
Expand Down

0 comments on commit d295f3b

Please sign in to comment.