diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ee8cedebed..c4f99fa168 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,12 +28,11 @@ jobs: python -m pip install --upgrade pip pip install jaxlib pip install jax - pip install black pip install .[doc,test] pip install https://github.com/pyro-ppl/funsor/archive/master.zip pip install -r docs/requirements.txt pip freeze - - name: Lint with flake8 + - name: Lint with ruff run: | make lint - name: Build documentation diff --git a/pyproject.toml b/pyproject.toml index b4ecf3a77d..d34dec8c36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,3 +62,24 @@ line-ending = "auto" [tool.ruff.extend-per-file-ignores] "numpyro/contrib/tfp/distributions.py" = ["F811"] "numpyro/distributions/kl.py" = ["F811"] + +[tool.pytest.ini_options] +addopts = [ + "-v", + "--color=yes", +] +filterwarnings = [ + "error", + "ignore:numpy.ufunc size changed,:RuntimeWarning", + "ignore:Using a non-tuple sequence:FutureWarning", + "ignore:jax.tree_structure is deprecated:FutureWarning", + "ignore:numpy.linalg support is experimental:UserWarning", + "ignore:scipy.linalg support is experimental:UserWarning", + "once:No GPU:UserWarning", + "once::DeprecationWarning", +] +doctest_optionflags = [ + "ELLIPSIS", + "NORMALIZE_WHITESPACE", + "IGNORE_EXCEPTION_DETAIL", +] diff --git a/setup.py b/setup.py deleted file mode 100644 index e6b13e248a..0000000000 --- a/setup.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import absolute_import, division, print_function - -import os -import sys - -from setuptools import find_packages, setup - -PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) -_jax_version_constraints = ">=0.4.14" -_jaxlib_version_constraints = ">=0.4.14" - -# Find version -for line in open(os.path.join(PROJECT_PATH, "numpyro", "version.py")): - if line.startswith("__version__ = "): - version = line.strip().split()[2][1:-1] - -# READ README.md for long description on PyPi. -try: - long_description = open("README.md", encoding="utf-8").read() -except Exception as e: - sys.stderr.write("Failed to read README.md:\n {}\n".format(e)) - sys.stderr.flush() - long_description = "" - -setup( - name="numpyro", - version=version, - description="Pyro PPL on NumPy", - packages=find_packages(include=["numpyro", "numpyro.*"]), - url="https://github.com/pyro-ppl/numpyro", - author="Uber AI Labs", - install_requires=[ - f"jax{_jax_version_constraints}", - f"jaxlib{_jaxlib_version_constraints}", - "multipledispatch", - "numpy", - "tqdm", - ], - extras_require={ - "doc": [ - "ipython", # sphinx needs this to render codes - "nbsphinx>=0.8.5", - "readthedocs-sphinx-search==0.1.0", - "sphinx", - "sphinx_rtd_theme", - "sphinx-gallery", - ], - "test": [ - "importlib-metadata<5.0", - "importlib-metadata<5.0", - "ruff>=0.1.8", - "pytest>=4.1", - "pyro-api>=0.1.1", - "scipy>=1.9", - ], - "dev": [ - "dm-haiku", - "flax", - "funsor>=0.4.1", - "graphviz", - "jaxns==2.2.6", - "matplotlib", - "optax>=0.0.6", - "pylab-sdk", # jaxns dependency - "pyyaml", # flax dependency - "requests", # pylab dependency - "tensorflow_probability>=0.18.0", - ], - "examples": [ - "arviz", - "jupyter", - "matplotlib", - "pandas", - "seaborn", - "scikit-learn", - "wordcloud", - ], - "cpu": f"jax[cpu]{_jax_version_constraints}", - # TPU and CUDA installations, currently require to add package repository URL, i.e., - # pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html - "tpu": f"jax[tpu]{_jax_version_constraints}", - "cuda": f"jax[cuda]{_jax_version_constraints}", - }, - python_requires=">=3.9", - long_description=long_description, - long_description_content_type="text/markdown", - keywords="probabilistic machine learning bayesian statistics", - license="Apache License 2.0", - classifiers=[ - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: Apache Software License", - "Operating System :: POSIX :: Linux", - "Operating System :: MacOS :: MacOS X", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - ], -)