Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform jax samples #4427

Merged
merged 5 commits into from
Feb 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions .github/workflows/jaxtests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
name: jax-sampling

on:
pull_request:
push:
branches: [master]

jobs:
pytest:
strategy:
matrix:
os: [ubuntu-latest]
floatx: [float64]
test-subset:
- pymc3/tests/test_sampling_jax.py
fail-fast: false
runs-on: ${{ matrix.os }}
env:
TEST_SUBSET: ${{ matrix.test-subset }}
THEANO_FLAGS: floatX=${{ matrix.floatx }},gcc__cxxflags='-march=native'
defaults:
run:
shell: bash -l {0}
steps:
- uses: actions/checkout@v2
- name: Cache conda
uses: actions/cache@v1
env:
# Increase this value to reset cache if environment-dev-py39.yml has not changed
CACHE_NUMBER: 0
with:
path: ~/conda_pkgs_dir
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
hashFiles('conda-envs/environment-dev-py39.yml') }}
- name: Cache multiple paths
uses: actions/cache@v2
env:
# Increase this value to reset cache if requirements.txt has not changed
CACHE_NUMBER: 0
with:
path: |
~/.cache/pip
$RUNNER_TOOL_CACHE/Python/*
~\AppData\Local\pip\Cache
key: ${{ runner.os }}-build-${{ matrix.python-version }}-${{
hashFiles('requirements.txt') }}
- uses: conda-incubator/setup-miniconda@v2
with:
activate-environment: pymc3-dev-py39
channel-priority: strict
environment-file: conda-envs/environment-dev-py39.yml
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
- name: Install pymc3
run: |
conda activate pymc3-dev-py39
pip install -e .
python --version
- name: Install jax specific dependencies
run: |
conda activate pymc3-dev-py39
pip install numpyro tensorflow_probability
- name: Run tests
run: |
python -m pytest -vv --cov=pymc3 --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
--ignore=pymc3/tests/test_quadpotential.py
--ignore=pymc3/tests/test_random.py
--ignore=pymc3/tests/test_sampling.py
--ignore=pymc3/tests/test_sampling_jax.py
--ignore=pymc3/tests/test_shape_handling.py
--ignore=pymc3/tests/test_shared.py
--ignore=pymc3/tests/test_smc.py
Expand Down
3 changes: 2 additions & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
### Breaking Changes

### New Features
+ Automatic imputations now also work with `ndarray` data, not just `pd.Series` or `pd.DataFrame` (see[#4439](https://github.com/pymc-devs/pymc3/pull/4439)).
- Automatic imputations now also work with `ndarray` data, not just `pd.Series` or `pd.DataFrame` (see[#4439](https://github.com/pymc-devs/pymc3/pull/4439)).
- `pymc3.sampling_jax.sample_numpyro_nuts` now returns samples from transformed random variables, rather than from the unconstrained representation (see [#4427](https://github.com/pymc-devs/pymc3/pull/4427)).

### Maintenance
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).
Expand Down
45 changes: 44 additions & 1 deletion pymc3/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import re
import warnings

from collections import defaultdict

xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(100)])
Expand Down Expand Up @@ -121,6 +123,7 @@ def sample_numpyro_nuts(
random_seed=10,
model=None,
progress_bar=True,
keep_untransformed=False,
):
from numpyro.infer import MCMC, NUTS

Expand Down Expand Up @@ -175,8 +178,48 @@ def _sample(current_state, seed):
# print("Sampling time = ", tic4 - tic3)

posterior = {k: v for k, v in zip(rv_names, mcmc_samples)}
tic3 = pd.Timestamp.now()
posterior = _transform_samples(posterior, model, keep_untransformed=keep_untransformed)
tic4 = pd.Timestamp.now()

az_trace = az.from_dict(posterior=posterior)
tic3 = pd.Timestamp.now()
print("Compilation + sampling time = ", tic3 - tic2)
print("Transformation time = ", tic4 - tic3)

return az_trace # , leapfrogs_taken, tic3 - tic2


def _transform_samples(samples, model, keep_untransformed=False):

# Find out which RVs we need to compute:
free_rv_names = {x.name for x in model.free_RVs}
unobserved_names = {x.name for x in model.unobserved_RVs}

names_to_compute = unobserved_names - free_rv_names
ops_to_compute = [x for x in model.unobserved_RVs if x.name in names_to_compute]

# Create function graph for these:
fgraph = theano.graph.fg.FunctionGraph(model.free_RVs, ops_to_compute)

# Jaxify, which returns a list of functions, one for each op
jax_fns = jax_funcify(fgraph)

# Put together the inputs
inputs = [samples[x.name] for x in model.free_RVs]

for cur_op, cur_jax_fn in zip(ops_to_compute, jax_fns):

# We need a function taking a single argument to run vmap, while the
# jax_fn takes a list, so:
result = jax.vmap(jax.vmap(cur_jax_fn))(*inputs)

# Add to sample dict
samples[cur_op.name] = result

# Discard unwanted transformed variables, if desired:
vars_to_keep = set(
pm.util.get_default_varnames(list(samples.keys()), include_transformed=keep_untransformed)
)
samples = {x: y for x, y in samples.items() if x in vars_to_keep}

return samples
19 changes: 19 additions & 0 deletions pymc3/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np

import pymc3 as pm

from pymc3.sampling_jax import sample_numpyro_nuts


def test_transform_samples():

with pm.Model() as model:

sigma = pm.HalfNormal("sigma")
b = pm.Normal("b", sigma=sigma)
trace = sample_numpyro_nuts(keep_untransformed=True)

log_vals = trace.posterior["sigma_log__"].values
trans_vals = trace.posterior["sigma"].values

assert np.allclose(np.exp(log_vals), trans_vals)
20 changes: 12 additions & 8 deletions scripts/check_all_tests_are_covered.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
from pathlib import Path

if __name__ == "__main__":
pytest_ci_job = Path(".github") / "workflows/pytest.yml"
txt = pytest_ci_job.read_text()
ignored_tests = set(re.findall(r"(?<=--ignore=)(pymc3/tests.*\.py)", txt))
non_ignored_tests = set(re.findall(r"(?<!--ignore=)(pymc3/tests.*\.py)", txt))
testing_workflows = ["jaxtests.yml", "pytest.yml"]
ignored = set()
non_ignored = set()
for wfyml in testing_workflows:
pytest_ci_job = Path(".github") / "workflows" / wfyml
txt = pytest_ci_job.read_text()
ignored = set(re.findall(r"(?<=--ignore=)(pymc3/tests.*\.py)", txt))
non_ignored = non_ignored.union(set(re.findall(r"(?<!--ignore=)(pymc3/tests.*\.py)", txt)))
assert (
ignored_tests <= non_ignored_tests
), f"The following tests are ignored by the first job but not run by the others: {ignored_tests.difference(non_ignored_tests)}"
ignored <= non_ignored
), f"The following tests are ignored by the first job but not run by the others: {ignored.difference(non_ignored)}"
assert (
ignored_tests >= non_ignored_tests
), f"The following tests are run by multiple jobs: {non_ignored_tests.difference(ignored_tests)}"
ignored >= non_ignored
), f"The following tests are run by multiple jobs: {non_ignored.difference(ignored)}"