diff --git a/.devtools/license b/.devtools/license
index f89c47f3c1..a63992e968 100755
--- a/.devtools/license
+++ b/.devtools/license
@@ -46,10 +46,11 @@ external_license = "# LICENSE: EXTERNAL"
def check_file(path: Path) -> bool:
- """Check whether `path` is complicit.
+ """
+ Check whether `path` is complicit.
- The license header needs to start on the 3rd line latest. This allows
- to have a shebang on top of the file.
+ The license header needs to start on the 3rd line latest. This allows to
+ have a shebang on top of the file.
"""
num_lines = len(license.splitlines()) + 2
diff --git a/.github/workflows/just_checks.yml b/.github/workflows/just_checks.yml
new file mode 100644
index 0000000000..564463c8ca
--- /dev/null
+++ b/.github/workflows/just_checks.yml
@@ -0,0 +1,23 @@
+name: Type & license checks
+
+on: [push, pull_request]
+
+jobs:
+ just-checks:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ just_cmd: ["license", "mypy"]
+ steps:
+ - uses: actions/checkout@v3
+ - uses: extractions/setup-just@v1
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ - name: Set up Python 3.8
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.8'
+ - name: Install dependencies
+ run: pip install -r requirements/requirements-ci-just.txt
+ - name: Check ${{ matrix.just_cmd }}
+ run: just ${{ matrix.just_cmd }}
diff --git a/.github/workflows/lint_precommit.yml b/.github/workflows/lint_precommit.yml
new file mode 100644
index 0000000000..285de16c92
--- /dev/null
+++ b/.github/workflows/lint_precommit.yml
@@ -0,0 +1,11 @@
+name: Lint checks by pre-commit
+
+on: [push, pull_request]
+
+jobs:
+ pre-commit:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v3
+ - uses: pre-commit/action@v3.0.1
diff --git a/.github/workflows/style_type_checks.yml b/.github/workflows/style_type_checks.yml
deleted file mode 100644
index 48c2178edc..0000000000
--- a/.github/workflows/style_type_checks.yml
+++ /dev/null
@@ -1,24 +0,0 @@
-name: Style and type checks
-
-on: [push, pull_request]
-
-jobs:
- style-type-checks:
- runs-on: ubuntu-latest
-
- steps:
- - uses: actions/checkout@v3
- - uses: extractions/setup-just@v1
- - uses: actions/setup-python@v4
- - name: Install dependencies
- run: |
- pip install .
- # todo: install also `black[jupyter]`
- pip install click "black==24.02" "mypy==1.8.0" \
- types-python-dateutil types-waitress types-PyYAML
- - name: Style check
- run: just black
- - name: Type check
- run: just mypy
- - name: Check license headers
- run: just license
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000..be0752d118
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,52 @@
+default_language_version:
+ python: python3
+
+ci:
+ autofix_prs: true
+ autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions"
+ autoupdate_schedule: monthly
+ # submodules: true
+
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.5.0
+ hooks:
+ # - id: end-of-file-fixer
+ # - id: trailing-whitespace
+ - id: check-json
+ - id: check-yaml
+ - id: check-toml
+ # - id: check-docstring-first
+ # - id: check-executables-have-shebangs
+ # - id: check-case-conflict
+ # - id: check-added-large-files
+ - id: detect-private-key
+
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.3.3
+ hooks:
+ - id: ruff
+ name: Run Ruff & fixing (Flake8)
+ args: ["--fix"]
+
+ - repo: https://github.com/PyCQA/docformatter
+ rev: v1.5.0
+ hooks:
+ - id: docformatter
+ name: Format docstrings with Docformatter
+ additional_dependencies: [tomli]
+ args: ["--in-place"]
+
+ - repo: https://github.com/psf/black
+ rev: 24.2.0
+ hooks:
+ - id: black
+ name: Format code with Black
+
+ #- repo: https://github.com/executablebooks/mdformat
+ # rev: 0.7.17
+ # hooks:
+ # - id: mdformat
+ # additional_dependencies:
+ # - mdformat-gfm
+ # - mdformat_frontmatter
diff --git a/Justfile b/Justfile
index 986927c91c..bd0800c3b8 100644
--- a/Justfile
+++ b/Justfile
@@ -33,9 +33,6 @@ compile_notebooks:
release:
python setup.py sdist
-black:
- black --check --diff --color src test examples
-
mypy:
python setup.py type_check
diff --git a/docs/conf.py b/docs/conf.py
index d74e5c357e..c4b8ad8de6 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -95,9 +95,9 @@
}
if os.environ.get("GITHUB_REF_NAME") == "dev":
- html_theme_options[
- "announcement"
- ] = "Note: You are looking at the development docs."
+ html_theme_options["announcement"] = (
+ "Note: You are looking at the development docs."
+ )
# Add any paths that contain custom static files (such as style sheets) here,
diff --git a/examples/anomaly_detection.py b/examples/anomaly_detection.py
index 094404486f..455d1fddcb 100644
--- a/examples/anomaly_detection.py
+++ b/examples/anomaly_detection.py
@@ -13,11 +13,11 @@
"""
This example shows how to do anomaly detection with DeepAR.
-The model is first trained and then time-points with the largest negative log-likelihood are plotted.
+The model is first trained and then time-points with the largest negative log-
+likelihood are plotted.
"""
import numpy as np
-from itertools import islice
from functools import partial
import mxnet as mx
import matplotlib.pyplot as plt
diff --git a/examples/warm_start.py b/examples/warm_start.py
index fd425615a8..fba7be65a0 100644
--- a/examples/warm_start.py
+++ b/examples/warm_start.py
@@ -12,8 +12,8 @@
# permissions and limitations under the License.
"""
-This example show how to intialize the network with parameters from a model that was previously trained.
-
+This example show how to intialize the network with parameters from a model
+that was previously trained.
"""
from gluonts.dataset.repository import get_dataset, dataset_recipes
diff --git a/pyproject.toml b/pyproject.toml
index 96524e4daa..8205dc6171 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,13 +20,13 @@ filterwarnings = "ignore"
[tool.ruff]
line-length = 79
-
lint.ignore = [
"E501", # line-length is handled by black
"E741" # TODO: remove usage of `l`
]
-
-exclude = ["src/gluonts/nursery"]
+exclude = ["src/gluonts/nursery"] # todo
+[tool.ruff.lint.per-file-ignores]
+"examples/*" = ["E402"]
[tool.docformatter]
diff --git a/requirements/requirements-articial-dataset.py b/requirements/requirements-articial-dataset.txt
similarity index 100%
rename from requirements/requirements-articial-dataset.py
rename to requirements/requirements-articial-dataset.txt
diff --git a/requirements/requirements-ci-just.txt b/requirements/requirements-ci-just.txt
new file mode 100644
index 0000000000..c481648433
--- /dev/null
+++ b/requirements/requirements-ci-just.txt
@@ -0,0 +1,5 @@
+click
+mypy==1.8.0
+types-python-dateutil
+types-waitress
+types-PyYAML
\ No newline at end of file
diff --git a/setup.py b/setup.py
index eb628f654f..8c43857d15 100644
--- a/setup.py
+++ b/setup.py
@@ -60,7 +60,8 @@ def run(self):
import mypy.api
excluded_folders = [
- str(p.parent.relative_to(ROOT)) for p in ROOT.glob("src/**/.typeunsafe")
+ str(p.parent.relative_to(ROOT))
+ for p in ROOT.glob("src/**/.typeunsafe")
]
if len(excluded_folders) > 0:
diff --git a/src/gluonts/itertools.py b/src/gluonts/itertools.py
index 44398e5eb6..fab1f35099 100644
--- a/src/gluonts/itertools.py
+++ b/src/gluonts/itertools.py
@@ -298,7 +298,7 @@ def split_into(xs: Sequence, n: int) -> Sequence:
# e.g. 10 by 3 -> 4, 3, 3
relative_splits[:remainder] += 1
- return split(xs, np.cumsum(relative_splits))
+ return split(xs, np.cumsum(relative_splits)) # type: ignore[arg-type]
@dataclass
diff --git a/src/gluonts/model/forecast.py b/src/gluonts/model/forecast.py
index c69ae385fc..1b56b1fc02 100644
--- a/src/gluonts/model/forecast.py
+++ b/src/gluonts/model/forecast.py
@@ -178,14 +178,14 @@ class Quantile:
@classmethod
def from_float(cls, quantile: float) -> "Quantile":
assert isinstance(quantile, float)
- return cls(value=quantile, name=str(quantile))
+ return cls(value=quantile, name=str(quantile)) # type: ignore[call-arg]
@classmethod
def from_str(cls, quantile: str) -> "Quantile":
assert isinstance(quantile, str)
try:
- return cls(value=float(quantile), name=quantile)
+ return cls(value=float(quantile), name=quantile) # type: ignore[call-arg]
except ValueError:
m = re.match(r"^p(\d+)$", quantile)
diff --git a/src/gluonts/mx/trainer/learning_rate_scheduler.py b/src/gluonts/mx/trainer/learning_rate_scheduler.py
index 5f33445b1e..9a165748d2 100644
--- a/src/gluonts/mx/trainer/learning_rate_scheduler.py
+++ b/src/gluonts/mx/trainer/learning_rate_scheduler.py
@@ -247,8 +247,8 @@ def __init__(
0 <= min_lr <= base_lr
), "The value of `min_lr` should be >= 0 and <= base_lr"
- self.lr_scheduler = MetricAttentiveScheduler(
- patience=Patience(
+ self.lr_scheduler = MetricAttentiveScheduler( # type: ignore[call-arg]
+ patience=Patience( # type: ignore[call-arg]
patience=patience, objective=Objective.from_str(objective)
),
learning_rate=base_lr,
diff --git a/src/gluonts/torch/model/patch_tst/module.py b/src/gluonts/torch/model/patch_tst/module.py
index 5f5c3e9fbc..14f481ea8f 100644
--- a/src/gluonts/torch/model/patch_tst/module.py
+++ b/src/gluonts/torch/model/patch_tst/module.py
@@ -229,7 +229,7 @@ def forward(
# shift time features by `prediction_length` so that they are
# aligned with the target input.
time_feat = take_last(
- torch.cat((past_time_feat, future_time_feat), dim=1),
+ torch.cat((past_time_feat, future_time_feat), dim=1), # type: ignore[arg-type]
dim=1,
num=self.context_length,
)
diff --git a/src/gluonts/torch/model/tft/layers.py b/src/gluonts/torch/model/tft/layers.py
index d19ed7f059..9977a3d1fa 100644
--- a/src/gluonts/torch/model/tft/layers.py
+++ b/src/gluonts/torch/model/tft/layers.py
@@ -27,7 +27,7 @@ def forward(self, features: torch.Tensor) -> List[torch.Tensor]: # type: ignore
concat_features = super().forward(features=features)
if self._num_features > 1:
- return torch.chunk(concat_features, self._num_features, dim=-1)
+ return torch.chunk(concat_features, self._num_features, dim=-1) # type: ignore[return-value]
else:
return [concat_features]
diff --git a/src/gluonts/torch/modules/feature.py b/src/gluonts/torch/modules/feature.py
index 852d3372c9..46cef7993d 100644
--- a/src/gluonts/torch/modules/feature.py
+++ b/src/gluonts/torch/modules/feature.py
@@ -38,7 +38,7 @@ def forward(self, features: torch.Tensor) -> torch.Tensor:
features, self._num_features, dim=-1
)
else:
- cat_feature_slices = [features]
+ cat_feature_slices = [features] # type: ignore[assignment]
return torch.cat(
[
diff --git a/src/gluonts/zebras/_time_frame.py b/src/gluonts/zebras/_time_frame.py
index b845252586..dba0994e53 100644
--- a/src/gluonts/zebras/_time_frame.py
+++ b/src/gluonts/zebras/_time_frame.py
@@ -453,11 +453,11 @@ def split(
# If past_length is not provided, it will equal to `index`, since
# `len(tf.split(5).past) == 5`
- past_length: int = maybe.unwrap_or(past_length, index)
+ past_length: int = maybe.unwrap_or(past_length, index) # type: ignore[annotation-unchecked]
# Same logic applies to future_length, except that we deduct from the
# right. (We can't use past_length, since it can be unequal to index).
- future_length: int = maybe.unwrap_or(future_length, len(self) - index)
+ future_length: int = maybe.unwrap_or(future_length, len(self) - index) # type: ignore[annotation-unchecked]
if self.index is None:
new_index = None
diff --git a/test/core/test_component.py b/test/core/test_component.py
index 241a245520..26524e5707 100644
--- a/test/core/test_component.py
+++ b/test/core/test_component.py
@@ -24,8 +24,8 @@ def __init__(self, x: float, y: float) -> None:
self.x = x
self.y = y
- assert type(self.x) == float
- assert type(self.y) == float
+ assert isinstance(self.x, float)
+ assert isinstance(self.y, float)
def __eq__(self, that):
return self.x == that.x and self.y == that.y
@@ -41,9 +41,9 @@ def __init__(self, a: int, c: Complex, **kwargs) -> None:
self.b = kwargs["b"]
self.c = c
- assert type(self.a) == int
- assert type(self.b) == float
- assert type(self.c) == Complex
+ assert isinstance(self.a, int)
+ assert isinstance(self.b, float)
+ assert isinstance(self.c, Complex)
def __eq__(self, that):
return self.a == that.a and self.b == that.b and self.c == that.c
diff --git a/test/ev/test_metrics_compared_to_previous_approach.py b/test/ev/test_metrics_compared_to_previous_approach.py
index 5adc2e5d2b..9faa6167fb 100644
--- a/test/ev/test_metrics_compared_to_previous_approach.py
+++ b/test/ev/test_metrics_compared_to_previous_approach.py
@@ -121,7 +121,7 @@ def get_data_batches(predictor, test_data):
)
seasonality = get_seasonality(freq=forecast.start_date.freqstr)
- freq = forecast.start_date.freqstr
+ # freq = forecast.start_date.freqstr
other_data = {
"label": np.array([label["target"]]),
"seasonal_error": np.array(
diff --git a/test/ext/hierarchicalforecast/test_predictor.py b/test/ext/hierarchicalforecast/test_predictor.py
index 51d962289b..9a69f5b881 100644
--- a/test/ext/hierarchicalforecast/test_predictor.py
+++ b/test/ext/hierarchicalforecast/test_predictor.py
@@ -110,7 +110,7 @@ def test_predictor_serialization(
):
train_datasets = sine7(nonnegative=True)
- (train_dataset, test_dataset, metadata) = (
+ _, _, metadata = (
train_datasets.train,
train_datasets.test,
train_datasets.metadata,
@@ -152,7 +152,7 @@ def test_predictor_working(
bias=1,
)
- (train_dataset, test_dataset, metadata) = (
+ _, test_dataset, metadata = (
train_datasets.train,
train_datasets.test,
train_datasets.metadata,
diff --git a/test/mx/model/deepvar_hierarchical/test_projection.py b/test/mx/model/deepvar_hierarchical/test_projection.py
index 4e57e341b5..ba68a5076c 100644
--- a/test/mx/model/deepvar_hierarchical/test_projection.py
+++ b/test/mx/model/deepvar_hierarchical/test_projection.py
@@ -146,4 +146,4 @@ def test_projection_mat(D):
)
def test_projection_mat_expected_fail(D):
with pytest.raises(AssertionError):
- p = projection_mat(S=S, D=D)
+ projection_mat(S=S, D=D)
diff --git a/test/test_itertools.py b/test/test_itertools.py
index b9071d9ad0..6c69e29bb9 100644
--- a/test/test_itertools.py
+++ b/test/test_itertools.py
@@ -289,7 +289,7 @@ def test_join_items():
]
with pytest.raises(Exception):
- oin_items(left, right, "strict")
+ oin_items(left, right, "strict") # noqa: F821
@pytest.mark.parametrize(
diff --git a/test/zebras/test_batch.py b/test/zebras/test_batch.py
index bd00613ed8..689770d94b 100644
--- a/test/zebras/test_batch.py
+++ b/test/zebras/test_batch.py
@@ -11,7 +11,6 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
-import pytest
import numpy as np