diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index c2a4972cb..650072845 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -9,7 +9,7 @@ import numpy as np -from jax import jit, lax, local_device_count, pmap, random, vmap, device_get +from jax import device_get, jit, lax, local_device_count, pmap, random, vmap import jax.numpy as jnp from jax.tree_util import tree_flatten, tree_map diff --git a/pyproject.toml b/pyproject.toml index d34dec8c3..de1df6bc0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,8 @@ indent-width = 4 [tool.ruff.lint] # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. -select = ["E4", "E7", "E9", "F"] +# We also add isort. +select = ["E4", "E7", "E9", "F", "I"] ignore = ["E203"] # Allow fix for all enabled rules (when `--fix`) is provided. @@ -63,11 +64,26 @@ line-ending = "auto" "numpyro/contrib/tfp/distributions.py" = ["F811"] "numpyro/distributions/kl.py" = ["F811"] -[tool.pytest.ini_options] -addopts = [ - "-v", - "--color=yes", +[tool.ruff.lint.isort] +combine-as-imports = true +known-first-party = ["funsor", "numpyro", "test"] +known-third-party = ["opt_einsum"] +force-sort-within-sections = true +section-order = [ + "future", + "standard-library", + "third-party", + "known-jax", + "first-party", + "local-folder", ] + +[tool.ruff.lint.isort.sections] +known-jax = ["flax", "haiku", "jax", "optax", "tensorflow_probability"] + + +[tool.pytest.ini_options] +addopts = ["-v", "--color=yes"] filterwarnings = [ "error", "ignore:numpy.ufunc size changed,:RuntimeWarning", diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index 20c8717f3..803af666c 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -4,13 +4,13 @@ from copy import deepcopy import numpy as np +from numpy.testing import assert_allclose import pytest + from jax import random from jax.tree_util import tree_all, tree_map -from numpy.testing import assert_allclose import numpyro -import numpyro.distributions as dist from numpyro import handlers from numpyro.contrib.module import ( ParamShape, @@ -20,6 +20,7 @@ random_flax_module, random_haiku_module, ) +import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS pytestmark = pytest.mark.filterwarnings("ignore:jax.tree_.+ is deprecated:FutureWarning")