diff --git a/docs/index.md b/docs/index.md index d15a624..f7c5157 100644 --- a/docs/index.md +++ b/docs/index.md @@ -5,6 +5,7 @@ :hidden: self api-reference.md +testing-utils.md contributing.md contributors.md ``` diff --git a/docs/testing-utils.md b/docs/testing-utils.md new file mode 100644 index 0000000..49aeb30 --- /dev/null +++ b/docs/testing-utils.md @@ -0,0 +1,14 @@ +# Testing Utilities + +These additional functions are meant to be used while unit testing Array API +compliant packages: + +```{eval-rst} +.. currentmodule:: array_api_extra.testing +.. autosummary:: + :nosignatures: + :toctree: generated + + lazy_xp_function + patch_lazy_xp_functions +``` diff --git a/pixi.lock b/pixi.lock index f82df2e..00c8fee 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1609,14 +1609,21 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/certifi-2024.12.14-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/cffi-1.17.1-py312h06ac9bb_0.conda - conda: https://prefix.dev/conda-forge/noarch/charset-normalizer-3.4.1-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/click-8.1.8-pyh707e725_0.conda + - conda: https://prefix.dev/conda-forge/noarch/cloudpickle-3.1.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/dask-core-2025.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/docutils-0.21.2-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/exceptiongroup-1.2.2-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/fsspec-2024.12.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/furo-2024.8.6-pyhd8ed1ab_2.conda - conda: https://prefix.dev/conda-forge/noarch/h2-4.1.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/hpack-4.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/hyperframe-6.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/idna-3.10-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.6.1-pyha770c72_0.conda + - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.5-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/ld_impl_linux-64-2.43-h712a8e2_2.conda - conda: https://prefix.dev/conda-forge/linux-64/libexpat-2.6.4-h5888daf_0.conda @@ -1632,6 +1639,7 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda - conda: https://prefix.dev/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/llvm-openmp-19.1.7-h024ca30_0.conda + - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/markupsafe-3.0.2-py312h178313f_1.conda - conda: https://prefix.dev/conda-forge/noarch/mdit-py-plugins-0.4.2-pyhd8ed1ab_1.conda @@ -1640,9 +1648,12 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/ncurses-6.5-h2d0b736_2.conda - conda: https://prefix.dev/conda-forge/linux-64/openssl-3.4.0-h7b32b05_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-24.2-pyhd8ed1ab_2.conda + - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.5.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/pycparser-2.22-pyh29332c3_1.conda - conda: https://prefix.dev/conda-forge/noarch/pygments-2.19.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pysocks-1.7.1-pyha55dd90_7.conda + - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.4-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/python-3.12.8-h9e4cc4f_1_cpython.conda - conda: https://prefix.dev/conda-forge/linux-64/python_abi-3.12-5_cp312.conda - conda: https://prefix.dev/conda-forge/noarch/pytz-2024.1-pyhd8ed1ab_0.conda @@ -1663,9 +1674,13 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/typing-extensions-4.12.2-hd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.3.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/yaml-0.2.5-h7f98852_2.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py312hef9b889_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda - pypi: . @@ -1680,14 +1695,21 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/certifi-2024.12.14-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/cffi-1.17.1-py312h0fad829_0.conda - conda: https://prefix.dev/conda-forge/noarch/charset-normalizer-3.4.1-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/click-8.1.8-pyh707e725_0.conda + - conda: https://prefix.dev/conda-forge/noarch/cloudpickle-3.1.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/dask-core-2025.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/docutils-0.21.2-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/exceptiongroup-1.2.2-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/fsspec-2024.12.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/furo-2024.8.6-pyhd8ed1ab_2.conda - conda: https://prefix.dev/conda-forge/noarch/h2-4.1.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/hpack-4.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/hyperframe-6.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/idna-3.10-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.6.1-pyha770c72_0.conda + - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.5-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libcxx-19.1.7-ha82da77_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libexpat-2.6.4-h286801f_0.conda @@ -1695,6 +1717,7 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/liblzma-5.6.3-h39f12f2_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libsqlite-3.48.0-h3f77e49_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libzlib-1.3.1-h8359307_2.conda + - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/markupsafe-3.0.2-py312h998013c_1.conda - conda: https://prefix.dev/conda-forge/noarch/mdit-py-plugins-0.4.2-pyhd8ed1ab_1.conda @@ -1703,9 +1726,12 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/ncurses-6.5-h5e97a16_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/openssl-3.4.0-h81ee809_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-24.2-pyhd8ed1ab_2.conda + - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.5.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/pycparser-2.22-pyh29332c3_1.conda - conda: https://prefix.dev/conda-forge/noarch/pygments-2.19.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pysocks-1.7.1-pyha55dd90_7.conda + - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.4-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/python-3.12.8-hc22306f_1_cpython.conda - conda: https://prefix.dev/conda-forge/osx-arm64/python_abi-3.12-5_cp312.conda - conda: https://prefix.dev/conda-forge/noarch/pytz-2024.1-pyhd8ed1ab_0.conda @@ -1726,9 +1752,13 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/typing-extensions-4.12.2-hd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.3.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/yaml-0.2.5-h3422bc3_2.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py312h15fbf35_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda - pypi: . @@ -1743,20 +1773,28 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/certifi-2024.12.14-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/cffi-1.17.1-py312h4389bb4_0.conda - conda: https://prefix.dev/conda-forge/noarch/charset-normalizer-3.4.1-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/click-8.1.8-pyh7428d3b_0.conda + - conda: https://prefix.dev/conda-forge/noarch/cloudpickle-3.1.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/dask-core-2025.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/docutils-0.21.2-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/exceptiongroup-1.2.2-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/fsspec-2024.12.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/furo-2024.8.6-pyhd8ed1ab_2.conda - conda: https://prefix.dev/conda-forge/noarch/h2-4.1.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/hpack-4.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/hyperframe-6.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/idna-3.10-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.6.1-pyha770c72_0.conda + - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.5-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/win-64/libexpat-2.6.4-he0c23c2_0.conda - conda: https://prefix.dev/conda-forge/win-64/libffi-3.4.2-h8ffe710_5.tar.bz2 - conda: https://prefix.dev/conda-forge/win-64/liblzma-5.6.3-h2466b09_1.conda - conda: https://prefix.dev/conda-forge/win-64/libsqlite-3.48.0-h67fdade_1.conda - conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda + - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/noarch/markdown-it-py-3.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/markupsafe-3.0.2-py312h31fea79_1.conda - conda: https://prefix.dev/conda-forge/noarch/mdit-py-plugins-0.4.2-pyhd8ed1ab_1.conda @@ -1764,9 +1802,12 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/myst-parser-4.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/openssl-3.4.0-ha4e3fda_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-24.2-pyhd8ed1ab_2.conda + - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.5.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/pycparser-2.22-pyh29332c3_1.conda - conda: https://prefix.dev/conda-forge/noarch/pygments-2.19.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/pysocks-1.7.1-pyh09c184e_7.conda + - conda: https://prefix.dev/conda-forge/noarch/pytest-8.3.4-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/python-3.12.8-h3f84c4b_1_cpython.conda - conda: https://prefix.dev/conda-forge/win-64/python_abi-3.12-5_cp312.conda - conda: https://prefix.dev/conda-forge/noarch/pytz-2024.1-pyhd8ed1ab_0.conda @@ -1786,6 +1827,9 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/typing-extensions-4.12.2-hd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.3.0-pyhd8ed1ab_0.conda @@ -1794,6 +1838,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hfef2bbc_24.conda - conda: https://prefix.dev/conda-forge/noarch/win_inet_pton-1.1.0-pyh7428d3b_8.conda - conda: https://prefix.dev/conda-forge/win-64/yaml-0.2.5-h8ffe710_2.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py312h7606c53_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda - pypi: . @@ -1821,12 +1866,16 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/cffi-1.17.1-py312h06ac9bb_0.conda - conda: https://prefix.dev/conda-forge/noarch/cfgv-3.3.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/charset-normalizer-3.4.1-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/click-8.1.8-pyh707e725_0.conda + - conda: https://prefix.dev/conda-forge/noarch/cloudpickle-3.1.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/dask-core-2025.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/dill-0.3.9-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/distlib-0.3.9-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/docutils-0.21.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/exceptiongroup-1.2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/filelock-3.17.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/fsspec-2024.12.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/h2-4.1.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/hpack-4.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/hyperframe-6.1.0-pyhd8ed1ab_0.conda @@ -1834,6 +1883,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/identify-2.6.6-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/idna-3.10-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.6.1-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/isort-5.13.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.5-pyhd8ed1ab_0.conda @@ -1858,6 +1908,7 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/libxml2-2.13.5-h8d12d68_1.conda - conda: https://prefix.dev/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/llvm-openmp-19.1.7-h024ca30_0.conda + - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/linux-64/markupsafe-3.0.2-py312h178313f_1.conda - conda: https://prefix.dev/conda-forge/noarch/mccabe-0.7.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/mkl-2024.2.2-ha957f24_16.conda @@ -1870,6 +1921,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/numpydoc-1.8.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/openssl-3.4.0-h7b32b05_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-24.2-pyhd8ed1ab_2.conda + - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/platformdirs-4.3.6-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.5.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/pre-commit-4.1.0-pyha770c72_0.conda @@ -1898,6 +1950,7 @@ environments: - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomlkit-0.13.2-pyha770c72_1.conda + - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/typing-extensions-4.12.2-hd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda @@ -1905,6 +1958,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.3.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/virtualenv-20.29.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/linux-64/yaml-0.2.5-h7f98852_2.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda - conda: https://prefix.dev/conda-forge/linux-64/zstandard-0.23.0-py312hef9b889_1.conda - conda: https://prefix.dev/conda-forge/linux-64/zstd-1.5.6-ha6fb4c9_0.conda @@ -1925,12 +1979,16 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/cffi-1.17.1-py312h0fad829_0.conda - conda: https://prefix.dev/conda-forge/noarch/cfgv-3.3.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/charset-normalizer-3.4.1-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/click-8.1.8-pyh707e725_0.conda + - conda: https://prefix.dev/conda-forge/noarch/cloudpickle-3.1.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/dask-core-2025.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/dill-0.3.9-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/distlib-0.3.9-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/docutils-0.21.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/exceptiongroup-1.2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/filelock-3.17.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/fsspec-2024.12.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/h2-4.1.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/hpack-4.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/hyperframe-6.1.0-pyhd8ed1ab_0.conda @@ -1938,6 +1996,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/identify-2.6.6-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/idna-3.10-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.6.1-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/isort-5.13.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/jinja2-3.1.5-pyhd8ed1ab_0.conda @@ -1955,6 +2014,7 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/libuv-1.50.0-h5505292_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/libzlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/llvm-openmp-19.1.7-hdb05f8b_0.conda + - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/osx-arm64/markupsafe-3.0.2-py312h998013c_1.conda - conda: https://prefix.dev/conda-forge/noarch/mccabe-0.7.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/mypy_extensions-1.0.0-pyha770c72_1.conda @@ -1966,6 +2026,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/numpydoc-1.8.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/openssl-3.4.0-h81ee809_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-24.2-pyhd8ed1ab_2.conda + - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/platformdirs-4.3.6-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.5.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/pre-commit-4.1.0-pyha770c72_0.conda @@ -1993,6 +2054,7 @@ environments: - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomlkit-0.13.2-pyha770c72_1.conda + - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/typing-extensions-4.12.2-hd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda @@ -2000,6 +2062,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.3.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/virtualenv-20.29.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/osx-arm64/yaml-0.2.5-h3422bc3_2.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zlib-1.3.1-h8359307_2.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstandard-0.23.0-py312h15fbf35_1.conda - conda: https://prefix.dev/conda-forge/osx-arm64/zstd-1.5.6-hb46c0d2_0.conda @@ -2020,18 +2083,23 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/cffi-1.17.1-py312h4389bb4_0.conda - conda: https://prefix.dev/conda-forge/noarch/cfgv-3.3.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/charset-normalizer-3.4.1-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/click-8.1.8-pyh7428d3b_0.conda + - conda: https://prefix.dev/conda-forge/noarch/cloudpickle-3.1.1-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/dask-core-2025.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/dill-0.3.9-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/distlib-0.3.9-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/docutils-0.21.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/exceptiongroup-1.2.2-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/filelock-3.17.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/fsspec-2024.12.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/h2-4.1.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/hpack-4.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/hyperframe-6.1.0-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/identify-2.6.6-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/idna-3.10-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/imagesize-1.4.1-pyhd8ed1ab_0.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/importlib-metadata-8.6.1-pyha770c72_0.conda - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/intel-openmp-2024.2.1-h57928b3_1083.conda - conda: https://prefix.dev/conda-forge/noarch/isort-5.13.2-pyhd8ed1ab_1.conda @@ -2048,6 +2116,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/libwinpthread-12.0.0.r4.gg4f2fc60ca-h57928b3_9.conda - conda: https://prefix.dev/conda-forge/win-64/libxml2-2.13.5-he286e8c_1.conda - conda: https://prefix.dev/conda-forge/win-64/libzlib-1.3.1-h2466b09_2.conda + - conda: https://prefix.dev/conda-forge/noarch/locket-1.0.0-pyhd8ed1ab_0.tar.bz2 - conda: https://prefix.dev/conda-forge/win-64/markupsafe-3.0.2-py312h31fea79_1.conda - conda: https://prefix.dev/conda-forge/noarch/mccabe-0.7.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/mkl-2024.2.2-h66d3029_15.conda @@ -2059,6 +2128,7 @@ environments: - conda: https://prefix.dev/conda-forge/noarch/numpydoc-1.8.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/openssl-3.4.0-ha4e3fda_1.conda - conda: https://prefix.dev/conda-forge/noarch/packaging-24.2-pyhd8ed1ab_2.conda + - conda: https://prefix.dev/conda-forge/noarch/partd-1.4.2-pyhd8ed1ab_0.conda - conda: https://prefix.dev/conda-forge/noarch/platformdirs-4.3.6-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.5.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/pre-commit-4.1.0-pyha770c72_0.conda @@ -2086,6 +2156,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/tomlkit-0.13.2-pyha770c72_1.conda + - conda: https://prefix.dev/conda-forge/noarch/toolz-1.0.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/typing-extensions-4.12.2-hd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_1.conda - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025a-h78e105d_0.conda @@ -2098,6 +2169,7 @@ environments: - conda: https://prefix.dev/conda-forge/win-64/vs2015_runtime-14.42.34433-hfef2bbc_24.conda - conda: https://prefix.dev/conda-forge/noarch/win_inet_pton-1.1.0-pyh7428d3b_8.conda - conda: https://prefix.dev/conda-forge/win-64/yaml-0.2.5-h8ffe710_2.tar.bz2 + - conda: https://prefix.dev/conda-forge/noarch/zipp-3.21.0-pyhd8ed1ab_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstandard-0.23.0-py312h7606c53_1.conda - conda: https://prefix.dev/conda-forge/win-64/zstd-1.5.6-h0ea2cb4_0.conda - pypi: . @@ -3695,7 +3767,7 @@ packages: - pypi: . name: array-api-extra version: 0.6.1.dev0 - sha256: 22c9e9830a088aff4480ecea8495d2ebcf91f65596886a12012bebfb238181d6 + sha256: bb6cd89a7f100a73d3f853de571b2f4fff0e70de8df0d113f2f5c1559744e6b6 requires_dist: - array-api-compat>=1.10.0,<2 requires_python: '>=3.10' diff --git a/pyproject.toml b/pyproject.toml index 2fae4f8..d15aba8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ array-api-compat = ">=1.10.0,<2" array-api-extra = { path = ".", editable = true } [tool.pixi.feature.lint.dependencies] +typing-extensions = "*" pre-commit = "*" pylint = "*" basedmypy = "*" @@ -63,6 +64,9 @@ numpydoc = ">=1.8.0,<2" array-api-strict = "*" numpy = "*" pytest = "*" +dask-core = "*" # No distributed, tornado, etc. +# NOTE: don't add jax, pytorch, sparse, cupy here +# as they slow down mypy and are not portable across target OSs [tool.pixi.feature.lint.tasks] pre-commit-install = "pre-commit install" @@ -98,6 +102,10 @@ furo = ">=2023.08.17" myst-parser = ">=0.13" sphinx-copybutton = "*" sphinx-autodoc-typehints = "*" +# Needed to import parsed modules with autodoc +dask-core = "*" +pytest = "*" +typing-extensions = "*" [tool.pixi.feature.docs.tasks] docs = { cmd = "sphinx-build . build/", cwd = "docs" } @@ -180,8 +188,10 @@ markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific b [tool.coverage] run.source = ["array_api_extra"] -report.exclude_also = ['\.\.\.'] - +report.exclude_also = [ + '\.\.\.', + 'if TYPE_CHECKING:', +] # mypy @@ -221,6 +231,8 @@ reportMissingImports = false reportMissingTypeStubs = false # false positives for input validation reportUnreachable = false +# ruff handles this +reportUnusedParameter = false executionEnvironments = [ { root = "tests", reportPrivateUsage = false }, @@ -282,7 +294,10 @@ messages_control.disable = [ "design", # ignore heavily opinionated design checks "fixme", # allow FIXME comments "line-too-long", # ruff handles this + "unused-argument", # ruff handles this "missing-function-docstring", # numpydoc handles this + "import-error", # mypy handles this + "import-outside-toplevel", # optional dependencies ] diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index cc0d055..ac4ae58 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -2,6 +2,7 @@ Testing utilities. Note that this is private API; don't expect it to be stable. +See also ..testing for public testing utilities. """ import math diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py new file mode 100644 index 0000000..e124ed7 --- /dev/null +++ b/src/array_api_extra/testing.py @@ -0,0 +1,262 @@ +""" +Public testing utilities. + +See also _lib._testing for additional private testing utilities. +""" + +# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 +from __future__ import annotations + +from collections.abc import Callable, Iterable, Sequence +from functools import wraps +from types import ModuleType +from typing import TYPE_CHECKING, Any, TypeVar, cast + +import pytest + +from array_api_extra._lib._utils._compat import is_dask_namespace, is_jax_namespace + +__all__ = ["lazy_xp_function", "patch_lazy_xp_functions"] + +if TYPE_CHECKING: + # TODO move ParamSpec outside TYPE_CHECKING + # depends on scikit-learn abandoning Python 3.9 + # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 + from typing import ParamSpec + + from dask.typing import Graph, Key, SchedulerGetCallable + from typing_extensions import override + + P = ParamSpec("P") +else: + SchedulerGetCallable = object + + # Sphinx hacks + class P: # pylint: disable=missing-class-docstring + args: tuple + kwargs: dict + + def override(func: Callable[P, T]) -> Callable[P, T]: + return func + + +T = TypeVar("T") + + +def lazy_xp_function( # type: ignore[no-any-explicit] + func: Callable[..., Any], + *, + allow_dask_compute: int = 0, + jax_jit: bool = True, + static_argnums: int | Sequence[int] | None = None, + static_argnames: str | Iterable[str] | None = None, +) -> None: # numpydoc ignore=GL07 + """ + Tag a function to be tested on lazy backends. + + Tag a function, which must be imported in the test module globals, so that when any + tests defined in the same module are executed with ``xp=jax.numpy`` the function is + replaced with a jitted version of itself, and when it is executed with + ``xp=dask.array`` the function will raise if it attempts to materialize the graph. + This will be later expanded to provide test coverage for other lazy backends. + + In order for the tag to be effective, the test or a fixture must call + :func:`patch_lazy_xp_functions`. + + Parameters + ---------- + func : callable + Function to be tested. + allow_dask_compute : int, optional + Number of times `func` is allowed to internally materialize the Dask graph. This + is typically triggered by ``bool()``, ``float()``, or ``np.asarray()``. + + Set to 1 if you are aware that `func` converts the input parameters to numpy and + want to let it do so at least for the time being, knowing that it is going to be + extremely detrimental for performance. + + If a test needs values higher than 1 to pass, it is a canary that the conversion + to numpy/bool/float is happening multiple times, which translates to multiple + computations of the whole graph. Short of making the function fully lazy, you + should at least add explicit calls to ``np.asarray()`` early in the function. + *Note:* the counter of `allow_dask_compute` resets after each call to `func`, so + a test function that invokes `func` multiple times should still work with this + parameter set to 1. + + Default: 0, meaning that `func` must be fully lazy and never materialize the + graph. + jax_jit : bool, optional + Set to True to replace `func` with ``jax.jit(func)`` after calling the + :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. Set to False + if `func` is only compatible with eager (non-jitted) JAX. Default: True. + static_argnums : int | Sequence[int], optional + Passed to jax.jit. Positional arguments to treat as static (compile-time + constant). Default: infer from `static_argnames` using + `inspect.signature(func)`. + static_argnames : str | Iterable[str], optional + Passed to jax.jit. Named arguments to treat as static (compile-time constant). + Default: infer from `static_argnums` using `inspect.signature(func)`. + + See Also + -------- + patch_lazy_xp_functions : Companion function to call from the test or fixture. + jax.jit : JAX function to compile a function for performance. + + Examples + -------- + In ``test_mymodule.py``:: + + from array_api_extra.testing import lazy_xp_function from mymodule import myfunc + + lazy_xp_function(myfunc) + + def test_myfunc(xp): + a = xp.asarray([1, 2]) + # When xp=jax.numpy, this is the same as `b = jax.jit(myfunc)(a)` + # When xp=dask.array, crash on compute() or persist() + b = myfunc(a) + + Notes + ----- + A test function can circumvent this monkey-patching system by calling `func` as an + attribute of the original module. You need to sanitize your code to make sure this + does not happen. + + Example:: + + import mymodule from mymodule import myfunc + + lazy_xp_function(myfunc) + + def test_myfunc(xp): + a = xp.asarray([1, 2]) b = myfunc(a) # This is jitted when xp=jax.numpy c = + mymodule.myfunc(a) # This is not + """ + func.allow_dask_compute = allow_dask_compute # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] + if jax_jit: + func.lazy_jax_jit_kwargs = { # type: ignore[attr-defined] # pyright: ignore[reportFunctionMemberAccess] + "static_argnums": static_argnums, + "static_argnames": static_argnames, + } + + +def patch_lazy_xp_functions( + request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch, *, xp: ModuleType +) -> None: + """ + Test lazy execution of functions tagged with :func:`lazy_xp_function`. + + If ``xp==jax.numpy``, search for all functions which have been tagged with + :func:`lazy_xp_function` in the globals of the module that defines the current test + and wrap them with :func:`jax.jit`. Unwrap them at the end of the test. + + If ``xp==dask.array``, wrap the functions with a decorator that disables + ``compute()`` and ``persist()``. + + This function should be typically called by your library's `xp` fixture that runs + tests on multiple backends:: + + @pytest.fixture(params=[numpy, array_api_strict, jax.numpy, dask.array]) + def xp(request, monkeypatch): + patch_lazy_xp_functions(request, monkeypatch, xp=request.param) + return request.param + + but it can be otherwise be called by the test itself too. + + Parameters + ---------- + request : pytest.FixtureRequest + Pytest fixture, as acquired by the test itself or by one of its fixtures. + monkeypatch : pytest.MonkeyPatch + Pytest fixture, as acquired by the test itself or by one of its fixtures. + xp : module + Array namespace to be tested. + + See Also + -------- + lazy_xp_function : Tag a function to be tested on lazy backends. + pytest.FixtureRequest : `request` test function parameter. + """ + globals_ = cast("dict[str, Any]", request.module.__dict__) # type: ignore[no-any-explicit] + + if is_dask_namespace(xp): + for name, func in globals_.items(): + n = getattr(func, "allow_dask_compute", None) + if n is not None: + assert isinstance(n, int) + wrapped = _allow_dask_compute(func, n) + monkeypatch.setitem(globals_, name, wrapped) + + elif is_jax_namespace(xp): + import jax + + for name, func in globals_.items(): + kwargs = cast( # type: ignore[no-any-explicit] + "dict[str, Any] | None", getattr(func, "lazy_jax_jit_kwargs", None) + ) + if kwargs is not None: + # suppress unused-ignore to run mypy in -e lint as well as -e dev + wrapped = cast(Callable[..., Any], jax.jit(func, **kwargs)) # type: ignore[no-any-explicit,no-untyped-call,unused-ignore] + monkeypatch.setitem(globals_, name, wrapped) + + +class CountingDaskScheduler(SchedulerGetCallable): + """ + Dask scheduler that counts how many times `dask.compute` is called. + + If the number of times exceeds 'max_count', it raises an error. + This is a wrapper around Dask's own 'synchronous' scheduler. + + Parameters + ---------- + max_count : int + Maximum number of allowed calls to `dask.compute`. + msg : str + Assertion to raise when the count exceeds `max_count`. + """ + + count: int + max_count: int + msg: str + + def __init__(self, max_count: int, msg: str): # numpydoc ignore=GL08 + self.count = 0 + self.max_count = max_count + self.msg = msg + + @override + def __call__(self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any) -> Any: # type: ignore[no-any-decorated,no-any-explicit] # numpydoc ignore=GL08 + import dask + + self.count += 1 + # This should yield a nice traceback to the + # offending line in the user's code + assert self.count <= self.max_count, self.msg + + return dask.get(dsk, keys, **kwargs) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage] + + +def _allow_dask_compute( + func: Callable[P, T], n: int +) -> Callable[P, T]: # numpydoc ignore=PR01,RT01 + """ + Wrap `func` to raise if it attempts to call `dask.compute` more than `n` times. + """ + import dask.config + + func_name = getattr(func, "__name__", str(func)) + n_str = f"only up to {n}" if n else "no" + msg = ( + f"Called `dask.compute()` or `dask.persist()` {n + 1} times, " + f"but {n_str} calls are allowed. Set " + f"`lazy_xp_function({func_name}, allow_dask_compute={n + 1})` " + "to allow for more (but note that this will harm performance). " + ) + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 + scheduler = CountingDaskScheduler(n, msg) + with dask.config.set({"scheduler": scheduler}): + return func(*args, **kwargs) + + return wrapper diff --git a/tests/conftest.py b/tests/conftest.py index 39904ae..4402c06 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ from array_api_extra._lib._utils._compat import array_namespace from array_api_extra._lib._utils._compat import device as get_device from array_api_extra._lib._utils._typing import Device +from array_api_extra.testing import patch_lazy_xp_functions T = TypeVar("T") P = ParamSpec("P") @@ -96,7 +97,9 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 @pytest.fixture -def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03 +def xp( + library: Backend, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch +) -> ModuleType: # numpydoc ignore=PR01,RT03 """ Parameterized fixture that iterates on all libraries. @@ -107,6 +110,9 @@ def xp(library: Backend) -> ModuleType: # numpydoc ignore=PR01,RT03 if library == Backend.NUMPY_READONLY: return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType] xp = pytest.importorskip(library.value) + + patch_lazy_xp_functions(request, monkeypatch, xp=xp) + if library == Backend.JAX: import jax diff --git a/tests/test_at.py b/tests/test_at.py index 4bd09c6..744e3aa 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -1,7 +1,8 @@ +import pickle from collections.abc import Callable, Generator from contextlib import contextmanager from types import ModuleType -from typing import cast +from typing import Any, cast import numpy as np import pytest @@ -11,7 +12,46 @@ from array_api_extra._lib._at import _AtOp from array_api_extra._lib._testing import xp_assert_equal from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array -from array_api_extra._lib._utils._typing import Array +from array_api_extra._lib._utils._typing import Array, Index +from array_api_extra.testing import lazy_xp_function + + +def at_op( # type: ignore[no-any-explicit] + x: Array, + idx: Index, + op: _AtOp, + y: Array | object, + **kwargs: Any, # Test the default copy=None +) -> Array: + """ + Wrapper around at(x, idx).op(y, copy=copy, xp=xp). + + This is a hack to allow wrapping `at()` with `lazy_xp_function`. + For clarity, at() itself works inside jax.jit without hacks; this is + just a workaround for when one wants to apply jax.jit to `at()` directly, + which is not a common use case. + """ + if isinstance(idx, (slice | tuple)): + return _at_op(x, None, pickle.dumps(idx), op, y, **kwargs) + return _at_op(x, idx, None, op, y, **kwargs) + + +def _at_op( # type: ignore[no-any-explicit] + x: Array, + idx: Index | None, + idx_pickle: bytes | None, + op: _AtOp, + y: Array | object, + **kwargs: Any, +) -> Array: + """jitted helper of at_op""" + if idx_pickle: + idx = pickle.loads(idx_pickle) + meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[no-any-explicit] + return meth(y, **kwargs) + + +lazy_xp_function(_at_op, static_argnames=("op", "idx_pickle", "copy", "xp")) @contextmanager @@ -43,7 +83,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: ], ) @pytest.mark.parametrize( - ("op", "arg", "expect"), + ("op", "y", "expect"), [ (_AtOp.SET, 40.0, [10.0, 40.0, 40.0]), (_AtOp.ADD, 40.0, [10.0, 60.0, 70.0]), @@ -55,21 +95,52 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]: (_AtOp.MAX, 25.0, [10.0, 25.0, 30.0]), ], ) +@pytest.mark.parametrize( + ("bool_mask", "shaped_y"), + [ + (False, False), + (False, True), + pytest.param( + True, + False, + marks=( + pytest.mark.skip_xp_backend(Backend.JAX, reason="TODO special case"), + pytest.mark.skip_xp_backend(Backend.DASK, reason="TODO special case"), + ), + ), + pytest.param( + True, + True, + marks=( + pytest.mark.skip_xp_backend( + Backend.JAX, reason="bool mask update with shaped rhs" + ), + pytest.mark.skip_xp_backend( + Backend.DASK, reason="bool mask update with shaped rhs" + ), + ), + ), + ], +) def test_update_ops( xp: ModuleType, kwargs: dict[str, bool | None], expect_copy: bool | None, op: _AtOp, - arg: float, + y: float, expect: list[float], + bool_mask: bool, + shaped_y: bool, ): - array = xp.asarray([10.0, 20.0, 30.0]) + x = xp.asarray([10.0, 20.0, 30.0]) + idx = xp.asarray([False, True, True]) if bool_mask else slice(1, None) + if shaped_y: + y = xp.asarray([y, y]) - with assert_copy(array, expect_copy): - func = cast(Callable[..., Array], getattr(at(array)[1:], op.value)) # type: ignore[no-any-explicit] - y = func(arg, **kwargs) - assert isinstance(y, type(array)) - xp_assert_equal(y, xp.asarray(expect)) + with assert_copy(x, expect_copy): + z = at_op(x, idx, op, y, **kwargs) + assert isinstance(z, type(x)) + xp_assert_equal(z, xp.asarray(expect)) def test_copy_invalid(): @@ -121,7 +192,6 @@ def test_iops_incompatible_dtype(op: _AtOp, copy: bool): UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64') to dtype('int64') with casting rule 'same_kind' """ - a = np.asarray([2, 4]) - func = cast(Callable[..., Array], getattr(at(a)[:], op.value)) # type: ignore[no-any-explicit] + x = np.asarray([2, 4]) with pytest.raises(TypeError, match="Cannot cast ufunc"): - func(1.1, copy=copy) + at_op(x, slice(None), op, 1.1, copy=copy) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index fbb530b..ef1a1fc 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -22,10 +22,24 @@ from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal from array_api_extra._lib._utils._compat import device as get_device from array_api_extra._lib._utils._typing import Array, Device +from array_api_extra.testing import lazy_xp_function # some xp backends are untyped # mypy: disable-error-code=no-untyped-def +lazy_xp_function(atleast_nd, static_argnames=("ndim", "xp")) +lazy_xp_function(cov, static_argnames="xp") +# FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238 +lazy_xp_function(create_diagonal, jax_jit=False, static_argnames=("offset", "xp")) +lazy_xp_function(expand_dims, static_argnames=("axis", "xp")) +lazy_xp_function(kron, static_argnames="xp") +lazy_xp_function(nunique, static_argnames="xp") +lazy_xp_function(pad, static_argnames=("pad_width", "mode", "constant_values", "xp")) +# FIXME calls in1d which calls xp.unique_values without size +lazy_xp_function(setdiff1d, jax_jit=False, static_argnames=("assume_unique", "xp")) +# FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238 +lazy_xp_function(sinc, jax_jit=False, static_argnames="xp") + @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no expand_dims") class TestAtLeastND: diff --git a/tests/test_testing.py b/tests/test_testing.py index ec5023b..7d4ed0a 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -6,6 +6,13 @@ from array_api_extra._lib import Backend from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal +from array_api_extra._lib._utils._compat import ( + array_namespace, + is_dask_namespace, + is_jax_namespace, +) +from array_api_extra._lib._utils._typing import Array +from array_api_extra.testing import lazy_xp_function # mypy: disable-error-code=no-any-decorated # pyright: reportUnknownParameterType=false,reportMissingParameterType=false @@ -88,3 +95,110 @@ def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None] func(xp.asarray(2), a) with pytest.raises(AssertionError): func(xp.asarray([3]), a) + + +def good_lazy(x: Array) -> Array: + """A function that behaves well in dask and jax.jit""" + return x * 2.0 + + +def non_materializable(x: Array) -> Array: + """ + This function materializes the input array, so it will fail when wrapped in jax.jit + and it will trigger an expensive computation in dask. + """ + xp = array_namespace(x) + # On dask, this triggers two computations of the whole graph + if xp.any(x < 0.0) or xp.any(x > 10.0): + msg = "Values must be in the [0, 10] range" + raise ValueError(msg) + return x + + +def non_materializable2(x: Array) -> Array: + return non_materializable(x) + + +def non_materializable3(x: Array) -> Array: + return non_materializable(x) + + +def non_materializable4(x: Array) -> Array: + return non_materializable(x) + + +lazy_xp_function(good_lazy) +# Works on JAX and Dask +lazy_xp_function(non_materializable2, jax_jit=False, allow_dask_compute=2) +# Works on JAX, but not Dask +lazy_xp_function(non_materializable3, jax_jit=False, allow_dask_compute=1) +# Works neither on Dask nor JAX +lazy_xp_function(non_materializable4) + + +def test_lazy_xp_function(xp: ModuleType): + x = xp.asarray([1.0, 2.0]) + + xp_assert_equal(good_lazy(x), xp.asarray([2.0, 4.0])) + # Not wrapped + xp_assert_equal(non_materializable(x), xp.asarray([1.0, 2.0])) + # Wrapping explicitly disabled + xp_assert_equal(non_materializable2(x), xp.asarray([1.0, 2.0])) + + if is_jax_namespace(xp): + xp_assert_equal(non_materializable3(x), xp.asarray([1.0, 2.0])) + with pytest.raises( + TypeError, match="Attempted boolean conversion of traced array" + ): + non_materializable4(x) # Wrapped + + elif is_dask_namespace(xp): + with pytest.raises( + AssertionError, + match=r"dask\.compute.* 2 times, but only up to 1 calls are allowed", + ): + non_materializable3(x) + with pytest.raises( + AssertionError, + match=r"dask\.compute.* 1 times, but no calls are allowed", + ): + non_materializable4(x) + + else: + xp_assert_equal(non_materializable3(x), xp.asarray([1.0, 2.0])) + xp_assert_equal(non_materializable4(x), xp.asarray([1.0, 2.0])) + + +def static_params(x: Array, n: int, flag: bool = False) -> Array: + """Function with static parameters that must not be jitted""" + if flag and n > 0: # This fails if n or flag are jitted arrays + return x * 2.0 + return x * 3.0 + + +def static_params1(x: Array, n: int, flag: bool = False) -> Array: + return static_params(x, n, flag) + + +def static_params2(x: Array, n: int, flag: bool = False) -> Array: + return static_params(x, n, flag) + + +def static_params3(x: Array, n: int, flag: bool = False) -> Array: + return static_params(x, n, flag) + + +lazy_xp_function(static_params1, static_argnums=(1, 2)) +lazy_xp_function(static_params2, static_argnames=("n", "flag")) +lazy_xp_function(static_params3, static_argnums=1, static_argnames="flag") + + +@pytest.mark.parametrize("func", [static_params1, static_params2, static_params3]) +def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Array]): # type: ignore[no-any-explicit] + x = xp.asarray([1.0, 2.0]) + xp_assert_equal(func(x, 1), xp.asarray([3.0, 6.0])) + xp_assert_equal(func(x, 1, True), xp.asarray([2.0, 4.0])) + xp_assert_equal(func(x, 1, False), xp.asarray([3.0, 6.0])) + xp_assert_equal(func(x, 0, False), xp.asarray([3.0, 6.0])) + xp_assert_equal(func(x, 1, flag=True), xp.asarray([2.0, 4.0])) + xp_assert_equal(func(x, n=1, flag=True), xp.asarray([2.0, 4.0])) diff --git a/tests/test_utils.py b/tests/test_utils.py index fff3f0f..f710056 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,9 +7,13 @@ from array_api_extra._lib._utils._compat import device as get_device from array_api_extra._lib._utils._helpers import in1d from array_api_extra._lib._utils._typing import Device +from array_api_extra.testing import lazy_xp_function # mypy: disable-error-code=no-untyped-usage +# FIXME calls xp.unique_values without size +lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp")) + class TestIn1D: @pytest.mark.skip_xp_backend(Backend.DASK, reason="no argsort")