diff --git a/.devtools/license b/.devtools/license index f89c47f3c1..a63992e968 100755 --- a/.devtools/license +++ b/.devtools/license @@ -46,10 +46,11 @@ external_license = "# LICENSE: EXTERNAL" def check_file(path: Path) -> bool: - """Check whether `path` is complicit. + """ + Check whether `path` is complicit. - The license header needs to start on the 3rd line latest. This allows - to have a shebang on top of the file. + The license header needs to start on the 3rd line latest. This allows to + have a shebang on top of the file. """ num_lines = len(license.splitlines()) + 2 diff --git a/.github/workflows/just_checks.yml b/.github/workflows/just_checks.yml new file mode 100644 index 0000000000..564463c8ca --- /dev/null +++ b/.github/workflows/just_checks.yml @@ -0,0 +1,23 @@ +name: Type & license checks + +on: [push, pull_request] + +jobs: + just-checks: + runs-on: ubuntu-latest + strategy: + matrix: + just_cmd: ["license", "mypy"] + steps: + - uses: actions/checkout@v3 + - uses: extractions/setup-just@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Set up Python 3.8 + uses: actions/setup-python@v4 + with: + python-version: '3.8' + - name: Install dependencies + run: pip install -r requirements/requirements-ci-just.txt + - name: Check ${{ matrix.just_cmd }} + run: just ${{ matrix.just_cmd }} diff --git a/.github/workflows/lint_precommit.yml b/.github/workflows/lint_precommit.yml new file mode 100644 index 0000000000..285de16c92 --- /dev/null +++ b/.github/workflows/lint_precommit.yml @@ -0,0 +1,11 @@ +name: Lint checks by pre-commit + +on: [push, pull_request] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + - uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/style_type_checks.yml b/.github/workflows/style_type_checks.yml deleted file mode 100644 index 48c2178edc..0000000000 --- a/.github/workflows/style_type_checks.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: Style and type checks - -on: [push, pull_request] - -jobs: - style-type-checks: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - uses: extractions/setup-just@v1 - - uses: actions/setup-python@v4 - - name: Install dependencies - run: | - pip install . - # todo: install also `black[jupyter]` - pip install click "black==24.02" "mypy==1.8.0" \ - types-python-dateutil types-waitress types-PyYAML - - name: Style check - run: just black - - name: Type check - run: just mypy - - name: Check license headers - run: just license diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..be0752d118 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,52 @@ +default_language_version: + python: python3 + +ci: + autofix_prs: true + autoupdate_commit_msg: "[pre-commit.ci] pre-commit suggestions" + autoupdate_schedule: monthly + # submodules: true + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + # - id: end-of-file-fixer + # - id: trailing-whitespace + - id: check-json + - id: check-yaml + - id: check-toml + # - id: check-docstring-first + # - id: check-executables-have-shebangs + # - id: check-case-conflict + # - id: check-added-large-files + - id: detect-private-key + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.3 + hooks: + - id: ruff + name: Run Ruff & fixing (Flake8) + args: ["--fix"] + + - repo: https://github.com/PyCQA/docformatter + rev: v1.5.0 + hooks: + - id: docformatter + name: Format docstrings with Docformatter + additional_dependencies: [tomli] + args: ["--in-place"] + + - repo: https://github.com/psf/black + rev: 24.2.0 + hooks: + - id: black + name: Format code with Black + + #- repo: https://github.com/executablebooks/mdformat + # rev: 0.7.17 + # hooks: + # - id: mdformat + # additional_dependencies: + # - mdformat-gfm + # - mdformat_frontmatter diff --git a/Justfile b/Justfile index 986927c91c..bd0800c3b8 100644 --- a/Justfile +++ b/Justfile @@ -33,9 +33,6 @@ compile_notebooks: release: python setup.py sdist -black: - black --check --diff --color src test examples - mypy: python setup.py type_check diff --git a/docs/conf.py b/docs/conf.py index d74e5c357e..c4b8ad8de6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -95,9 +95,9 @@ } if os.environ.get("GITHUB_REF_NAME") == "dev": - html_theme_options[ - "announcement" - ] = "Note: You are looking at the development docs." + html_theme_options["announcement"] = ( + "Note: You are looking at the development docs." + ) # Add any paths that contain custom static files (such as style sheets) here, diff --git a/examples/anomaly_detection.py b/examples/anomaly_detection.py index 094404486f..455d1fddcb 100644 --- a/examples/anomaly_detection.py +++ b/examples/anomaly_detection.py @@ -13,11 +13,11 @@ """ This example shows how to do anomaly detection with DeepAR. -The model is first trained and then time-points with the largest negative log-likelihood are plotted. +The model is first trained and then time-points with the largest negative log- +likelihood are plotted. """ import numpy as np -from itertools import islice from functools import partial import mxnet as mx import matplotlib.pyplot as plt diff --git a/examples/warm_start.py b/examples/warm_start.py index fd425615a8..fba7be65a0 100644 --- a/examples/warm_start.py +++ b/examples/warm_start.py @@ -12,8 +12,8 @@ # permissions and limitations under the License. """ -This example show how to intialize the network with parameters from a model that was previously trained. - +This example show how to intialize the network with parameters from a model +that was previously trained. """ from gluonts.dataset.repository import get_dataset, dataset_recipes diff --git a/pyproject.toml b/pyproject.toml index 96524e4daa..8205dc6171 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,13 +20,13 @@ filterwarnings = "ignore" [tool.ruff] line-length = 79 - lint.ignore = [ "E501", # line-length is handled by black "E741" # TODO: remove usage of `l` ] - -exclude = ["src/gluonts/nursery"] +exclude = ["src/gluonts/nursery"] # todo +[tool.ruff.lint.per-file-ignores] +"examples/*" = ["E402"] [tool.docformatter] diff --git a/requirements/requirements-articial-dataset.py b/requirements/requirements-articial-dataset.txt similarity index 100% rename from requirements/requirements-articial-dataset.py rename to requirements/requirements-articial-dataset.txt diff --git a/requirements/requirements-ci-just.txt b/requirements/requirements-ci-just.txt new file mode 100644 index 0000000000..c481648433 --- /dev/null +++ b/requirements/requirements-ci-just.txt @@ -0,0 +1,5 @@ +click +mypy==1.8.0 +types-python-dateutil +types-waitress +types-PyYAML \ No newline at end of file diff --git a/setup.py b/setup.py index eb628f654f..8c43857d15 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,8 @@ def run(self): import mypy.api excluded_folders = [ - str(p.parent.relative_to(ROOT)) for p in ROOT.glob("src/**/.typeunsafe") + str(p.parent.relative_to(ROOT)) + for p in ROOT.glob("src/**/.typeunsafe") ] if len(excluded_folders) > 0: diff --git a/src/gluonts/itertools.py b/src/gluonts/itertools.py index 44398e5eb6..fab1f35099 100644 --- a/src/gluonts/itertools.py +++ b/src/gluonts/itertools.py @@ -298,7 +298,7 @@ def split_into(xs: Sequence, n: int) -> Sequence: # e.g. 10 by 3 -> 4, 3, 3 relative_splits[:remainder] += 1 - return split(xs, np.cumsum(relative_splits)) + return split(xs, np.cumsum(relative_splits)) # type: ignore[arg-type] @dataclass diff --git a/src/gluonts/model/forecast.py b/src/gluonts/model/forecast.py index c69ae385fc..1b56b1fc02 100644 --- a/src/gluonts/model/forecast.py +++ b/src/gluonts/model/forecast.py @@ -178,14 +178,14 @@ class Quantile: @classmethod def from_float(cls, quantile: float) -> "Quantile": assert isinstance(quantile, float) - return cls(value=quantile, name=str(quantile)) + return cls(value=quantile, name=str(quantile)) # type: ignore[call-arg] @classmethod def from_str(cls, quantile: str) -> "Quantile": assert isinstance(quantile, str) try: - return cls(value=float(quantile), name=quantile) + return cls(value=float(quantile), name=quantile) # type: ignore[call-arg] except ValueError: m = re.match(r"^p(\d+)$", quantile) diff --git a/src/gluonts/mx/trainer/learning_rate_scheduler.py b/src/gluonts/mx/trainer/learning_rate_scheduler.py index 5f33445b1e..9a165748d2 100644 --- a/src/gluonts/mx/trainer/learning_rate_scheduler.py +++ b/src/gluonts/mx/trainer/learning_rate_scheduler.py @@ -247,8 +247,8 @@ def __init__( 0 <= min_lr <= base_lr ), "The value of `min_lr` should be >= 0 and <= base_lr" - self.lr_scheduler = MetricAttentiveScheduler( - patience=Patience( + self.lr_scheduler = MetricAttentiveScheduler( # type: ignore[call-arg] + patience=Patience( # type: ignore[call-arg] patience=patience, objective=Objective.from_str(objective) ), learning_rate=base_lr, diff --git a/src/gluonts/torch/model/patch_tst/module.py b/src/gluonts/torch/model/patch_tst/module.py index 5f5c3e9fbc..14f481ea8f 100644 --- a/src/gluonts/torch/model/patch_tst/module.py +++ b/src/gluonts/torch/model/patch_tst/module.py @@ -229,7 +229,7 @@ def forward( # shift time features by `prediction_length` so that they are # aligned with the target input. time_feat = take_last( - torch.cat((past_time_feat, future_time_feat), dim=1), + torch.cat((past_time_feat, future_time_feat), dim=1), # type: ignore[arg-type] dim=1, num=self.context_length, ) diff --git a/src/gluonts/torch/model/tft/layers.py b/src/gluonts/torch/model/tft/layers.py index d19ed7f059..9977a3d1fa 100644 --- a/src/gluonts/torch/model/tft/layers.py +++ b/src/gluonts/torch/model/tft/layers.py @@ -27,7 +27,7 @@ def forward(self, features: torch.Tensor) -> List[torch.Tensor]: # type: ignore concat_features = super().forward(features=features) if self._num_features > 1: - return torch.chunk(concat_features, self._num_features, dim=-1) + return torch.chunk(concat_features, self._num_features, dim=-1) # type: ignore[return-value] else: return [concat_features] diff --git a/src/gluonts/torch/modules/feature.py b/src/gluonts/torch/modules/feature.py index 852d3372c9..46cef7993d 100644 --- a/src/gluonts/torch/modules/feature.py +++ b/src/gluonts/torch/modules/feature.py @@ -38,7 +38,7 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: features, self._num_features, dim=-1 ) else: - cat_feature_slices = [features] + cat_feature_slices = [features] # type: ignore[assignment] return torch.cat( [ diff --git a/src/gluonts/zebras/_time_frame.py b/src/gluonts/zebras/_time_frame.py index b845252586..dba0994e53 100644 --- a/src/gluonts/zebras/_time_frame.py +++ b/src/gluonts/zebras/_time_frame.py @@ -453,11 +453,11 @@ def split( # If past_length is not provided, it will equal to `index`, since # `len(tf.split(5).past) == 5` - past_length: int = maybe.unwrap_or(past_length, index) + past_length: int = maybe.unwrap_or(past_length, index) # type: ignore[annotation-unchecked] # Same logic applies to future_length, except that we deduct from the # right. (We can't use past_length, since it can be unequal to index). - future_length: int = maybe.unwrap_or(future_length, len(self) - index) + future_length: int = maybe.unwrap_or(future_length, len(self) - index) # type: ignore[annotation-unchecked] if self.index is None: new_index = None diff --git a/test/core/test_component.py b/test/core/test_component.py index 241a245520..26524e5707 100644 --- a/test/core/test_component.py +++ b/test/core/test_component.py @@ -24,8 +24,8 @@ def __init__(self, x: float, y: float) -> None: self.x = x self.y = y - assert type(self.x) == float - assert type(self.y) == float + assert isinstance(self.x, float) + assert isinstance(self.y, float) def __eq__(self, that): return self.x == that.x and self.y == that.y @@ -41,9 +41,9 @@ def __init__(self, a: int, c: Complex, **kwargs) -> None: self.b = kwargs["b"] self.c = c - assert type(self.a) == int - assert type(self.b) == float - assert type(self.c) == Complex + assert isinstance(self.a, int) + assert isinstance(self.b, float) + assert isinstance(self.c, Complex) def __eq__(self, that): return self.a == that.a and self.b == that.b and self.c == that.c diff --git a/test/ev/test_metrics_compared_to_previous_approach.py b/test/ev/test_metrics_compared_to_previous_approach.py index 5adc2e5d2b..9faa6167fb 100644 --- a/test/ev/test_metrics_compared_to_previous_approach.py +++ b/test/ev/test_metrics_compared_to_previous_approach.py @@ -121,7 +121,7 @@ def get_data_batches(predictor, test_data): ) seasonality = get_seasonality(freq=forecast.start_date.freqstr) - freq = forecast.start_date.freqstr + # freq = forecast.start_date.freqstr other_data = { "label": np.array([label["target"]]), "seasonal_error": np.array( diff --git a/test/ext/hierarchicalforecast/test_predictor.py b/test/ext/hierarchicalforecast/test_predictor.py index 51d962289b..9a69f5b881 100644 --- a/test/ext/hierarchicalforecast/test_predictor.py +++ b/test/ext/hierarchicalforecast/test_predictor.py @@ -110,7 +110,7 @@ def test_predictor_serialization( ): train_datasets = sine7(nonnegative=True) - (train_dataset, test_dataset, metadata) = ( + _, _, metadata = ( train_datasets.train, train_datasets.test, train_datasets.metadata, @@ -152,7 +152,7 @@ def test_predictor_working( bias=1, ) - (train_dataset, test_dataset, metadata) = ( + _, test_dataset, metadata = ( train_datasets.train, train_datasets.test, train_datasets.metadata, diff --git a/test/mx/model/deepvar_hierarchical/test_projection.py b/test/mx/model/deepvar_hierarchical/test_projection.py index 4e57e341b5..ba68a5076c 100644 --- a/test/mx/model/deepvar_hierarchical/test_projection.py +++ b/test/mx/model/deepvar_hierarchical/test_projection.py @@ -146,4 +146,4 @@ def test_projection_mat(D): ) def test_projection_mat_expected_fail(D): with pytest.raises(AssertionError): - p = projection_mat(S=S, D=D) + projection_mat(S=S, D=D) diff --git a/test/test_itertools.py b/test/test_itertools.py index b9071d9ad0..6c69e29bb9 100644 --- a/test/test_itertools.py +++ b/test/test_itertools.py @@ -289,7 +289,7 @@ def test_join_items(): ] with pytest.raises(Exception): - oin_items(left, right, "strict") + oin_items(left, right, "strict") # noqa: F821 @pytest.mark.parametrize( diff --git a/test/zebras/test_batch.py b/test/zebras/test_batch.py index bd00613ed8..689770d94b 100644 --- a/test/zebras/test_batch.py +++ b/test/zebras/test_batch.py @@ -11,7 +11,6 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytest import numpy as np