diff --git a/.azure-pipelines/azure-pipelines-base.yml b/.azure-pipelines/azure-pipelines-base.yml index a3eff465f9..53cb1ef709 100644 --- a/.azure-pipelines/azure-pipelines-base.yml +++ b/.azure-pipelines/azure-pipelines-base.yml @@ -5,6 +5,8 @@ jobs: variables: - name: NUMBA_DISABLE_JIT value: 1 + - name: ARVIZ_CI_MACHINE + value: 1 timeoutInMinutes: 360 strategy: matrix: diff --git a/.azure-pipelines/azure-pipelines-external.yml b/.azure-pipelines/azure-pipelines-external.yml index 67542506b6..12afd50c50 100644 --- a/.azure-pipelines/azure-pipelines-external.yml +++ b/.azure-pipelines/azure-pipelines-external.yml @@ -5,6 +5,8 @@ jobs: variables: - name: NUMBA_DISABLE_JIT value: 1 + - name: ARVIZ_CI_MACHINE + value: 1 timeoutInMinutes: 360 strategy: matrix: diff --git a/.gitignore b/.gitignore index 9e50d2b96d..f4fca72d24 100644 --- a/.gitignore +++ b/.gitignore @@ -62,6 +62,7 @@ target/ # IDE configs .idea/ +.vscode/ saved_animations/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 75cad75182..b2bf074d13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * Add `num_chains` and `pred_dims` arguments to io_pyro #1090 * Integrate jointplot into pairplot, add point-estimate and overlay of plot kinds (#1079) * Allow xarray.Dataarray input for plots.(#1120) +* Skip test for optional/extra dependencies when not installed (#1113) ### Maintenance and fixes * Fixed behaviour of `credible_interval=None` in `plot_posterior` (#1115) * Fixed hist kind of `plot_dist` with multidimensional input (#1115) @@ -212,4 +213,3 @@ ## v0.3.0 (2018 Dec 14) * First Beta Release - diff --git a/arviz/tests/base_tests/test_diagnostics.py b/arviz/tests/base_tests/test_diagnostics.py index 102510bcec..01b1f0d65f 100644 --- a/arviz/tests/base_tests/test_diagnostics.py +++ b/arviz/tests/base_tests/test_diagnostics.py @@ -1,28 +1,28 @@ """Test Diagnostic methods""" # pylint: disable=redefined-outer-name, no-member, too-many-public-methods import os + import numpy as np -from numpy.testing import assert_almost_equal, assert_array_almost_equal import pandas as pd import pytest +from numpy.testing import assert_almost_equal, assert_array_almost_equal -from ...data import load_arviz_data, from_cmdstan +from ...data import from_cmdstan, load_arviz_data from ...plots.plot_utils import xarray_var_iter -from ...stats import bfmi, rhat, ess, mcse, geweke +from ...rcparams import rcParams +from ...stats import bfmi, ess, geweke, mcse, rhat from ...stats.diagnostics import ( - ks_summary, + _conv_quantile, _ess, _ess_quantile, - _multichain_statistics, _mc_error, + _multichain_statistics, _rhat, _rhat_rank, - _z_scale, - _conv_quantile, _split_chains, + _z_scale, + ks_summary, ) -from ...utils import Numba -from ...rcparams import rcParams # For tests only, recommended value should be closer to 1.01-1.05 # See discussion in https://github.com/stan-dev/rstan/pull/618 @@ -536,85 +536,3 @@ def test_split_chain_dims(self, chains, draws): if chains is None: chains = 1 assert split_data.shape == (chains * 2, draws // 2) - - -def test_numba_bfmi(): - """Numba test for bfmi.""" - state = Numba.numba_flag - school = load_arviz_data("centered_eight") - data_md = np.random.rand(100, 100, 10) - Numba.disable_numba() - non_numba = bfmi(school.posterior["mu"].values) - non_numba_md = bfmi(data_md) - Numba.enable_numba() - with_numba = bfmi(school.posterior["mu"].values) - with_numba_md = bfmi(data_md) - assert np.allclose(non_numba_md, with_numba_md) - assert np.allclose(with_numba, non_numba) - assert state == Numba.numba_flag - - -@pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity")) -def test_numba_rhat(method): - """Numba test for mcse.""" - state = Numba.numba_flag - school = np.random.rand(100, 100) - Numba.disable_numba() - non_numba = rhat(school, method=method) - Numba.enable_numba() - with_numba = rhat(school, method=method) - assert np.allclose(with_numba, non_numba) - assert Numba.numba_flag == state - - -@pytest.mark.parametrize("method", ("mean", "sd", "quantile")) -def test_numba_mcse(method, prob=None): - """Numba test for mcse.""" - state = Numba.numba_flag - school = np.random.rand(100, 100) - if method == "quantile": - prob = 0.80 - Numba.disable_numba() - non_numba = mcse(school, method=method, prob=prob) - Numba.enable_numba() - with_numba = mcse(school, method=method, prob=prob) - assert np.allclose(with_numba, non_numba) - assert Numba.numba_flag == state - - -def test_ks_summary_numba(): - """Numba test for ks_summary.""" - state = Numba.numba_flag - data = np.random.randn(100, 100) - Numba.disable_numba() - non_numba = (ks_summary(data)["Count"]).values - Numba.enable_numba() - with_numba = (ks_summary(data)["Count"]).values - assert np.allclose(non_numba, with_numba) - assert Numba.numba_flag == state - - -def test_geweke_numba(): - """Numba test for geweke.""" - state = Numba.numba_flag - data = np.random.randn(100) - Numba.disable_numba() - non_numba = geweke(data) - Numba.enable_numba() - with_numba = geweke(data) - assert np.allclose(non_numba, with_numba) - assert Numba.numba_flag == state - - -@pytest.mark.parametrize("batches", (1, 20)) -@pytest.mark.parametrize("circular", (True, False)) -def test_mcse_error_numba(batches, circular): - """Numba test for mcse_error.""" - data = np.random.randn(100, 100) - state = Numba.numba_flag - Numba.disable_numba() - non_numba = _mc_error(data, batches=batches, circular=circular) - Numba.enable_numba() - with_numba = _mc_error(data, batches=batches, circular=circular) - assert np.allclose(non_numba, with_numba) - assert state == Numba.numba_flag diff --git a/arviz/tests/base_tests/test_diagnostics_numba.py b/arviz/tests/base_tests/test_diagnostics_numba.py new file mode 100644 index 0000000000..71bce25ae9 --- /dev/null +++ b/arviz/tests/base_tests/test_diagnostics_numba.py @@ -0,0 +1,104 @@ +"""Test Diagnostic methods""" +import importlib + +# pylint: disable=redefined-outer-name, no-member, too-many-public-methods +import numpy as np +import pytest + +from ...data import load_arviz_data +from ..helpers import running_on_ci +from ...rcparams import rcParams +from ...stats import bfmi, geweke, mcse, rhat +from ...stats.diagnostics import _mc_error, ks_summary +from ...utils import Numba +from .test_diagnostics import data # pylint: disable=unused-import + + +pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name + (importlib.util.find_spec("numba") is None) & ~running_on_ci(), + reason="test requires numba which is not installed", +) + +rcParams["data.load"] = "eager" + + +def test_numba_bfmi(): + """Numba test for bfmi.""" + state = Numba.numba_flag + school = load_arviz_data("centered_eight") + data_md = np.random.rand(100, 100, 10) + Numba.disable_numba() + non_numba = bfmi(school.posterior["mu"].values) + non_numba_md = bfmi(data_md) + Numba.enable_numba() + with_numba = bfmi(school.posterior["mu"].values) + with_numba_md = bfmi(data_md) + assert np.allclose(non_numba_md, with_numba_md) + assert np.allclose(with_numba, non_numba) + assert state == Numba.numba_flag + + +@pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity")) +def test_numba_rhat(method): + """Numba test for mcse.""" + state = Numba.numba_flag + school = np.random.rand(100, 100) + Numba.disable_numba() + non_numba = rhat(school, method=method) + Numba.enable_numba() + with_numba = rhat(school, method=method) + assert np.allclose(with_numba, non_numba) + assert Numba.numba_flag == state + + +@pytest.mark.parametrize("method", ("mean", "sd", "quantile")) +def test_numba_mcse(method, prob=None): + """Numba test for mcse.""" + state = Numba.numba_flag + school = np.random.rand(100, 100) + if method == "quantile": + prob = 0.80 + Numba.disable_numba() + non_numba = mcse(school, method=method, prob=prob) + Numba.enable_numba() + with_numba = mcse(school, method=method, prob=prob) + assert np.allclose(with_numba, non_numba) + assert Numba.numba_flag == state + + +def test_ks_summary_numba(): + """Numba test for ks_summary.""" + state = Numba.numba_flag + data = np.random.randn(100, 100) + Numba.disable_numba() + non_numba = (ks_summary(data)["Count"]).values + Numba.enable_numba() + with_numba = (ks_summary(data)["Count"]).values + assert np.allclose(non_numba, with_numba) + assert Numba.numba_flag == state + + +def test_geweke_numba(): + """Numba test for geweke.""" + state = Numba.numba_flag + data = np.random.randn(100) + Numba.disable_numba() + non_numba = geweke(data) + Numba.enable_numba() + with_numba = geweke(data) + assert np.allclose(non_numba, with_numba) + assert Numba.numba_flag == state + + +@pytest.mark.parametrize("batches", (1, 20)) +@pytest.mark.parametrize("circular", (True, False)) +def test_mcse_error_numba(batches, circular): + """Numba test for mcse_error.""" + data = np.random.randn(100, 100) + state = Numba.numba_flag + Numba.disable_numba() + non_numba = _mc_error(data, batches=batches, circular=circular) + Numba.enable_numba() + with_numba = _mc_error(data, batches=batches, circular=circular) + assert np.allclose(non_numba, with_numba) + assert state == Numba.numba_flag diff --git a/arviz/tests/base_tests/test_helpers.py b/arviz/tests/base_tests/test_helpers.py new file mode 100644 index 0000000000..a7d82be96f --- /dev/null +++ b/arviz/tests/base_tests/test_helpers.py @@ -0,0 +1,18 @@ +import pytest +from _pytest.outcomes import Skipped + +from ..helpers import importorskip + + +def test_importorskip_local(monkeypatch): + """Test ``importorskip`` run on local machine with non-existent module, which should skip.""" + monkeypatch.delenv("ARVIZ_CI_MACHINE", raising=False) + with pytest.raises(Skipped): + importorskip("non-existent-function") + + +def test_importorskip_ci(monkeypatch): + """Test ``importorskip`` run on CI machine with non-existent module, which should fail.""" + monkeypatch.setenv("ARVIZ_CI_MACHINE", 1) + with pytest.raises(ModuleNotFoundError): + importorskip("non-existent-function") diff --git a/arviz/tests/base_tests/test_plot_utils.py b/arviz/tests/base_tests/test_plot_utils.py index e11acf3f53..9250f84ae5 100644 --- a/arviz/tests/base_tests/test_plot_utils.py +++ b/arviz/tests/base_tests/test_plot_utils.py @@ -1,19 +1,22 @@ # pylint: disable=redefined-outer-name +import importlib + import numpy as np -import xarray as xr import pytest +import xarray as xr from ...data import from_dict +from ..helpers import running_on_ci from ...plots.plot_utils import ( - make_2d, - xarray_to_ndarray, - xarray_var_iter, - get_bins, - get_coords, filter_plotters_list, format_sig_figs, + get_bins, + get_coords, get_plotting_function, + make_2d, matplotlib_kwarg_dealiaser, + xarray_to_ndarray, + xarray_var_iter, ) from ...rcparams import rc_context @@ -194,6 +197,10 @@ def test_filter_plotter_list_warning(): assert len(plotters_filtered) == 5 +@pytest.mark.skipif( + (importlib.util.find_spec("bokeh") is None) & ~running_on_ci(), + reason="test requires bokeh which is not installed", +) def test_bokeh_import(): """Tests that correct method is returned on bokeh import""" plot = get_plotting_function("plot_dist", "distplot", "bokeh") diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index e961f3d8bc..b1452cba6e 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -1,21 +1,22 @@ """Tests use the 'bokeh' backend.""" # pylint: disable=redefined-outer-name,too-many-lines from copy import deepcopy -import bokeh.plotting as bkp -from pandas import DataFrame + import numpy as np import pytest +from pandas import DataFrame # pylint: disable=wrong-import-position -from ...data import from_dict, load_arviz_data -from ..helpers import ( # pylint: disable=unused-import +from ...data import from_dict, load_arviz_data # pylint: disable=wrong-import-position +from ..helpers import ( # pylint: disable=unused-import, wrong-import-position + create_model, + create_multidimensional_model, eight_schools_params, + importorskip, models, - create_model, multidim_models, - create_multidimensional_model, ) -from ...rcparams import rcParams, rc_context -from ...plots import ( +from ...rcparams import rc_context, rcParams # pylint: disable=wrong-import-position +from ...plots import ( # pylint: disable=wrong-import-position plot_autocorr, plot_compare, plot_density, @@ -31,14 +32,18 @@ plot_loo_pit, plot_mcse, plot_pair, - plot_rank, - plot_trace, plot_parallel, plot_posterior, plot_ppc, + plot_rank, + plot_trace, plot_violin, ) -from ...stats import compare, loo, waic +from ...stats import compare, loo, waic # pylint: disable=wrong-import-position + +# Skip tests if bokeh not installed +bkp = importorskip("bokeh.plotting") # pylint: disable=invalid-name + rcParams["data.load"] = "eager" diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index cd2af18a0a..d6143b6786 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -153,12 +153,8 @@ def test_plot_trace(models, kwargs): assert axes.shape -@pytest.mark.parametrize( - "compact", [True, False], -) -@pytest.mark.parametrize( - "combined", [True, False], -) +@pytest.mark.parametrize("compact", [True, False]) +@pytest.mark.parametrize("combined", [True, False]) def test_plot_trace_legend(compact, combined): idata = load_arviz_data("rugby") axes = plot_trace( diff --git a/arviz/tests/base_tests/test_stats.py b/arviz/tests/base_tests/test_stats.py index ab38aea502..b1cebd47ff 100644 --- a/arviz/tests/base_tests/test_stats.py +++ b/arviz/tests/base_tests/test_stats.py @@ -1,30 +1,28 @@ # pylint: disable=redefined-outer-name, no-member from copy import deepcopy + import numpy as np -from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal import pytest +from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal from scipy.stats import linregress -from xarray import Dataset, DataArray +from xarray import DataArray, Dataset - -from ...data import load_arviz_data, from_dict, convert_to_inference_data, concat +from ...data import concat, convert_to_inference_data, from_dict, load_arviz_data +from ...rcparams import rcParams from ...stats import ( + apply_test_function, compare, + ess, hpd, loo, - r2_score, - waic, + loo_pit, psislw, + r2_score, summary, - loo_pit, - ess, - apply_test_function, + waic, ) from ...stats.stats import _gpinv -from ...utils import Numba from ..helpers import check_multiple_attrs, multidim_models # pylint: disable=unused-import -from ...rcparams import rcParams - rcParams["data.load"] = "eager" @@ -154,16 +152,6 @@ def test_summary_var_names(centered_eight, var_names_expected): assert len(summary_df.index) == expected -@pytest.mark.parametrize("include_circ", [True, False]) -def test_summary_include_circ(centered_eight, include_circ): - assert summary(centered_eight, include_circ=include_circ) is not None - state = Numba.numba_flag - Numba.disable_numba() - assert summary(centered_eight, include_circ=include_circ) is not NotImplementedError - Numba.enable_numba() - assert state == Numba.numba_flag - - METRICS_NAMES = [ "mean", "sd", @@ -630,21 +618,3 @@ def test_apply_test_function_should_overwrite_error(centered_eight): """Test error when overwrite=False but out_name is already a present variable.""" with pytest.raises(ValueError, match="Should overwrite"): apply_test_function(centered_eight, lambda y, theta: y, out_name_data="obs") - - -def test_numba_stats(): - """Numba test for r2_score""" - state = Numba.numba_flag # Store the current state of Numba - set_1 = np.random.randn(100, 100) - set_2 = np.random.randn(100, 100) - set_3 = np.random.rand(100) - set_4 = np.random.rand(100) - Numba.disable_numba() - non_numba = r2_score(set_1, set_2) - non_numba_one_dimensional = r2_score(set_3, set_4) - Numba.enable_numba() - with_numba = r2_score(set_1, set_2) - with_numba_one_dimensional = r2_score(set_3, set_4) - assert state == Numba.numba_flag # Ensure that inital state = final state - assert np.allclose(non_numba, with_numba) - assert np.allclose(non_numba_one_dimensional, with_numba_one_dimensional) diff --git a/arviz/tests/base_tests/test_stats_numba.py b/arviz/tests/base_tests/test_stats_numba.py new file mode 100644 index 0000000000..a2d135fc68 --- /dev/null +++ b/arviz/tests/base_tests/test_stats_numba.py @@ -0,0 +1,50 @@ +# pylint: disable=redefined-outer-name, no-member +import importlib + +import numpy as np +import pytest + +from ...rcparams import rcParams +from ...stats import r2_score, summary +from ...utils import Numba +from ..helpers import ( # pylint: disable=unused-import + check_multiple_attrs, + multidim_models, + running_on_ci, +) +from .test_stats import centered_eight, non_centered_eight # pylint: disable=unused-import + +pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name + (importlib.util.find_spec("numba") is None) & ~running_on_ci(), + reason="test requires numba which is not installed", +) + +rcParams["data.load"] = "eager" + + +@pytest.mark.parametrize("include_circ", [True, False]) +def test_summary_include_circ(centered_eight, include_circ): + assert summary(centered_eight, include_circ=include_circ) is not None + state = Numba.numba_flag + Numba.disable_numba() + assert summary(centered_eight, include_circ=include_circ) is not NotImplementedError + Numba.enable_numba() + assert state == Numba.numba_flag + + +def test_numba_stats(): + """Numba test for r2_score""" + state = Numba.numba_flag # Store the current state of Numba + set_1 = np.random.randn(100, 100) + set_2 = np.random.randn(100, 100) + set_3 = np.random.rand(100) + set_4 = np.random.rand(100) + Numba.disable_numba() + non_numba = r2_score(set_1, set_2) + non_numba_one_dimensional = r2_score(set_3, set_4) + Numba.enable_numba() + with_numba = r2_score(set_1, set_2) + with_numba_one_dimensional = r2_score(set_3, set_4) + assert state == Numba.numba_flag # Ensure that inital state = final state + assert np.allclose(non_numba, with_numba) + assert np.allclose(non_numba_one_dimensional, with_numba_one_dimensional) diff --git a/arviz/tests/base_tests/test_utils.py b/arviz/tests/base_tests/test_utils.py index 6160d808f4..2c9456bd32 100644 --- a/arviz/tests/base_tests/test_utils.py +++ b/arviz/tests/base_tests/test_utils.py @@ -3,15 +3,11 @@ """ # pylint: disable=redefined-outer-name, no-member from unittest.mock import Mock -import importlib import numpy as np import pytest from ...utils import ( _var_names, - numba_check, - Numba, - _numba_var, _stack, one_de, two_de, @@ -19,7 +15,6 @@ flatten_inference_data_to_dict, ) from ...data import load_arviz_data, from_dict -from ...stats.stats_utils import stats_variance_2d as svar @pytest.fixture(scope="session") @@ -81,16 +76,6 @@ def utils_with_numba_import_fail(monkeypatch): return utils -def test_utils_fixture(utils_with_numba_import_fail): - """Test of utils fixture to ensure mock is applied correctly""" - - # If Numba doesn't exist in dev environment this will raise an ImportError - import numba # pylint: disable=unused-import,W0612 - - with pytest.raises(ImportError): - utils_with_numba_import_fail.importlib.import_module("numba") - - def test_conditional_jit_decorator_no_numba(utils_with_numba_import_fail): """Tests to see if Numba jit code block is skipped with Import Failure @@ -134,31 +119,6 @@ def func(): assert func() -def test_conditional_jit_numba_decorator_keyword(monkeypatch): - """Checks else statement and JIT keyword argument""" - from arviz import utils - - # Mock import lib to return numba with hit method which returns a function that returns kwargs - numba_mock = Mock() - monkeypatch.setattr(utils.importlib, "import_module", lambda x: numba_mock) - - def jit(**kwargs): - """overwrite numba.jit function""" - return lambda fn: lambda: (fn(), kwargs) - - numba_mock.jit = jit - - @utils.conditional_jit(keyword_argument="A keyword argument") - def placeholder_func(): - """This function does nothing""" - return "output" - - # pylint: disable=unpacking-non-sequence - function_results, wrapper_result = placeholder_func() - assert wrapper_result == {"keyword_argument": "A keyword argument"} - assert function_results == "output" - - def test_conditional_vect_numba_decorator(): """Tests to see if Numba is used. @@ -201,44 +161,6 @@ def placeholder_func(): assert function_results == "output" -def test_numba_check(): - """Test for numba_check""" - numba = importlib.util.find_spec("numba") - flag = numba is not None - assert flag == numba_check() - - -def test_numba_utils(): - """Test for class Numba.""" - flag = Numba.numba_flag - assert flag == numba_check() - Numba.disable_numba() - val = Numba.numba_flag - assert not val - Numba.enable_numba() - val = Numba.numba_flag - assert val - assert flag == Numba.numba_flag - - -@pytest.mark.parametrize("axis", (0, 1)) -@pytest.mark.parametrize("ddof", (0, 1)) -def test_numba_var(axis, ddof): - """Method to test numba_var.""" - flag = Numba.numba_flag - data_1 = np.random.randn(100, 100) - data_2 = np.random.rand(100) - with_numba_1 = _numba_var(svar, np.var, data_1, axis=axis, ddof=ddof) - with_numba_2 = _numba_var(svar, np.var, data_2, ddof=ddof) - Numba.disable_numba() - non_numba_1 = _numba_var(svar, np.var, data_1, axis=axis, ddof=ddof) - non_numba_2 = _numba_var(svar, np.var, data_2, ddof=ddof) - Numba.enable_numba() - assert flag == Numba.numba_flag - assert np.allclose(with_numba_1, non_numba_1) - assert np.allclose(with_numba_2, non_numba_2) - - def test_stack(): x = np.random.randn(10, 4, 6) y = np.random.randn(100, 4, 6) diff --git a/arviz/tests/base_tests/test_utils_numba.py b/arviz/tests/base_tests/test_utils_numba.py new file mode 100644 index 0000000000..81b59a1381 --- /dev/null +++ b/arviz/tests/base_tests/test_utils_numba.py @@ -0,0 +1,97 @@ +""" +Tests for arviz.utils. +""" +import importlib + +# pylint: disable=redefined-outer-name, no-member +from unittest.mock import Mock + +import numpy as np +import pytest + +from ..helpers import running_on_ci +from ...stats.stats_utils import stats_variance_2d as svar +from ...utils import ( + Numba, + _numba_var, + numba_check, +) +from .test_utils import utils_with_numba_import_fail # pylint: disable=unused-import + +pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name + (importlib.util.find_spec("numba") is None) & ~running_on_ci(), + reason="test requires numba which is not installed", +) + + +def test_utils_fixture(utils_with_numba_import_fail): + """Test of utils fixture to ensure mock is applied correctly""" + + # If Numba doesn't exist in dev environment this will raise an ImportError + import numba # pylint: disable=unused-import,W0612 + + with pytest.raises(ImportError): + utils_with_numba_import_fail.importlib.import_module("numba") + + +def test_conditional_jit_numba_decorator_keyword(monkeypatch): + """Checks else statement and JIT keyword argument""" + from arviz import utils + + # Mock import lib to return numba with hit method which returns a function that returns kwargs + numba_mock = Mock() + monkeypatch.setattr(utils.importlib, "import_module", lambda x: numba_mock) + + def jit(**kwargs): + """overwrite numba.jit function""" + return lambda fn: lambda: (fn(), kwargs) + + numba_mock.jit = jit + + @utils.conditional_jit(keyword_argument="A keyword argument") + def placeholder_func(): + """This function does nothing""" + return "output" + + # pylint: disable=unpacking-non-sequence + function_results, wrapper_result = placeholder_func() + assert wrapper_result == {"keyword_argument": "A keyword argument"} + assert function_results == "output" + + +def test_numba_check(): + """Test for numba_check""" + numba = importlib.util.find_spec("numba") + flag = numba is not None + assert flag == numba_check() + + +def test_numba_utils(): + """Test for class Numba.""" + flag = Numba.numba_flag + assert flag == numba_check() + Numba.disable_numba() + val = Numba.numba_flag + assert not val + Numba.enable_numba() + val = Numba.numba_flag + assert val + assert flag == Numba.numba_flag + + +@pytest.mark.parametrize("axis", (0, 1)) +@pytest.mark.parametrize("ddof", (0, 1)) +def test_numba_var(axis, ddof): + """Method to test numba_var.""" + flag = Numba.numba_flag + data_1 = np.random.randn(100, 100) + data_2 = np.random.rand(100) + with_numba_1 = _numba_var(svar, np.var, data_1, axis=axis, ddof=ddof) + with_numba_2 = _numba_var(svar, np.var, data_2, ddof=ddof) + Numba.disable_numba() + non_numba_1 = _numba_var(svar, np.var, data_1, axis=axis, ddof=ddof) + non_numba_2 = _numba_var(svar, np.var, data_2, ddof=ddof) + Numba.enable_numba() + assert flag == Numba.numba_flag + assert np.allclose(with_numba_1, non_numba_1) + assert np.allclose(with_numba_2, non_numba_2) diff --git a/arviz/tests/external_tests/test_data_cmdstanpy.py b/arviz/tests/external_tests/test_data_cmdstanpy.py index f49ed4faf7..29380e230f 100644 --- a/arviz/tests/external_tests/test_data_cmdstanpy.py +++ b/arviz/tests/external_tests/test_data_cmdstanpy.py @@ -12,6 +12,7 @@ check_multiple_attrs, draws, eight_schools_params, + importorskip, load_cached_models, pystan_version, ) @@ -32,9 +33,12 @@ def filepaths(self, data_directory): @pytest.fixture(scope="class") def data(self, filepaths): - from cmdstanpy import CmdStanMCMC - from cmdstanpy.stanfit import RunSet - from cmdstanpy.model import CmdStanArgs, SamplerArgs + # Skip tests if cmdstanpy not installed + cmdstanpy = importorskip("cmdstanpy") + CmdStanMCMC = cmdstanpy.CmdStanMCMC # pylint: disable=invalid-name + RunSet = cmdstanpy.stanfit.RunSet # pylint: disable=invalid-name + CmdStanArgs = cmdstanpy.model.CmdStanArgs # pylint: disable=invalid-name + SamplerArgs = cmdstanpy.model.SamplerArgs # pylint: disable=invalid-name class Data: args = CmdStanArgs( diff --git a/arviz/tests/external_tests/test_data_emcee.py b/arviz/tests/external_tests/test_data_emcee.py index 62a33538a4..324e30dc24 100644 --- a/arviz/tests/external_tests/test_data_emcee.py +++ b/arviz/tests/external_tests/test_data_emcee.py @@ -3,10 +3,8 @@ import numpy as np import pytest -import emcee # pylint: disable=unused-import - -from arviz import from_emcee -from ..helpers import ( # pylint: disable=unused-import +from arviz import from_emcee # pylint: disable=wrong-import-position +from ..helpers import ( # pylint: disable=unused-import, wrong-import-position chains, check_multiple_attrs, draws, @@ -15,8 +13,12 @@ needs_emcee3_func, eight_schools_params, load_cached_models, + importorskip, ) +# Skip all tests if emcee not installed +emcee = importorskip("emcee") + needs_emcee3 = needs_emcee3_func() diff --git a/arviz/tests/external_tests/test_data_numpyro.py b/arviz/tests/external_tests/test_data_numpyro.py index 23e53822cb..2719489043 100644 --- a/arviz/tests/external_tests/test_data_numpyro.py +++ b/arviz/tests/external_tests/test_data_numpyro.py @@ -1,18 +1,22 @@ # pylint: disable=no-member, invalid-name, redefined-outer-name import numpy as np import pytest -from jax.random import PRNGKey -from numpyro.infer import Predictive -from ...data.io_numpyro import from_numpyro -from ..helpers import ( # pylint: disable=unused-import +from ...data.io_numpyro import from_numpyro # pylint: disable=wrong-import-position +from ..helpers import ( # pylint: disable=unused-import, wrong-import-position chains, check_multiple_attrs, draws, eight_schools_params, + importorskip, load_cached_models, ) +# Skip all tests if jax or numpyro not installed +jax = importorskip("jax") +PRNGKey = jax.random.PRNGKey +numpyro = importorskip("numpyro") +Predictive = numpyro.infer.Predictive class TestDataNumPyro: @pytest.fixture(scope="class") diff --git a/arviz/tests/external_tests/test_data_pymc.py b/arviz/tests/external_tests/test_data_pymc.py index 15381bcfbc..56ae89f28f 100644 --- a/arviz/tests/external_tests/test_data_pymc.py +++ b/arviz/tests/external_tests/test_data_pymc.py @@ -1,22 +1,25 @@ # pylint: disable=no-member, invalid-name, redefined-outer-name from sys import version_info -from typing import Tuple, Dict -import pytest - +from typing import Dict, Tuple import numpy as np +import pytest from numpy import ma -import pymc3 as pm -from arviz import from_pymc3, from_pymc3_predictions, InferenceData -from ..helpers import ( # pylint: disable=unused-import +from arviz import from_pymc3, from_pymc3_predictions, InferenceData # pylint: disable=wrong-import-position + +from ..helpers import ( # pylint: disable=unused-import, wrong-import-position chains, check_multiple_attrs, draws, eight_schools_params, + importorskip, load_cached_models, ) +# Skip all tests if pymc3 not installed +pm = importorskip("pymc3") + class TestDataPyMC3: @pytest.fixture(scope="class") diff --git a/arviz/tests/external_tests/test_data_pyro.py b/arviz/tests/external_tests/test_data_pyro.py index d620a17cf2..e21eb435c0 100644 --- a/arviz/tests/external_tests/test_data_pyro.py +++ b/arviz/tests/external_tests/test_data_pyro.py @@ -2,19 +2,22 @@ import numpy as np import packaging import pytest -import torch -import pyro -from pyro.infer import Predictive -from ...data.io_pyro import from_pyro -from ..helpers import ( # pylint: disable=unused-import +from ...data.io_pyro import from_pyro # pylint: disable=wrong-import-position +from ..helpers import ( # pylint: disable=unused-import, wrong-import-position chains, check_multiple_attrs, draws, eight_schools_params, + importorskip, load_cached_models, ) +# Skip all tests if pyro or pytorch not installed +torch = importorskip("torch") +pyro = importorskip("pyro") +Predictive = pyro.infer.Predictive + class TestDataPyro: @pytest.fixture(scope="class") diff --git a/arviz/tests/external_tests/test_data_pystan.py b/arviz/tests/external_tests/test_data_pystan.py index 5f0bee08f5..7310a534f0 100644 --- a/arviz/tests/external_tests/test_data_pystan.py +++ b/arviz/tests/external_tests/test_data_pystan.py @@ -1,17 +1,31 @@ # pylint: disable=no-member, invalid-name, redefined-outer-name +import importlib from collections import OrderedDict + import numpy as np import pytest from arviz import from_pystan + from ...data.io_pystan import get_draws, get_draws_stan3 # pylint: disable=unused-import from ..helpers import ( # pylint: disable=unused-import chains, check_multiple_attrs, draws, eight_schools_params, + importorskip, load_cached_models, pystan_version, + running_on_ci, +) + +# Check if either pystan or pystan3 is installed +pystan_installed = (importlib.util.find_spec("pystan") is not None) or ( + importlib.util.find_spec("stan") is not None +) +pytestmark = pytest.mark.skipif( + not (pystan_installed | running_on_ci()), + reason="test requires pystan/pystan3 which is not installed", ) @@ -217,7 +231,8 @@ def test_get_draws(self, data): @pytest.mark.skipif(pystan_version() != 2, reason="PyStan 2.x required") def test_index_order(self, data, eight_schools_params): """Test 0-indexed data.""" - import pystan # pylint: disable=import-error + # Skip test if pystan not installed + pystan = importorskip("pystan") # pylint: disable=import-error fit = data.model.sampling(data=eight_schools_params) if pystan.__version__ >= "2.18": diff --git a/arviz/tests/external_tests/test_data_tfp.py b/arviz/tests/external_tests/test_data_tfp.py index 9191666639..880d8501d3 100644 --- a/arviz/tests/external_tests/test_data_tfp.py +++ b/arviz/tests/external_tests/test_data_tfp.py @@ -8,9 +8,13 @@ check_multiple_attrs, draws, eight_schools_params, + importorskip, load_cached_models, ) +# Skip all tests if tensorflow_probability not installed +importorskip("tensorflow_probability") + class TestDataTfp: @pytest.fixture(scope="class") diff --git a/arviz/tests/helpers.py b/arviz/tests/helpers.py index 1d7b565514..ddc31b505c 100644 --- a/arviz/tests/helpers.py +++ b/arviz/tests/helpers.py @@ -2,16 +2,18 @@ """Test helper functions.""" import gzip import importlib +import logging import os import pickle import sys -import logging -from typing import Dict, List, Tuple, Union -import pytest -import numpy as np +from typing import Any, Dict, List, Optional, Tuple, Union -from ..data import from_dict, InferenceData +import numpy as np +import pytest +from _pytest.outcomes import Skipped +from packaging.version import Version +from ..data import InferenceData, from_dict _log = logging.getLogger(__name__) @@ -571,11 +573,70 @@ def pystan_version(): """ try: import pystan # pylint: disable=import-error + + version = int(pystan.__version__[0]) except ImportError: - import stan as pystan # pylint: disable=import-error - return int(pystan.__version__[0]) + try: + import stan as pystan # pylint: disable=import-error + + version = int(pystan.__version__[0]) + except ImportError: + version = None + return version def test_precompile_models(eight_schools_params, draws, chains): """Precompile model files.""" load_cached_models(eight_schools_params, draws, chains) + + +def running_on_ci() -> bool: + """Return True if running on CI machine.""" + return os.environ.get("ARVIZ_CI_MACHINE") is not None + + +def importorskip( + modname: str, minversion: Optional[str] = None, reason: Optional[str] = None +) -> Any: + """Import and return the requested module ``modname``. + + Doesn't allow skips on CI machine. + Borrowed and modified from ``pytest.importorskip``. + :param str modname: the name of the module to import + :param str minversion: if given, the imported module's ``__version__`` + attribute must be at least this minimal version, otherwise the test is + still skipped. + :param str reason: if given, this reason is shown as the message when the + module cannot be imported. + :returns: The imported module. This should be assigned to its canonical + name. + Example:: + docutils = pytest.importorskip("docutils") + """ + # ARVIZ_CI_MACHINE is True if tests run on CI, where ARVIZ_CI_MACHINE env variable exists + ARVIZ_CI_MACHINE = running_on_ci() + if ARVIZ_CI_MACHINE: + import warnings + + compile(modname, "", "eval") # to catch syntaxerrors + + with warnings.catch_warnings(): + # make sure to ignore ImportWarnings that might happen because + # of existing directories with the same name we're trying to + # import but without a __init__.py file + warnings.simplefilter("ignore") + __import__(modname) + mod = sys.modules[modname] + if minversion is None: + return mod + verattr = getattr(mod, "__version__", None) + if minversion is not None: + if verattr is None or Version(verattr) < Version(minversion): + raise Skipped( + "module %r has __version__ %r, required is: %r" + % (modname, verattr, minversion), + allow_module_level=True, + ) + return mod + else: + return pytest.importorskip(modname=modname, minversion=minversion, reason=reason) diff --git a/pytest.ini b/pytest.ini index 14b1ba1631..09e578bdc3 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,5 @@ [pytest] -addopts = --strict -rf --durations=20 -p no:warnings +addopts = --strict -rsf --durations=20 -p no:warnings console_output_style = count junit_family= xunit1 markers =