Skip to content

Commit

Permalink
Add isort (ruff) (#1718)
Browse files Browse the repository at this point in the history
* initial config

* fix sort

* missing fix
  • Loading branch information
juanitorduz authored Jan 9, 2024
1 parent 9ce2384 commit d28fd82
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
2 changes: 1 addition & 1 deletion numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import numpy as np

from jax import jit, lax, local_device_count, pmap, random, vmap, device_get
from jax import device_get, jit, lax, local_device_count, pmap, random, vmap
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_map

Expand Down
26 changes: 21 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ indent-width = 4

[tool.ruff.lint]
# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default.
select = ["E4", "E7", "E9", "F"]
# We also add isort.
select = ["E4", "E7", "E9", "F", "I"]
ignore = ["E203"]

# Allow fix for all enabled rules (when `--fix`) is provided.
Expand All @@ -63,11 +64,26 @@ line-ending = "auto"
"numpyro/contrib/tfp/distributions.py" = ["F811"]
"numpyro/distributions/kl.py" = ["F811"]

[tool.pytest.ini_options]
addopts = [
"-v",
"--color=yes",
[tool.ruff.lint.isort]
combine-as-imports = true
known-first-party = ["funsor", "numpyro", "test"]
known-third-party = ["opt_einsum"]
force-sort-within-sections = true
section-order = [
"future",
"standard-library",
"third-party",
"known-jax",
"first-party",
"local-folder",
]

[tool.ruff.lint.isort.sections]
known-jax = ["flax", "haiku", "jax", "optax", "tensorflow_probability"]


[tool.pytest.ini_options]
addopts = ["-v", "--color=yes"]
filterwarnings = [
"error",
"ignore:numpy.ufunc size changed,:RuntimeWarning",
Expand Down
5 changes: 3 additions & 2 deletions test/contrib/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from copy import deepcopy

import numpy as np
from numpy.testing import assert_allclose
import pytest

from jax import random
from jax.tree_util import tree_all, tree_map
from numpy.testing import assert_allclose

import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.contrib.module import (
ParamShape,
Expand All @@ -20,6 +20,7 @@
random_flax_module,
random_haiku_module,
)
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

pytestmark = pytest.mark.filterwarnings("ignore:jax.tree_.+ is deprecated:FutureWarning")
Expand Down

0 comments on commit d28fd82

Please sign in to comment.