From d9504854ae717e23c530002bfcb6f5d37349c3e9 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 5 May 2021 15:18:58 -0500 Subject: [PATCH] Add a temporary MacOS test --- .github/workflows/jaxtests.yml | 2 +- pymc3/tests/test_sampling_jax.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/.github/workflows/jaxtests.yml b/.github/workflows/jaxtests.yml index 2e2f16b33ad..97ce760b672 100644 --- a/.github/workflows/jaxtests.yml +++ b/.github/workflows/jaxtests.yml @@ -9,7 +9,7 @@ jobs: pytest: strategy: matrix: - os: [ubuntu-latest] + os: [macos-latest, ubuntu-latest] floatx: [float64] test-subset: - pymc3/tests/test_sampling_jax.py diff --git a/pymc3/tests/test_sampling_jax.py b/pymc3/tests/test_sampling_jax.py index b2d39d130e6..030688427c9 100644 --- a/pymc3/tests/test_sampling_jax.py +++ b/pymc3/tests/test_sampling_jax.py @@ -1,4 +1,5 @@ import aesara +import aesara.tensor as at import numpy as np import pymc3 as pm @@ -33,3 +34,33 @@ def test_transform_samples(): assert -11 < trace.posterior["a"].mean() < -8 assert 1.5 < trace.posterior["sigma"].mean() < 2.5 + + +def test_example(): + + np.random.seed(23248) + + n = 250 + k_true = 5 + d = 9 + err_sd = 2 + M = np.random.binomial(1, 0.25, size=(k_true, n)) + Q = np.hstack( + [np.random.exponential(2 * k_true - k, size=(d, 1)) for k in range(k_true)] + ) * np.random.binomial(1, 0.75, size=(d, k_true)) + Y = np.round(1000 * np.dot(Q, M) + np.random.normal(size=(d, n)) * err_sd) / 1000 + + k = 2 + + with pm.Model() as PPCA: + W = pm.Normal("W", size=(d, k)) + F = pm.Normal("F", size=(k, n)) + psi = pm.HalfNormal("psi", 1.0) + X = pm.Normal("X", mu=at.dot(W, F), sigma=psi, observed=Y) + W_plot = pm.Deterministic("W_plot", W[1:3, 0]) + F_plot = pm.Deterministic("F_plot", F[0, 1:3]) + + trace = sample_numpyro_nuts() + + # trace.posterior['W_plot'] = trace.posterior.W[:, :, 1:3, 0] + # trace.posterior['F_plot'] = trace.posterior.F[:, :, 0, 1:3]