Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ci: update lints Ruff & docformatter #3130

Merged
merged 16 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions .github/workflows/flake8.yml

This file was deleted.

23 changes: 23 additions & 0 deletions .github/workflows/lints.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Ruff & Docformat

on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
check: ["ruff", "docformatter"]

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- name: Install tools
run: pip install "ruff==0.2.2" "docformatter[tomli]==1.5.0"
- name: Ruff (Flake8)
if: matrix.check == 'ruff'
working-directory: src/
run: ruff check .
- name: Docformatter
if: matrix.check == 'docformatter'
run: docformatter --check -r src/
2 changes: 1 addition & 1 deletion Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ release:
python setup.py sdist

black:
black --check --color src test examples
black --check --diff --color src test examples

mypy:
python setup.py type_check
Expand Down
2 changes: 1 addition & 1 deletion examples/persist_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# permissions and limitations under the License.

"""
This example shows how to serialize and deserialize a model
This example shows how to serialize and deserialize a model.
"""
import os
import pprint
Expand Down
8 changes: 3 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,15 @@ filterwarnings = "ignore"
line-length = 79

lint.ignore = [
# line-length is handled by black
"E501",

# TODO: remove usage of `l`
"E741"
"E501", # line-length is handled by black
"E741" # TODO: remove usage of `l`
]

exclude = ["src/gluonts/nursery"]


[tool.docformatter]
black = true
pre-summary-newline = true
make-summary-multi-line = true
wrap-descriptions = 79
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def get_version_cmdclass(version_file) -> dict:


class TypeCheckCommand(distutils.cmd.Command):
"""A custom command to run MyPy on the project sources."""
"""
A custom command to run MyPy on the project sources.
"""

description = "run MyPy on Python source files"
user_options = []
Expand Down
8 changes: 4 additions & 4 deletions src/gluonts/dataset/artificial/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,10 +1063,10 @@ def normalized_ar1(tau, x0=None, norm="minmax", sigma=1.0):
r"""
Returns an ar1 process with an auto correlation time of tau.

norm can be
None -> no normalization
'minmax' -> min_max_scaled
'standard' -> 0 mean, unit variance
norm can be:
- None -> no normalization
- 'minmax' -> min_max_scaled
- 'standard' -> 0 mean, unit variance
"""
assert norm in [None, "minmax", "standard"]
phi = lifted_numpy.exp(-1.0 / tau)
Expand Down
7 changes: 4 additions & 3 deletions src/gluonts/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ def infer_file_type(path):


def _rglob(path: Path, pattern="*", levels=1):
"""Like ``path.rglob(pattern)`` except this limits the number of sub
directories that are traversed. ``levels = 0`` is thus the same as
``path.glob(pattern)``.
"""
Like ``path.rglob(pattern)`` except this limits the number of sub
directories that are traversed.

``levels = 0`` is thus the same as ``path.glob(pattern)``.
"""
if levels is not None:
levels -= 1
Expand Down
1 change: 1 addition & 0 deletions src/gluonts/dataset/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __len__(self):
def _line_starts(self):
"""
Calculate the position for each line in the file.

This information can be used with ``file.seek`` to directly jump to a
specific line in the file.
"""
Expand Down
5 changes: 3 additions & 2 deletions src/gluonts/dataset/multivariate_grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ def _preprocess(self, dataset: Dataset) -> None:
The preprocess function iterates over the dataset to gather data that
is necessary for alignment.

This includes 1) Storing first/last timestamp in the dataset 2)
Storing the frequency of the dataset
This includes:
1. Storing first/last timestamp in the dataset
2. Storing the frequency of the dataset
"""
for data in dataset:
timestamp = data[FieldName.START]
Expand Down
11 changes: 8 additions & 3 deletions src/gluonts/dataset/schema/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def fields(self):

@dataclass
class Get(Op):
"""Extracts the field ``name`` from the input."""
"""
Extracts the field ``name`` from the input.
"""

name: str

Expand All @@ -69,7 +71,9 @@ def fields(self):

@dataclass
class GetAttr(Op):
"""Invokes ``obj.name``"""
"""
Invokes ``obj.name``.
"""

obj: Op
name: str
Expand Down Expand Up @@ -298,7 +302,8 @@ def parse(x: Union[str, list]) -> Op:

@dataclass
class Translator:
"""Simple translation for GluonTS Datasets.
"""
Simple translation for GluonTS Datasets.

