Skip to content

Commit

Permalink
Split github workflows for lower latency, add ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
rdyro committed Dec 10, 2024
1 parent f076ae1 commit 5bc9812
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 8 deletions.
63 changes: 60 additions & 3 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ jobs:
build-and-test:
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }} jax=${{ matrix.jax-version }}"
runs-on: "${{ matrix.os }}"

strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
Expand All @@ -22,7 +21,6 @@ jobs:
- python-version: "3.9"
os: "ubuntu-latest"
jax-version: "0.4.27" # Keep version in sync with pyproject.toml and copy.bara.sky!

steps:
- uses: "actions/checkout@v2"
- uses: "actions/setup-python@v4"
Expand All @@ -33,13 +31,72 @@ jobs:
- name: Run CI tests
run: JAX_VERSION="${{ matrix.jax-version }}" bash test.sh
shell: bash
doctests:
name: "Doctests on ${{ matrix.os }} with Python ${{ matrix.python-version }}"
runs-on: "${{ matrix.os }}"
strategy:
matrix:
python-version: ["3.11"] # only build docs with a somewhat latest python
os: [ubuntu-latest]
steps:
- uses: "actions/checkout@v2"
- uses: "actions/setup-python@v4"
with:
python-version: "${{ matrix.python-version }}"
cache: "pip"
cache-dependency-path: 'pyproject.toml'
- name: Build docs and run doctests
run: |
python3 -m pip install --quiet --editable ".[docs]"
cd docs
make html
make doctest # run doctests
shell: bash
linting:
name: "Lint check with flake8 and pylint"
runs-on: "ubuntu-latest"
steps:
- uses: "actions/checkout@v2"
- uses: "actions/setup-python@v4"
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: "pyproject.toml"
- name: Install linting dependencies
run: |
pip install -U pip setuptools wheel
pip install -U flake8 pytest-xdist pylint pylint-exit
- name: Lint with flake8
run: |
python3 -m flake8 --select=E9,F63,F7,F82,E225,E251 --show-source --statistics
- name: Lint module files with pylint
run: |
PYLINT_ARGS="-efail -wfail -cfail -rfail"
python3 -m pylint --rcfile=.pylintrc $(find optax -name '*.py' | grep -v 'test.py' | xargs) -d E1102 || pylint-exit $PYLINT_ARGS $?
- name: Lint test files with pylint
run: |
PYLINT_ARGS="-efail -wfail -cfail -rfail"
python3 -m pylint --rcfile=.pylintrc $(find optax -name '*_test.py' | xargs) -d W0212,E1102 || pylint-exit $PYLINT_ARGS $?
ruff-lint:
name: "Lint check with ruff"
runs-on: "ubuntu-latest"
steps:
- uses: "actions/checkout@v2"
- uses: "actions/setup-python@v4"
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: "pyproject.toml"
- name: Install ruff and lint check
run: |
pip install -U ruff
ruff check .
markdown-link-check:
name: "Check links in markdown files"
runs-on: "ubuntu-latest"
steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Check links
uses: gaurav-nelson/github-action-markdown-link-check@v1
with:
Expand Down
4 changes: 2 additions & 2 deletions examples/contrib/reduce_on_plateau.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@
"source": [
"opt = optax.chain(\n",
" optax.adam(LEARNING_RATE),\n",
" reduce_on_plateau(\n",
" contrib.reduce_on_plateau(\n",
" patience=PATIENCE,\n",
" cooldown=COOLDOWN,\n",
" factor=FACTOR,\n",
Expand Down Expand Up @@ -759,7 +759,7 @@
}
],
"source": [
"transform = reduce_on_plateau(\n",
"transform = contrib.reduce_on_plateau(\n",
" patience=PATIENCE,\n",
" cooldown=COOLDOWN,\n",
" factor=FACTOR,\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/linear_assignment_problem.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"outputs": [],
"source": [
"import networkx as nx\n",
"from jax import numpy as jnp, random\n",
"from jax import random\n",
"import optax\n",
"from matplotlib import pyplot as plt"
]
Expand Down
2 changes: 1 addition & 1 deletion examples/nanolm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@
}
],
"source": [
"plt.title(f\"Convergence of adamw (train loss)\")\n",
"plt.title(\"Convergence of adamw (train loss)\")\n",
"plt.plot(all_train_losses, label=\"train\", lw=3)\n",
"plt.plot(\n",
" jnp.arange(0, len(all_eval_losses) * N_FREQ_EVAL, N_FREQ_EVAL),\n",
Expand Down
2 changes: 1 addition & 1 deletion optax/schedules/_inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

def _convert_floats(x, dtype):
"""Convert float-like inputs to dtype, rest pass through."""
if jax.dtypes.scalar_type_of(x) == float:
if jax.dtypes.scalar_type_of(x) is float:
return jnp.asarray(x, dtype=dtype)
return x

Expand Down
17 changes: 17 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,20 @@ dp-accounting = [
[tool.setuptools.packages.find]
include = ["README.md", "LICENSE"]
exclude = ["*_test.py"]

#[tool.ruff]
#extend-exclude = ["*.ipynb"]

[tool.ruff.lint]
select = [
"F",
"E",
]
ignore = [
"E731", # lambdas are allowed
"E501", # don't check line lengths
"F401", # allow unused imports
"E402", # allow modules not at top of file
"E741", # allow "l" as a variable name
"E703", # allow semicolons (for jupyter notebooks)
]
3 changes: 3 additions & 0 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,7 @@ make html
make doctest # run doctests
cd ..

pip install -U ruff
ruff check .

echo "All tests passed. Congrats!"

0 comments on commit 5bc9812

Please sign in to comment.