Skip to content

Commit

Permalink
Merge branch 'dev' into add-torch-mqcnn
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored May 27, 2024
2 parents 5fac466 + b1e054a commit 4be80a8
Show file tree
Hide file tree
Showing 198 changed files with 1,138 additions and 855 deletions.
24 changes: 0 additions & 24 deletions .github/workflows/flake8.yml

This file was deleted.

23 changes: 23 additions & 0 deletions .github/workflows/lints.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Ruff & Docformat

on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
check: ["ruff", "docformatter"]

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- name: Install tools
run: pip install "ruff==0.2.2" "docformatter[tomli]==1.5.0"
- name: Ruff (Flake8)
if: matrix.check == 'ruff'
working-directory: src/
run: ruff check .
- name: Docformatter
if: matrix.check == 'docformatter'
run: docformatter --check -r src/
2 changes: 1 addition & 1 deletion Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ release:
python setup.py sdist

black:
black --check --color src test examples
black --check --diff --color src test examples

mypy:
python setup.py type_check
Expand Down
2 changes: 1 addition & 1 deletion examples/persist_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# permissions and limitations under the License.

"""
This example shows how to serialize and deserialize a model
This example shows how to serialize and deserialize a model.
"""
import os
import pprint
Expand Down
8 changes: 3 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,15 @@ filterwarnings = "ignore"
line-length = 79

lint.ignore = [
# line-length is handled by black
"E501",

# TODO: remove usage of `l`
"E741"
"E501", # line-length is handled by black
"E741" # TODO: remove usage of `l`
]

exclude = ["src/gluonts/nursery"]


[tool.docformatter]
black = true
pre-summary-newline = true
make-summary-multi-line = true
wrap-descriptions = 79
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-extras-anomaly-evaluation.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
numba~=0.51,<0.54
scikit-learn~=0.22
scikit-learn~=1.0
2 changes: 1 addition & 1 deletion requirements/requirements-extras-sagemaker-sdk.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
sagemaker~=2.0
sagemaker~=2.0,>2.214.3
s3fs~=0.6; python_version >= "3.7.0"
s3fs~=0.5; python_version < "3.7.0"
fsspec~=0.8,<0.9; python_version < "3.7.0"
2 changes: 1 addition & 1 deletion requirements/requirements-rotbaum.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
xgboost>=0.90,<2
scikit-learn>=0.22,<2
scikit-learn~=1.0
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def get_version_cmdclass(version_file) -> dict:


class TypeCheckCommand(distutils.cmd.Command):
"""A custom command to run MyPy on the project sources."""
"""
A custom command to run MyPy on the project sources.
"""

description = "run MyPy on Python source files"
user_options = []
Expand Down
14 changes: 13 additions & 1 deletion src/gluonts/core/serde/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,15 @@ def encode_partial(v: partial) -> Any:
}


decode_disallow = [
eval,
exec,
compile,
open,
input,
]


def decode(r: Any) -> Any:
"""
Decodes a value from an intermediate representation `r`.
Expand All @@ -312,7 +321,10 @@ def decode(r: Any) -> Any:
kind = r["__kind__"]
cls = cast(Any, locate(r["class"]))

assert cls is not None, f"Can not locate {r['class']}."
if cls is None:
raise ValueError(f"Cannot locate {r['class']}.")
if cls in decode_disallow:
raise ValueError(f"{r['class']} cannot be run.")

if kind == Kind.Type:
return cls
Expand Down
8 changes: 4 additions & 4 deletions src/gluonts/dataset/artificial/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,10 +1063,10 @@ def normalized_ar1(tau, x0=None, norm="minmax", sigma=1.0):
r"""
Returns an ar1 process with an auto correlation time of tau.
norm can be
None -> no normalization
'minmax' -> min_max_scaled
'standard' -> 0 mean, unit variance
norm can be:
- None -> no normalization
- 'minmax' -> min_max_scaled
- 'standard' -> 0 mean, unit variance
"""
assert norm in [None, "minmax", "standard"]
phi = lifted_numpy.exp(-1.0 / tau)
Expand Down
7 changes: 4 additions & 3 deletions src/gluonts/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,11 @@ def infer_file_type(path):


def _rglob(path: Path, pattern="*", levels=1):
"""Like ``path.rglob(pattern)`` except this limits the number of sub
directories that are traversed. ``levels = 0`` is thus the same as
``path.glob(pattern)``.
"""
Like ``path.rglob(pattern)`` except this limits the number of sub
directories that are traversed.
``levels = 0`` is thus the same as ``path.glob(pattern)``.
"""
if levels is not None:
levels -= 1
Expand Down
1 change: 1 addition & 0 deletions src/gluonts/dataset/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __len__(self):
def _line_starts(self):
"""
Calculate the position for each line in the file.
This information can be used with ``file.seek`` to directly jump to a
specific line in the file.
"""
Expand Down
5 changes: 3 additions & 2 deletions src/gluonts/dataset/multivariate_grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ def _preprocess(self, dataset: Dataset) -> None:
The preprocess function iterates over the dataset to gather data that
is necessary for alignment.
This includes 1) Storing first/last timestamp in the dataset 2)
Storing the frequency of the dataset
This includes:
1. Storing first/last timestamp in the dataset
2. Storing the frequency of the dataset
"""
for data in dataset:
timestamp = data[FieldName.START]
Expand Down
11 changes: 8 additions & 3 deletions src/gluonts/dataset/schema/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def fields(self):

@dataclass
class Get(Op):
"""Extracts the field ``name`` from the input."""
"""
Extracts the field ``name`` from the input.
"""

name: str

Expand All @@ -69,7 +71,9 @@ def fields(self):

@dataclass
class GetAttr(Op):
"""Invokes ``obj.name``"""
"""
Invokes ``obj.name``.
"""

obj: Op
name: str
Expand Down Expand Up @@ -298,7 +302,8 @@ def parse(x: Union[str, list]) -> Op:

@dataclass
class Translator:
"""Simple translation for GluonTS Datasets.
"""
Simple translation for GluonTS Datasets.
A given translator transforms an input dictionary (data-entry) into an
output dictionary.
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/dataset/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def periods_between(
end: pd.Period,
) -> int:
"""
Count how many periods fit between ``start`` and ``end``
(inclusive). The frequency is taken from ``start``.
Count how many periods fit between ``start`` and ``end`` (inclusive). The
frequency is taken from ``start``.
For example:
Expand Down
6 changes: 4 additions & 2 deletions src/gluonts/ev/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def get(self) -> np.ndarray:

@dataclass
class Sum(Aggregation):
"""Map-reduce way of calculating the sum of a stream of values.
"""
Map-reduce way of calculating the sum of a stream of values.
`partial_result` represents one of two things, depending on the axis:
Case 1 - axis 0 is aggregated (axis is None or 0):
Expand Down Expand Up @@ -75,7 +76,8 @@ def get(self) -> np.ndarray:

@dataclass
class Mean(Aggregation):
"""Map-reduce way of calculating the mean of a stream of values.
"""
Map-reduce way of calculating the mean of a stream of values.
`partial_result` represents one of two things, depending on the axis:
Case 1 - axis 0 is aggregated (axis is None or 0):
Expand Down
Loading

0 comments on commit 4be80a8

Please sign in to comment.