Skip to content

Commit

Permalink
add test_shape_inputs for _OrderedLogistic
Browse files Browse the repository at this point in the history
  • Loading branch information
danhphan committed Feb 4, 2022
1 parent 32f6c89 commit 466a941
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,6 +1688,26 @@ class TestOrderedLogistic(BaseTestDistributionRandom):
"check_rv_size",
]

@pytest.mark.parametrize(
"eta, cutpoints, expected",
[
(0, [-2.0, 0, 2.0], (4,)),
([-1], [-2.0, 0, 2.0], (1, 4)),
([1.0, -2.0], [-1.0, 0, 1.0], (2, 4)),
([[1.0, -1.0, 0.0], [-1.0, 3.0, 5.0]], [-2.0, 0, 1.0], (2, 3, 4)),
],
)
def test_shape_inputs(self, eta, cutpoints, expected):
"""
This test checks when providing different shapes for `eta` parameters.
"""
categorical = _OrderedLogistic.dist(
eta=eta,
cutpoints=cutpoints,
)
p = categorical.owner.inputs[3].eval()
assert p.shape == expected


class TestOrderedProbit(BaseTestDistributionRandom):
pymc_dist = _OrderedProbit
Expand Down

0 comments on commit 466a941

Please sign in to comment.