Skip to content

Commit

Permalink
simplify test_shape_inputs for _OrderedProbit
Browse files Browse the repository at this point in the history
  • Loading branch information
danhphan committed Feb 4, 2022
1 parent 8d1d9d9 commit 32f6c89
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions pymc/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1698,33 +1698,33 @@ class TestOrderedProbit(BaseTestDistributionRandom):
"check_rv_size",
]

def test_vector_inputs(self):
@pytest.mark.parametrize(
"eta, cutpoints, sigma, expected",
[
(0, [-2.0, 0, 2.0], 1.0, (4,)),
([-1], [-2.0, 0, 2.0], [2.0], (1, 4)),
([1.0, -2.0], [-1.0, 0, 1.0], 1.0, (2, 4)),
([1.0, -2.0, 3.0], [-2.0, 0, 2.0], [-1.0, -2.0, 5.0], (3, 4)),
([[1.0, -1.0, 0.0], [-1.0, 3.0, 5.0]], [-2.0, 0, 1.0], [-1.0, -2.0, 5.0], (2, 3, 4)),
(
[[1.0, -2.0, 3.0], [1.0, 2.0, -4.0]],
[-2.0, 0, 1.0],
[[0.0, 2.0, -4.0], [-1.0, 1.0, 3.0]],
(2, 3, 4),
),
],
)
def test_shape_inputs(self, eta, cutpoints, sigma, expected):
"""
This test checks when providing vector inputs for `eta` and `sigma` parameters using advanced indexing.
This test checks when providing different shapes for `eta` and `sigma` parameters.
"""
categorical = pm.OrderedProbit.dist(
eta=0,
cutpoints=np.array([-2.0, 0, 2.0]),
sigma=1.0,
)
p = categorical.owner.inputs[3].eval()
assert p.shape == (4,)

categorical = pm.OrderedProbit.dist(
eta=np.array([1.0, 2.0, 3.0, 4.0, 5.0]),
cutpoints=np.array([-2.0, 0, 2.0]),
sigma=1,
)
p = categorical.owner.inputs[3].eval()
assert p.shape == (5, 4)

categorical = pm.OrderedProbit.dist(
eta=np.array([1.0, 2.0, 3.0, 4.0, 5.0]),
cutpoints=np.array([-2.0, 0, 2.0]),
sigma=np.array([1.0, 2.0, 3.0, 4.0, 5.0]),
categorical = _OrderedProbit.dist(
eta=eta,
cutpoints=cutpoints,
sigma=sigma,
)
p = categorical.owner.inputs[3].eval()
assert p.shape == (5, 4)
assert p.shape == expected


class TestOrderedMultinomial(BaseTestDistributionRandom):
Expand Down

0 comments on commit 32f6c89

Please sign in to comment.