Skip to content

Commit

Permalink
Split jax tests into their own workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege committed Feb 12, 2021
1 parent 78d15f4 commit 07b715c
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 8 deletions.
64 changes: 64 additions & 0 deletions .github/workflows/jaxtests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
name: jax-sampling

on:
pull_request:
push:
branches: [master]

jobs:
pytest:
strategy:
matrix:
os: [ubuntu-latest]
floatx: [float64]
test-subset:
- pymc3/tests/test_sampling_jax.py
fail-fast: false
runs-on: ${{ matrix.os }}
env:
TEST_SUBSET: ${{ matrix.test-subset }}
THEANO_FLAGS: floatX=${{ matrix.floatx }},gcc__cxxflags='-march=native'
defaults:
run:
shell: bash -l {0}
steps:
- uses: actions/checkout@v2
- name: Cache conda
uses: actions/cache@v1
env:
# Increase this value to reset cache if environment-dev-py39.yml has not changed
CACHE_NUMBER: 0
with:
path: ~/conda_pkgs_dir
key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{
hashFiles('conda-envs/environment-dev-py39.yml') }}
- name: Cache multiple paths
uses: actions/cache@v2
env:
# Increase this value to reset cache if requirements.txt has not changed
CACHE_NUMBER: 0
with:
path: |
~/.cache/pip
$RUNNER_TOOL_CACHE/Python/*
~\AppData\Local\pip\Cache
key: ${{ runner.os }}-build-${{ matrix.python-version }}-${{
hashFiles('requirements.txt') }}
- uses: conda-incubator/setup-miniconda@v2
with:
activate-environment: pymc3-dev-py39
channel-priority: strict
environment-file: conda-envs/environment-dev-py39.yml
use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly!
- name: Install pymc3
run: |
conda activate pymc3-dev-py39
pip install -e .
python --version
- name: Install jax specific dependencies
run: |
conda activate pymc3-dev-py39
pip install numpyro tensorflow_probability
- name: Run tests
run: |
python -m pytest -vv --cov=pymc3 --cov-report=xml --cov-report term --durations=50 $TEST_SUBSET
1 change: 1 addition & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
--ignore=pymc3/tests/test_quadpotential.py
--ignore=pymc3/tests/test_random.py
--ignore=pymc3/tests/test_sampling.py
--ignore=pymc3/tests/test_sampling_jax.py
--ignore=pymc3/tests/test_shape_handling.py
--ignore=pymc3/tests/test_shared.py
--ignore=pymc3/tests/test_smc.py
Expand Down
20 changes: 12 additions & 8 deletions scripts/check_all_tests_are_covered.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
from pathlib import Path

if __name__ == "__main__":
pytest_ci_job = Path(".github") / "workflows/pytest.yml"
txt = pytest_ci_job.read_text()
ignored_tests = set(re.findall(r"(?<=--ignore=)(pymc3/tests.*\.py)", txt))
non_ignored_tests = set(re.findall(r"(?<!--ignore=)(pymc3/tests.*\.py)", txt))
testing_workflows = ["jaxtests.yml", "pytest.yml"]
ignored = set()
non_ignored = set()
for wfyml in testing_workflows:
pytest_ci_job = Path(".github") / "workflows" / wfyml
txt = pytest_ci_job.read_text()
ignored = set(re.findall(r"(?<=--ignore=)(pymc3/tests.*\.py)", txt))
non_ignored = non_ignored.union(set(re.findall(r"(?<!--ignore=)(pymc3/tests.*\.py)", txt)))
assert (
ignored_tests <= non_ignored_tests
), f"The following tests are ignored by the first job but not run by the others: {ignored_tests.difference(non_ignored_tests)}"
ignored <= non_ignored
), f"The following tests are ignored by the first job but not run by the others: {ignored.difference(non_ignored)}"
assert (
ignored_tests >= non_ignored_tests
), f"The following tests are run by multiple jobs: {non_ignored_tests.difference(ignored_tests)}"
ignored >= non_ignored
), f"The following tests are run by multiple jobs: {non_ignored.difference(ignored)}"

0 comments on commit 07b715c

Please sign in to comment.