Skip to content

Commit

Permalink
Add a temporary MacOS test
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 5, 2021
1 parent 073da0e commit d950485
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/jaxtests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
pytest:
strategy:
matrix:
os: [ubuntu-latest]
os: [macos-latest, ubuntu-latest]
floatx: [float64]
test-subset:
- pymc3/tests/test_sampling_jax.py
Expand Down
31 changes: 31 additions & 0 deletions pymc3/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import aesara
import aesara.tensor as at
import numpy as np

import pymc3 as pm
Expand Down Expand Up @@ -33,3 +34,33 @@ def test_transform_samples():

assert -11 < trace.posterior["a"].mean() < -8
assert 1.5 < trace.posterior["sigma"].mean() < 2.5


def test_example():

np.random.seed(23248)

n = 250
k_true = 5
d = 9
err_sd = 2
M = np.random.binomial(1, 0.25, size=(k_true, n))
Q = np.hstack(
[np.random.exponential(2 * k_true - k, size=(d, 1)) for k in range(k_true)]
) * np.random.binomial(1, 0.75, size=(d, k_true))
Y = np.round(1000 * np.dot(Q, M) + np.random.normal(size=(d, n)) * err_sd) / 1000

k = 2

with pm.Model() as PPCA:
W = pm.Normal("W", size=(d, k))
F = pm.Normal("F", size=(k, n))
psi = pm.HalfNormal("psi", 1.0)
X = pm.Normal("X", mu=at.dot(W, F), sigma=psi, observed=Y)
W_plot = pm.Deterministic("W_plot", W[1:3, 0])
F_plot = pm.Deterministic("F_plot", F[0, 1:3])

trace = sample_numpyro_nuts()

# trace.posterior['W_plot'] = trace.posterior.W[:, :, 1:3, 0]
# trace.posterior['F_plot'] = trace.posterior.F[:, :, 0, 1:3]

0 comments on commit d950485

Please sign in to comment.