Skip to content

Commit

Permalink
Add test for overrides string keys
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 13, 2021
1 parent 735c29e commit 4e40756
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion pymc/tests/test_initvals.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_respects_overrides(self):
return_transformed=True,
overrides={
A: at.as_tensor(2, dtype=int),
"B": 3,
B: 3,
C: 5,
},
)
Expand All @@ -154,6 +154,27 @@ def test_respects_overrides(self):
assert np.isclose(iv["B_log__"], np.log(3))
assert iv["C"] == 5

def test_string_overrides_work(self):
with pm.Model() as pmodel:
A = pm.Flat("A", initval=10)
B = pm.HalfFlat("B", initval=10)
C = pm.HalfFlat("C", initval=10)

fn = make_initial_point_fn(
model=pmodel,
jitter_rvs={},
return_transformed=True,
overrides={
"A": 1,
"B": 1,
"C_log__": 0,
},
)
iv = fn(0)
assert iv["A"] == 1
assert np.isclose(iv["B_log__"], 0)
assert iv["C_log__"] == 0


class TestMoment:
def test_basic(self):
Expand Down

0 comments on commit 4e40756

Please sign in to comment.