From d67c9a37a9f1fcda6651e8a242eeb4c22ae80a30 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 3 Sep 2021 18:36:42 +0300 Subject: [PATCH] add missing data test --- pymc3/tests/test_bart.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pymc3/tests/test_bart.py b/pymc3/tests/test_bart.py index 0a32de4d116..5d221633a44 100644 --- a/pymc3/tests/test_bart.py +++ b/pymc3/tests/test_bart.py @@ -65,3 +65,15 @@ def test_bart_random(): assert_almost_equal(pred_first, pred_all[0, :10], decimal=4) assert pred_all.shape == (2, 50) assert pred_first.shape == (10,) + + +def test_missing_data(): + X = np.random.normal(0, 1, size=(2, 50)).T + Y = np.random.normal(0, 1, size=50) + X[10:20, 0] = np.nan + + with pm.Model() as model: + mu = pm.BART("mu", X, Y, m=10) + sigma = pm.HalfNormal("sigma", 1) + y = pm.Normal("y", mu, sigma, observed=Y) + idata = pm.sample(random_seed=3415)