Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate style check to precommit #3111

Open
wants to merge 28 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .devtools/license
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ external_license = "# LICENSE: EXTERNAL"


def check_file(path: Path) -> bool:
"""Check whether `path` is complicit.
"""
Check whether `path` is complicit.

The license header needs to start on the 3rd line latest. This allows
to have a shebang on top of the file.
The license header needs to start on the 3rd line latest. This allows to
have a shebang on top of the file.
"""
num_lines = len(license.splitlines()) + 2

Expand Down
23 changes: 23 additions & 0 deletions .github/workflows/just_checks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Type & license checks

on: [push, pull_request]

jobs:
just-checks:
runs-on: ubuntu-latest
strategy:
matrix:
just_cmd: ["license", "mypy"]
steps:
- uses: actions/checkout@v3
- uses: extractions/setup-just@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Python 3.8
uses: actions/setup-python@v4
with:
python-version: '3.8'
- name: Install dependencies
run: pip install -r requirements/requirements-ci-just.txt
- name: Check ${{ matrix.just_cmd }}
run: just ${{ matrix.just_cmd }}
11 changes: 11 additions & 0 deletions .github/workflows/lint_precommit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
name: Lint checks by pre-commit

on: [push, pull_request]

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: pre-commit/action@v3.0.1
24 changes: 0 additions & 24 deletions .github/workflows/style_type_checks.yml

This file was deleted.

52 changes: 52 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
default_language_version:
python: python3

ci:
autofix_prs: true
autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions"
autoupdate_schedule: monthly
# submodules: true

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
# - id: end-of-file-fixer
# - id: trailing-whitespace
- id: check-json
- id: check-yaml
- id: check-toml
# - id: check-docstring-first
# - id: check-executables-have-shebangs
# - id: check-case-conflict
# - id: check-added-large-files
- id: detect-private-key

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.3
hooks:
- id: ruff
name: Run Ruff & fixing (Flake8)
args: ["--fix"]

- repo: https://github.com/PyCQA/docformatter
rev: v1.5.0
hooks:
- id: docformatter
name: Format docstrings with Docformatter
additional_dependencies: [tomli]
args: ["--in-place"]

- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
name: Format code with Black

#- repo: https://github.com/executablebooks/mdformat
# rev: 0.7.17
# hooks:
# - id: mdformat
# additional_dependencies:
# - mdformat-gfm
# - mdformat_frontmatter
3 changes: 0 additions & 3 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ compile_notebooks:
release:
python setup.py sdist

black:
black --check --diff --color src test examples

mypy:
python setup.py type_check

Expand Down
6 changes: 3 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@
}

if os.environ.get("GITHUB_REF_NAME") == "dev":
html_theme_options[
"announcement"
] = "<strong>Note:</strong> You are looking at the development docs."
html_theme_options["announcement"] = (
"<strong>Note:</strong> You are looking at the development docs."
)


# Add any paths that contain custom static files (such as style sheets) here,
Expand Down
4 changes: 2 additions & 2 deletions examples/anomaly_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@

"""
This example shows how to do anomaly detection with DeepAR.
The model is first trained and then time-points with the largest negative log-likelihood are plotted.

The model is first trained and then time-points with the largest negative log-
likelihood are plotted.
"""
import numpy as np
from itertools import islice
from functools import partial
import mxnet as mx
import matplotlib.pyplot as plt
Expand Down
4 changes: 2 additions & 2 deletions examples/warm_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# permissions and limitations under the License.

"""
This example show how to intialize the network with parameters from a model that was previously trained.

