Skip to content

Commit 28d15f8

Browse files
committed
- Float masked arrays are filled with nan when
passed to pm.Data() and pm.Model().set_data() - Integer masked arrays trigger an error message and provide suggested alternatives
1 parent b7764dd commit 28d15f8

File tree

6 files changed

+114
-1
lines changed

6 files changed

+114
-1
lines changed

pymc/data.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
import pymc as pm
3838

39-
from pymc.pytensorf import convert_observed_data
39+
from pymc.pytensorf import convert_observed_data, unmask_masked_data
4040

4141
__all__ = [
4242
"get_data",
@@ -419,10 +419,20 @@ def Data(
419419
)
420420
name = model.name_for(name)
421421

422+
if isinstance(value, np.ma.MaskedArray):
423+
warnings.warn(
424+
"If possible, masked arrays will be converted to standard numpy arrays with np.nan values for compatibility with PyTensor."
425+
)
426+
422427
# `convert_observed_data` takes care of parameter `value` and
423428
# transforms it to something digestible for PyTensor.
424429
arr = convert_observed_data(value)
425430

431+
# because converted_observed_data() is also used outside pyTensor, we need an extra step to ensure that any masked arrays
432+
# produced by it are converted back to np.ndarray() with np.nan value.
433+
# This is not very efficient and will not be necessary once pyTensor implements MaskedArray support
434+
arr = unmask_masked_data(arr)
435+
426436
if mutable is None:
427437
warnings.warn(
428438
"The `mutable` kwarg was not specified. Before v4.1.0 it defaulted to `pm.Data(mutable=True)`,"

pymc/model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
hessian,
7676
inputvars,
7777
replace_rvs_by_values,
78+
unmask_masked_data,
7879
)
7980
from pymc.util import (
8081
UNSET,
@@ -1184,7 +1185,19 @@ def set_data(
11841185

11851186
if isinstance(values, list):
11861187
values = np.array(values)
1188+
1189+
if isinstance(values, np.ma.MaskedArray):
1190+
warnings.warn(
1191+
"If possible, masked arrays will be converted to standard numpy arrays with np.nan values for compatibility with PyTensor."
1192+
)
1193+
11871194
values = convert_observed_data(values)
1195+
1196+
# because converted_observed_data() is also used outside pyTensor, we need an extra step to ensure that any masked arrays
1197+
# produced by it are converted back to np.ndarray() with np.nan value.
1198+
# This is not very efficient and will not be necessary once pyTensor implements MaskedArray support
1199+
values = unmask_masked_data(values)
1200+
11881201
dims = self.named_vars_to_dims.get(name, None) or ()
11891202
coords = coords or {}
11901203

pymc/pytensorf.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,33 @@
8484
"make_shared_replacements",
8585
"generator",
8686
"convert_observed_data",
87+
"unmask_masked_data",
8788
"compile_pymc",
8889
"constant_fold",
8990
]
9091

9192

93+
def unmask_masked_data(data):
94+
"""Unmask masked numpy arrays for usage within PyTensor"""
95+
96+
# PyTensor currently does not support masked arrays
97+
# If a masked array is passed, we convert it to a standard numpy array with np.nans for float type arrays
98+
# In case of integer type arrays, we throw an error as np.nan is a float concept.
99+
100+
if isinstance(data, np.ma.MaskedArray):
101+
if "int" in str(data.dtype):
102+
raise TypeError(
103+
"Masked integer arrays (integer type datasets with missing values) are not supported by pm.Data() / pm.Model.set_data() at this time. \n"
104+
"Consider if using a float type fits your use case. \n"
105+
"Alternatively, if you want to benefit from automatic imputation in pyMC, pass a masked array directly to `observed=` parameter when defining a distribution."
106+
)
107+
else:
108+
ret = data.filled(fill_value=np.nan)
109+
else:
110+
ret = data
111+
return ret
112+
113+
92114
def convert_observed_data(data):
93115
"""Convert user provided dataset to accepted formats."""
94116

tests/test_data.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,33 @@ def test_get_data():
454454
assert type(data) == io.BytesIO
455455

456456

457+
def test_masked_data_mutable():
458+
with pm.Model():
459+
data = np.ma.MaskedArray([1.0, 2.0, 3], [0, 0, 1])
460+
expected = np.array([1, 2, np.nan])
461+
with pytest.warns(UserWarning, match="masked arrays"):
462+
result = pm.MutableData("test", data).get_value()
463+
np.testing.assert_array_equal(result, expected)
464+
465+
466+
def test_masked_data_constant():
467+
with pm.Model():
468+
data = np.ma.MaskedArray([1.0, 2.0, 3], [0, 0, 1])
469+
expected = np.array([1, 2, np.nan])
470+
with pytest.warns(UserWarning, match="masked arrays"):
471+
result = pm.ConstantData("test", data).data
472+
np.testing.assert_array_equal(result, expected)
473+
474+
475+
def test_masked_integer_data():
476+
with pm.Model():
477+
data = np.ma.MaskedArray([1, 2, 3], [0, 0, 1])
478+
with pytest.raises(TypeError, match="Masked integer"):
479+
pm.ConstantData("test", data)
480+
with pytest.raises(TypeError, match="Masked integer"):
481+
pm.MutableData("test", data)
482+
483+
457484
class _DataSampler:
458485
"""
459486
Not for users

tests/test_model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,6 +967,27 @@ def test_set_data_constant_shape_error():
967967
pmodel.set_data("y", np.arange(10))
968968

969969

970+
def test_set_data_masked_array():
971+
data = np.ma.MaskedArray([1.0, 2.0, 3], [0, 0, 1])
972+
973+
with pm.Model() as pmodel:
974+
D = pm.MutableData("test", np.zeros(4))
975+
976+
with pytest.warns(UserWarning, match="masked arrays"):
977+
pmodel.set_data("test", data)
978+
result = D.get_value()
979+
expected = np.array([1.0, 2.0, np.nan])
980+
np.testing.assert_array_equal(result, expected)
981+
982+
983+
def test_set_data_masked_integer_array():
984+
with pm.Model() as pmodel:
985+
D = pm.MutableData("test", np.zeros(4))
986+
with pytest.warns(UserWarning, match="masked arrays"):
987+
with pytest.raises(TypeError, match="Masked integer"):
988+
pmodel.set_data("test", np.ma.MaskedArray([1, 2, 3], [0, 0, 1]))
989+
990+
970991
def test_model_deprecation_warning():
971992
with pm.Model() as m:
972993
x = pm.Normal("x", 0, 1, size=2)

tests/test_pytensorf.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
replace_rvs_by_values,
5050
reseed_rngs,
5151
rvs_to_value_vars,
52+
unmask_masked_data,
5253
walk_model,
5354
)
5455
from pymc.testing import assert_no_rvs
@@ -269,6 +270,25 @@ def test_convert_observed_data(input_dtype):
269270
assert isinstance(wrapped, TensorVariable)
270271

271272

273+
def test_unmask_masked_data():
274+
# test with non-masked data
275+
data = np.array([1, 2, 3])
276+
result = unmask_masked_data(data)
277+
expected = np.array([1, 2, 3])
278+
np.testing.assert_array_equal(result, expected)
279+
280+
# test with masked float data
281+
data = np.ma.MaskedArray([1.0, 2.0, 3.0], [0, 0, 1])
282+
result = unmask_masked_data(data)
283+
expected = np.array([1.0, 2.0, np.nan])
284+
np.testing.assert_array_equal(result, expected)
285+
286+
# test with integer masked data
287+
data = np.ma.MaskedArray([1, 2, 3], [0, 0, 1])
288+
with pytest.raises(TypeError, match="Masked integer"):
289+
unmask_masked_data(data)
290+
291+
272292
def test_pandas_to_array_pandas_index():
273293
data = pd.Index([1, 2, 3])
274294
result = convert_observed_data(data)

0 commit comments

Comments
 (0)