A given translator transforms an input dictionary (data-entry) into an
output dictionary.
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/dataset/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def periods_between(
end: pd.Period,
) -> int:
"""
Count how many periods fit between ``start`` and ``end``
(inclusive). The frequency is taken from ``start``.
Count how many periods fit between ``start`` and ``end`` (inclusive). The
frequency is taken from ``start``.

For example:

Expand Down
6 changes: 4 additions & 2 deletions src/gluonts/ev/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def get(self) -> np.ndarray:

@dataclass
class Sum(Aggregation):
"""Map-reduce way of calculating the sum of a stream of values.
"""
Map-reduce way of calculating the sum of a stream of values.

`partial_result` represents one of two things, depending on the axis:
Case 1 - axis 0 is aggregated (axis is None or 0):
Expand Down Expand Up @@ -75,7 +76,8 @@ def get(self) -> np.ndarray:

@dataclass
class Mean(Aggregation):
"""Map-reduce way of calculating the mean of a stream of values.
"""
Map-reduce way of calculating the mean of a stream of values.

`partial_result` represents one of two things, depending on the axis:
Case 1 - axis 0 is aggregated (axis is None or 0):
Expand Down
67 changes: 49 additions & 18 deletions src/gluonts/ev/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,19 @@ class MetricCollection:
metrics: List[Metric]

def update(self, data: Mapping[str, np.ndarray]) -> Self:
"""Update metrics using a single data instance."""
"""
Update metrics using a single data instance.
"""

for metric in self.metrics:
metric.update(data)

return self

def update_all(self, stream: Iterator[Mapping[str, np.ndarray]]) -> Self:
"""Update metrics using a stream of data instances."""
"""
Update metrics using a stream of data instances.
"""

for element in stream:
self.update(element)
Expand All @@ -74,12 +78,16 @@ class Metric:
name: str

def update(self, data: Mapping[str, np.ndarray]) -> Self:
"""Update metric using a single data instance."""
"""
Update metric using a single data instance.
"""

raise NotImplementedError

def update_all(self, stream: Iterator[Mapping[str, np.ndarray]]) -> Self:
"""Update metric using a stream of data instances."""
"""
Update metric using a stream of data instances.
"""

for element in stream:
self.update(element)
Expand All @@ -92,7 +100,9 @@ def get(self) -> np.ndarray:

@dataclass
class DirectMetric(Metric):
"""A Metric which uses a single function and aggregation strategy."""
"""
A Metric which uses a single function and aggregation strategy.
"""

stat: Callable
aggregate: Aggregation
Expand All @@ -108,10 +118,11 @@ def get(self) -> np.ndarray:

@dataclass
class DerivedMetric(Metric):
"""A Metric that is computed using other metrics.
"""
A Metric that is computed using other metrics.

A derived metric updates multiple, simpler metrics independently and in
the end combines their results as defined in `post_process`.
A derived metric updates multiple, simpler metrics independently and in the
end combines their results as defined in `post_process`.
"""

metrics: Dict[str, Metric]
Expand Down Expand Up @@ -237,7 +248,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class MAE(BaseMetricDefinition):
"""Mean Absolute Error"""
"""
Mean Absolute Error.
"""

forecast_type: str = "0.5"

Expand All @@ -254,7 +267,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class MSE(BaseMetricDefinition):
"""Mean Squared Error"""
"""
Mean Squared Error.
"""

forecast_type: str = "mean"

Expand Down Expand Up @@ -295,7 +310,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class MAPE(BaseMetricDefinition):
"""Mean Absolute Percentage Error"""
"""
Mean Absolute Percentage Error.
"""

forecast_type: str = "0.5"

Expand All @@ -314,7 +331,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class SMAPE(BaseMetricDefinition):
"""Symmetric Mean Absolute Percentage Error"""
"""
Symmetric Mean Absolute Percentage Error.
"""

forecast_type: str = "0.5"

Expand All @@ -334,7 +353,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class MSIS(BaseMetricDefinition):
"""Mean Scaled Interval Score"""
"""
Mean Scaled Interval Score.
"""

alpha: float = 0.05

Expand All @@ -351,7 +372,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class MASE(BaseMetricDefinition):
"""Mean Absolute Scaled Error"""
"""
Mean Absolute Scaled Error.
"""

forecast_type: str = "0.5"

Expand Down Expand Up @@ -382,7 +405,9 @@ def __call__(self, axis: Optional[int] = None) -> DirectMetric:

@dataclass
class ND(BaseMetricDefinition):
"""Normalized Deviation"""
"""
Normalized Deviation.
"""

forecast_type: str = "0.5"

Expand Down Expand Up @@ -410,7 +435,9 @@ def __call__(self, axis: Optional[int] = None) -> DerivedMetric:

@dataclass
class RMSE(BaseMetricDefinition):
"""Root Mean Squared Error"""
"""
Root Mean Squared Error.
"""

forecast_type: str = "mean"

Expand All @@ -435,7 +462,9 @@ def __call__(self, axis: Optional[int] = None) -> DerivedMetric:

@dataclass
class NRMSE(BaseMetricDefinition):
"""RMSE, normalized by the mean absolute label"""
"""
RMSE, normalized by the mean absolute label.
"""

forecast_type: str = "mean"

Expand Down Expand Up @@ -582,7 +611,9 @@ def __call__(self, axis: Optional[int] = None) -> DerivedMetric:

@dataclass
class OWA(BaseMetricDefinition):
"""Overall Weighted Average"""
"""
Overall Weighted Average.
"""

forecast_type: str = "0.5"

Expand Down
Loading
Loading