Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip tests for optional/extra dependencies when not installed #1113

Merged
merged 26 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
124e473
Skip tests for optional/extra dependencies when not installed
hectormz Mar 9, 2020
6b05556
Update changelog
hectormz Mar 9, 2020
ad4a1e1
Fix pystan_version()
hectormz Mar 9, 2020
c8ca98c
Fix pylint issues
hectormz Mar 9, 2020
5f6e120
Relocate numba specific tests in test_diagnostics (and skip if not in…
hectormz Mar 9, 2020
bf207c3
Relocate numba specific tests in test_utils (and skip if not installed)
hectormz Mar 10, 2020
d82b18c
Relocate numba specific tests in test_stats (and skip if not installed)
hectormz Mar 10, 2020
042aed8
Add CI environment variable
hectormz Mar 22, 2020
33d485f
Add custom importorskip for CI
hectormz Mar 22, 2020
0765797
Use internal `importorskip`
hectormz Mar 22, 2020
54a18f3
Displayed skipped files in pytest output (due to importorskip)
hectormz Mar 22, 2020
2a50fe4
Move CI env variable detection into `importorskip`
hectormz Mar 22, 2020
a418aaa
Test `importorskip` for local/ci machines with `monkeypatch`
hectormz Mar 22, 2020
25f4cbb
Ignore vscode config
hectormz Mar 22, 2020
9b34ae7
Properly test for local machine with monkeypatch deleting CI env
hectormz Mar 23, 2020
30fb64e
Add back in some parts of `pytest.importorskip` to our `importorskip`…
hectormz Mar 23, 2020
8b4463d
Clarify reason text when test skipped for lack of pystan/pystan3
hectormz Mar 23, 2020
dfbc2d3
Refactor helper function to test running on CI machine
hectormz Mar 23, 2020
4981d78
Ensure individual tests for external requirements only skip locally
hectormz Mar 23, 2020
a02e781
Use `importlib.import_module` in `importorskip`
hectormz Mar 23, 2020
173e359
Attempt to fix failing CI imports
hectormz Mar 23, 2020
b8ad150
Revert import method
hectormz Mar 23, 2020
a480cd6
Correct skip logic
hectormz Mar 30, 2020
6c387dc
Breakup import statements
hectormz Mar 30, 2020
32bee02
Correct `pydocstyle` issues
hectormz Mar 30, 2020
3edc64e
Correct pystan skip logic
hectormz Mar 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .azure-pipelines/azure-pipelines-base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ jobs:
variables:
- name: NUMBA_DISABLE_JIT
value: 1
- name: ARVIZ_CI_MACHINE
value: 1
timeoutInMinutes: 360
strategy:
matrix:
Expand Down
2 changes: 2 additions & 0 deletions .azure-pipelines/azure-pipelines-external.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ jobs:
variables:
- name: NUMBA_DISABLE_JIT
value: 1
- name: ARVIZ_CI_MACHINE
value: 1
timeoutInMinutes: 360
strategy:
matrix:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ target/

# IDE configs
.idea/
.vscode/

saved_animations/

Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* Add `num_chains` and `pred_dims` arguments to io_pyro #1090
* Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079)
* Allow xarray.Dataarray input for plots.(#1120)
* Skip test for optional/extra dependencies when not installed (#1113)
### Maintenance and fixes
* Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115)
* Fixed hist kind of `plot_dist` with multidimensional input (#1115)
Expand Down Expand Up @@ -212,4 +213,3 @@
## v0.3.0 (2018 Dec 14)

* First Beta Release

100 changes: 9 additions & 91 deletions arviz/tests/base_tests/test_diagnostics.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
"""Test Diagnostic methods"""
# pylint: disable=redefined-outer-name, no-member, too-many-public-methods
import os

import numpy as np
from numpy.testing import assert_almost_equal, assert_array_almost_equal
import pandas as pd
import pytest
from numpy.testing import assert_almost_equal, assert_array_almost_equal

from ...data import load_arviz_data, from_cmdstan
from ...data import from_cmdstan, load_arviz_data
from ...plots.plot_utils import xarray_var_iter
from ...stats import bfmi, rhat, ess, mcse, geweke
from ...rcparams import rcParams
from ...stats import bfmi, ess, geweke, mcse, rhat
from ...stats.diagnostics import (
ks_summary,
_conv_quantile,
_ess,
_ess_quantile,
_multichain_statistics,
_mc_error,
_multichain_statistics,
_rhat,
_rhat_rank,
_z_scale,
_conv_quantile,
_split_chains,
_z_scale,
ks_summary,
)
from ...utils import Numba
from ...rcparams import rcParams

# For tests only, recommended value should be closer to 1.01-1.05
# See discussion in https://github.com/stan-dev/rstan/pull/618
Expand Down Expand Up @@ -536,85 +536,3 @@ def test_split_chain_dims(self, chains, draws):
if chains is None:
chains = 1
assert split_data.shape == (chains * 2, draws // 2)


def test_numba_bfmi():
"""Numba test for bfmi."""
state = Numba.numba_flag
school = load_arviz_data("centered_eight")
data_md = np.random.rand(100, 100, 10)
Numba.disable_numba()
non_numba = bfmi(school.posterior["mu"].values)
non_numba_md = bfmi(data_md)
Numba.enable_numba()
with_numba = bfmi(school.posterior["mu"].values)
with_numba_md = bfmi(data_md)
assert np.allclose(non_numba_md, with_numba_md)
assert np.allclose(with_numba, non_numba)
assert state == Numba.numba_flag


@pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity"))
def test_numba_rhat(method):
"""Numba test for mcse."""
state = Numba.numba_flag
school = np.random.rand(100, 100)
Numba.disable_numba()
non_numba = rhat(school, method=method)
Numba.enable_numba()
with_numba = rhat(school, method=method)
assert np.allclose(with_numba, non_numba)
assert Numba.numba_flag == state


@pytest.mark.parametrize("method", ("mean", "sd", "quantile"))
def test_numba_mcse(method, prob=None):
"""Numba test for mcse."""
state = Numba.numba_flag
school = np.random.rand(100, 100)
if method == "quantile":
prob = 0.80
Numba.disable_numba()
non_numba = mcse(school, method=method, prob=prob)
Numba.enable_numba()
with_numba = mcse(school, method=method, prob=prob)
assert np.allclose(with_numba, non_numba)
assert Numba.numba_flag == state


def test_ks_summary_numba():
"""Numba test for ks_summary."""
state = Numba.numba_flag
data = np.random.randn(100, 100)
Numba.disable_numba()
non_numba = (ks_summary(data)["Count"]).values
Numba.enable_numba()
with_numba = (ks_summary(data)["Count"]).values
assert np.allclose(non_numba, with_numba)
assert Numba.numba_flag == state


def test_geweke_numba():
"""Numba test for geweke."""
state = Numba.numba_flag
data = np.random.randn(100)
Numba.disable_numba()
non_numba = geweke(data)
Numba.enable_numba()
with_numba = geweke(data)
assert np.allclose(non_numba, with_numba)
assert Numba.numba_flag == state


@pytest.mark.parametrize("batches", (1, 20))
@pytest.mark.parametrize("circular", (True, False))
def test_mcse_error_numba(batches, circular):
"""Numba test for mcse_error."""
data = np.random.randn(100, 100)
state = Numba.numba_flag
Numba.disable_numba()
non_numba = _mc_error(data, batches=batches, circular=circular)
Numba.enable_numba()
with_numba = _mc_error(data, batches=batches, circular=circular)
assert np.allclose(non_numba, with_numba)
assert state == Numba.numba_flag
104 changes: 104 additions & 0 deletions arviz/tests/base_tests/test_diagnostics_numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Test Diagnostic methods"""
import importlib

# pylint: disable=redefined-outer-name, no-member, too-many-public-methods
import numpy as np
import pytest

from ...data import load_arviz_data
from ..helpers import running_on_ci
from ...rcparams import rcParams
from ...stats import bfmi, geweke, mcse, rhat
from ...stats.diagnostics import _mc_error, ks_summary
from ...utils import Numba
from .test_diagnostics import data # pylint: disable=unused-import


pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name
(importlib.util.find_spec("numba") is None) & ~running_on_ci(),
reason="test requires numba which is not installed",
)

rcParams["data.load"] = "eager"


def test_numba_bfmi():
"""Numba test for bfmi."""
state = Numba.numba_flag
school = load_arviz_data("centered_eight")
data_md = np.random.rand(100, 100, 10)
Numba.disable_numba()
non_numba = bfmi(school.posterior["mu"].values)
non_numba_md = bfmi(data_md)
Numba.enable_numba()
with_numba = bfmi(school.posterior["mu"].values)
with_numba_md = bfmi(data_md)
assert np.allclose(non_numba_md, with_numba_md)
assert np.allclose(with_numba, non_numba)
assert state == Numba.numba_flag


@pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity"))
def test_numba_rhat(method):
"""Numba test for mcse."""
state = Numba.numba_flag
school = np.random.rand(100, 100)
Numba.disable_numba()
non_numba = rhat(school, method=method)
Numba.enable_numba()
with_numba = rhat(school, method=method)
assert np.allclose(with_numba, non_numba)
assert Numba.numba_flag == state


@pytest.mark.parametrize("method", ("mean", "sd", "quantile"))
def test_numba_mcse(method, prob=None):
"""Numba test for mcse."""
state = Numba.numba_flag
school = np.random.rand(100, 100)
if method == "quantile":
prob = 0.80
Numba.disable_numba()
non_numba = mcse(school, method=method, prob=prob)
Numba.enable_numba()
with_numba = mcse(school, method=method, prob=prob)
assert np.allclose(with_numba, non_numba)
assert Numba.numba_flag == state


def test_ks_summary_numba():
"""Numba test for ks_summary."""
state = Numba.numba_flag
data = np.random.randn(100, 100)
Numba.disable_numba()
non_numba = (ks_summary(data)["Count"]).values
Numba.enable_numba()
with_numba = (ks_summary(data)["Count"]).values
assert np.allclose(non_numba, with_numba)
assert Numba.numba_flag == state


def test_geweke_numba():
"""Numba test for geweke."""
state = Numba.numba_flag
data = np.random.randn(100)
Numba.disable_numba()
non_numba = geweke(data)
Numba.enable_numba()
with_numba = geweke(data)
assert np.allclose(non_numba, with_numba)
assert Numba.numba_flag == state


@pytest.mark.parametrize("batches", (1, 20))
@pytest.mark.parametrize("circular", (True, False))
def test_mcse_error_numba(batches, circular):
"""Numba test for mcse_error."""
data = np.random.randn(100, 100)
state = Numba.numba_flag
Numba.disable_numba()
non_numba = _mc_error(data, batches=batches, circular=circular)
Numba.enable_numba()
with_numba = _mc_error(data, batches=batches, circular=circular)
assert np.allclose(non_numba, with_numba)
assert state == Numba.numba_flag
18 changes: 18 additions & 0 deletions arviz/tests/base_tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from _pytest.outcomes import Skipped

from ..helpers import importorskip


def test_importorskip_local(monkeypatch):
"""Test ``importorskip`` run on local machine with non-existent module, which should skip."""
monkeypatch.delenv("ARVIZ_CI_MACHINE", raising=False)
with pytest.raises(Skipped):
importorskip("non-existent-function")


def test_importorskip_ci(monkeypatch):
"""Test ``importorskip`` run on CI machine with non-existent module, which should fail."""
monkeypatch.setenv("ARVIZ_CI_MACHINE", 1)
with pytest.raises(ModuleNotFoundError):
importorskip("non-existent-function")
19 changes: 13 additions & 6 deletions arviz/tests/base_tests/test_plot_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# pylint: disable=redefined-outer-name
import importlib

import numpy as np
import xarray as xr
import pytest
import xarray as xr

from ...data import from_dict
from ..helpers import running_on_ci
from ...plots.plot_utils import (
make_2d,
xarray_to_ndarray,
xarray_var_iter,
get_bins,
get_coords,
filter_plotters_list,
format_sig_figs,
get_bins,
get_coords,
get_plotting_function,
make_2d,
matplotlib_kwarg_dealiaser,
xarray_to_ndarray,
xarray_var_iter,
)
from ...rcparams import rc_context

Expand Down Expand Up @@ -194,6 +197,10 @@ def test_filter_plotter_list_warning():
assert len(plotters_filtered) == 5


@pytest.mark.skipif(
(importlib.util.find_spec("bokeh") is None) & ~running_on_ci(),
reason="test requires bokeh which is not installed",
)
def test_bokeh_import():
"""Tests that correct method is returned on bokeh import"""
plot = get_plotting_function("plot_dist", "distplot", "bokeh")
Expand Down
27 changes: 16 additions & 11 deletions arviz/tests/base_tests/test_plots_bokeh.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
"""Tests use the 'bokeh' backend."""
# pylint: disable=redefined-outer-name,too-many-lines
from copy import deepcopy
import bokeh.plotting as bkp
from pandas import DataFrame

import numpy as np
import pytest
from pandas import DataFrame # pylint: disable=wrong-import-position

from ...data import from_dict, load_arviz_data
from ..helpers import ( # pylint: disable=unused-import
from ...data import from_dict, load_arviz_data # pylint: disable=wrong-import-position
from ..helpers import ( # pylint: disable=unused-import, wrong-import-position
create_model,
create_multidimensional_model,
eight_schools_params,
importorskip,
models,
create_model,
multidim_models,
create_multidimensional_model,
)
from ...rcparams import rcParams, rc_context
from ...plots import (
from ...rcparams import rc_context, rcParams # pylint: disable=wrong-import-position
from ...plots import ( # pylint: disable=wrong-import-position
plot_autocorr,
plot_compare,
plot_density,
Expand All @@ -31,14 +32,18 @@
plot_loo_pit,
plot_mcse,
plot_pair,
plot_rank,
plot_trace,
plot_parallel,
plot_posterior,
plot_ppc,
plot_rank,
plot_trace,
plot_violin,
)
from ...stats import compare, loo, waic
from ...stats import compare, loo, waic # pylint: disable=wrong-import-position

# Skip tests if bokeh not installed
bkp = importorskip("bokeh.plotting") # pylint: disable=invalid-name


rcParams["data.load"] = "eager"

Expand Down
8 changes: 2 additions & 6 deletions arviz/tests/base_tests/test_plots_matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,8 @@ def test_plot_trace(models, kwargs):
assert axes.shape


@pytest.mark.parametrize(
"compact", [True, False],
)
@pytest.mark.parametrize(
"combined", [True, False],
)
@pytest.mark.parametrize("compact", [True, False])
@pytest.mark.parametrize("combined", [True, False])
def test_plot_trace_legend(compact, combined):
idata = load_arviz_data("rugby")
axes = plot_trace(
Expand Down
Loading