Skip to content

Commit

Permalink
Fix bug that does not correctly set the dtype of determinsitic variab… (
Browse files Browse the repository at this point in the history
#6425)

* Fix bug that does not correctly set the dtype of determinsitic variable after automatic imputation
* Change `at.zeros` to `at.empty` when creating combined observed/missing vector

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
  • Loading branch information
jessegrabowski and ricardoV94 authored Jan 12, 2023
1 parent c3b8ff4 commit 6c4d4eb
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,7 +1456,7 @@ def make_obs_var(

# Create deterministic that combines observed and missing
# Note: This can widely increase memory consumption during sampling for large datasets
rv_var = at.zeros(data.shape)
rv_var = at.empty(data.shape, dtype=observed_rv_var.type.dtype)
rv_var = at.set_subtensor(rv_var[mask.nonzero()], missing_rv_var)
rv_var = at.set_subtensor(rv_var[antimask_idx], observed_rv_var)
rv_var = Deterministic(name, rv_var, self, dims)
Expand Down
3 changes: 3 additions & 0 deletions pymc/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,9 @@ def test_missing_data(self):

assert m["x2_missing"].type == gf._extra_vars_shared["x2_missing"].type

# The dtype of the merged observed/missing deterministic should match the RV dtype
assert m.deterministics[0].type.dtype == x2.type.dtype

pnt = m.initial_point(random_seed=None).copy()
del pnt["x2_missing"]

Expand Down

0 comments on commit 6c4d4eb

Please sign in to comment.