This example show how to intialize the network with parameters from a model
that was previously trained.
"""

from gluonts.dataset.repository import get_dataset, dataset_recipes
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ filterwarnings = "ignore"

[tool.ruff]
line-length = 79

lint.ignore = [
"E501", # line-length is handled by black
"E741" # TODO: remove usage of `l`
]

exclude = ["src/gluonts/nursery"]
exclude = ["src/gluonts/nursery"] # todo
[tool.ruff.lint.per-file-ignores]
"examples/*" = ["E402"]


[tool.docformatter]
Expand Down
5 changes: 5 additions & 0 deletions requirements/requirements-ci-just.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
click
mypy==1.8.0
types-python-dateutil
types-waitress
types-PyYAML
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def run(self):
import mypy.api

excluded_folders = [
str(p.parent.relative_to(ROOT)) for p in ROOT.glob("src/**/.typeunsafe")
str(p.parent.relative_to(ROOT))
for p in ROOT.glob("src/**/.typeunsafe")
]

if len(excluded_folders) > 0:
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def split_into(xs: Sequence, n: int) -> Sequence:
# e.g. 10 by 3 -> 4, 3, 3
relative_splits[:remainder] += 1

return split(xs, np.cumsum(relative_splits))
return split(xs, np.cumsum(relative_splits)) # type: ignore[arg-type]


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/model/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,14 @@ class Quantile:
@classmethod
def from_float(cls, quantile: float) -> "Quantile":
assert isinstance(quantile, float)
return cls(value=quantile, name=str(quantile))
return cls(value=quantile, name=str(quantile)) # type: ignore[call-arg]

@classmethod
def from_str(cls, quantile: str) -> "Quantile":
assert isinstance(quantile, str)

try:
return cls(value=float(quantile), name=quantile)
return cls(value=float(quantile), name=quantile) # type: ignore[call-arg]
except ValueError:
m = re.match(r"^p(\d+)$", quantile)

Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/mx/trainer/learning_rate_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def __init__(
0 <= min_lr <= base_lr
), "The value of `min_lr` should be >= 0 and <= base_lr"

self.lr_scheduler = MetricAttentiveScheduler(
patience=Patience(
self.lr_scheduler = MetricAttentiveScheduler( # type: ignore[call-arg]
patience=Patience( # type: ignore[call-arg]
patience=patience, objective=Objective.from_str(objective)
),
learning_rate=base_lr,
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/torch/model/patch_tst/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def forward(
# shift time features by `prediction_length` so that they are
# aligned with the target input.
time_feat = take_last(
torch.cat((past_time_feat, future_time_feat), dim=1),
torch.cat((past_time_feat, future_time_feat), dim=1), # type: ignore[arg-type]
dim=1,
num=self.context_length,
)
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/torch/model/tft/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def forward(self, features: torch.Tensor) -> List[torch.Tensor]: # type: ignore
concat_features = super().forward(features=features)

if self._num_features > 1:
return torch.chunk(concat_features, self._num_features, dim=-1)
return torch.chunk(concat_features, self._num_features, dim=-1) # type: ignore[return-value]
else:
return [concat_features]

Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/torch/modules/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, features: torch.Tensor) -> torch.Tensor:
features, self._num_features, dim=-1
)
else:
cat_feature_slices = [features]
cat_feature_slices = [features] # type: ignore[assignment]

return torch.cat(
[
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/zebras/_time_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,11 +453,11 @@ def split(

# If past_length is not provided, it will equal to `index`, since
# `len(tf.split(5).past) == 5`
past_length: int = maybe.unwrap_or(past_length, index)
past_length: int = maybe.unwrap_or(past_length, index) # type: ignore[annotation-unchecked]

# Same logic applies to future_length, except that we deduct from the
# right. (We can't use past_length, since it can be unequal to index).
future_length: int = maybe.unwrap_or(future_length, len(self) - index)
future_length: int = maybe.unwrap_or(future_length, len(self) - index) # type: ignore[annotation-unchecked]

if self.index is None:
new_index = None
Expand Down
10 changes: 5 additions & 5 deletions test/core/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __init__(self, x: float, y: float) -> None:
self.x = x
self.y = y

assert type(self.x) == float
assert type(self.y) == float
assert isinstance(self.x, float)
assert isinstance(self.y, float)

def __eq__(self, that):
return self.x == that.x and self.y == that.y
Expand All @@ -41,9 +41,9 @@ def __init__(self, a: int, c: Complex, **kwargs) -> None:
self.b = kwargs["b"]
self.c = c

assert type(self.a) == int
assert type(self.b) == float
assert type(self.c) == Complex
assert isinstance(self.a, int)
assert isinstance(self.b, float)
assert isinstance(self.c, Complex)

def __eq__(self, that):
return self.a == that.a and self.b == that.b and self.c == that.c
Expand Down
2 changes: 1 addition & 1 deletion test/ev/test_metrics_compared_to_previous_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def get_data_batches(predictor, test_data):
)

seasonality = get_seasonality(freq=forecast.start_date.freqstr)
freq = forecast.start_date.freqstr
# freq = forecast.start_date.freqstr
other_data = {
"label": np.array([label["target"]]),
"seasonal_error": np.array(
Expand Down
4 changes: 2 additions & 2 deletions test/ext/hierarchicalforecast/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_predictor_serialization(
):
train_datasets = sine7(nonnegative=True)

(train_dataset, test_dataset, metadata) = (
_, _, metadata = (
train_datasets.train,
train_datasets.test,
train_datasets.metadata,
Expand Down Expand Up @@ -152,7 +152,7 @@ def test_predictor_working(
bias=1,
)

(train_dataset, test_dataset, metadata) = (
_, test_dataset, metadata = (
train_datasets.train,
train_datasets.test,
train_datasets.metadata,
Expand Down
2 changes: 1 addition & 1 deletion test/mx/model/deepvar_hierarchical/test_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,4 @@ def test_projection_mat(D):
)
def test_projection_mat_expected_fail(D):
with pytest.raises(AssertionError):
p = projection_mat(S=S, D=D)
projection_mat(S=S, D=D)
2 changes: 1 addition & 1 deletion test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def test_join_items():
]

with pytest.raises(Exception):
oin_items(left, right, "strict")
oin_items(left, right, "strict") # noqa: F821


@pytest.mark.parametrize(
Expand Down
1 change: 0 additions & 1 deletion test/zebras/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import pytest

import numpy as np

Expand Down