Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lint: freeze & run Black version 24.02 #3131

Merged
merged 13 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions .github/workflows/style_type_checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,16 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: extractions/setup-just@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Python 3.8
uses: actions/setup-python@v4
with:
python-version: '3.8'
- uses: actions/setup-python@v4
- name: Install dependencies
run: |
pip install .
pip install click black mypy
pip install types-python-dateutil
pip install types-waitress
pip install types-PyYAML
- name: Style and type checks
run: |
just black
just mypy
# todo: install also `black[jupyter]`
Borda marked this conversation as resolved.
Show resolved Hide resolved
pip install click "black==24.01" "mypy==1.8.0" \
Borda marked this conversation as resolved.
Show resolved Hide resolved
types-python-dateutil types-waitress types-PyYAML
- name: Style check
run: just black
- name: Type check
run: just mypy
- name: Check license headers
run: just license
2 changes: 1 addition & 1 deletion Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ release:
python setup.py sdist

black:
black --check src test examples
black --check --color --preview src test examples

mypy:
python setup.py type_check
Expand Down
1 change: 1 addition & 0 deletions examples/anomaly_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
This example shows how to do anomaly detection with DeepAR.
The model is first trained and then time-points with the largest negative log-likelihood are plotted.
"""
import numpy as np
from itertools import islice
Expand Down
22 changes: 10 additions & 12 deletions examples/benchmark_m4.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,15 @@ def evaluate(dataset_name, estimator):

df = pd.DataFrame(results)

sub_df = df[
[
"dataset",
"estimator",
"RMSE",
"mean_wQuantileLoss",
"MASE",
"sMAPE",
"OWA",
"MSIS",
]
]
sub_df = df[[
"dataset",
"estimator",
"RMSE",
"mean_wQuantileLoss",
"MASE",
"sMAPE",
"OWA",
"MSIS",
]]

print(sub_df.to_string())
1 change: 1 addition & 0 deletions examples/warm_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

"""
This example show how to intialize the network with parameters from a model that was previously trained.
"""

from gluonts.dataset.repository import get_dataset, dataset_recipes
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[tool.black]
target-version = ['py38']
line-length = 79

[tool.pytest.ini_options]
Expand Down
12 changes: 5 additions & 7 deletions src/gluonts/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,11 @@ def init_wrapper(*args, **kwargs):
# __init_args__ is not already set in order to avoid overriding a
# value set by a subclass initializer in super().__init__ calls
if not getattr(self, "__init_args__", {}):
self.__init_args__ = OrderedDict(
{
name: arg
for name, arg in sorted(all_args.items())
if not skip_encoding(arg)
}
)
self.__init_args__ = OrderedDict({
name: arg
for name, arg in sorted(all_args.items())
if not skip_encoding(arg)
})
self.__class__.__getnewargs_ex__ = validated_getnewargs_ex
self.__class__.__repr__ = validated_repr

Expand Down
12 changes: 5 additions & 7 deletions src/gluonts/dataset/arrow/dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ class ArrowDecoder:

@classmethod
def from_schema(cls, schema):
return cls(
[
(column.name[: -len("._np_shape")], column.name)
for column in schema
if column.name.endswith("._np_shape")
]
)
return cls([
(column.name[: -len("._np_shape")], column.name)
for column in schema
if column.name.endswith("._np_shape")
])

def decode(self, batch, row_number: int):
return next(self.decode_batch(batch.slice(row_number, row_number + 1)))
Expand Down
10 changes: 4 additions & 6 deletions src/gluonts/dataset/arrow/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,10 @@ def __post_init__(self):
self.decoder = ArrowDecoder.from_schema(self.reader.schema_arrow)

if not self._row_group_sizes:
self._row_group_sizes = np.cumsum(
[
self.reader.metadata.row_group(row_group).num_rows
for row_group in range(self.reader.metadata.num_row_groups)
]
)
self._row_group_sizes = np.cumsum([
self.reader.metadata.row_group(row_group).num_rows
for row_group in range(self.reader.metadata.num_row_groups)
])

def location_for(self, idx):
if idx == 0:
Expand Down
10 changes: 4 additions & 6 deletions src/gluonts/dataset/artificial/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,12 +714,10 @@ def __call__(self, x, field_name, global_state, **kwargs):
probs = [self.prob_fun(x, length=c) for c in self.cardinalities]
global_state[field_name] = probs
probs = global_state[field_name]
cats = np.array(
[
np.random.choice(np.arange(len(probs[i])), p=probs[i])
for i in range(len(probs))
]
)
cats = np.array([
np.random.choice(np.arange(len(probs[i])), p=probs[i])
for i in range(len(probs))
])
return cats


Expand Down
20 changes: 9 additions & 11 deletions src/gluonts/dataset/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,17 +221,15 @@ def __len__(self) -> int:
return len(self._data_entries)

def __repr__(self) -> str:
info = ", ".join(
[
f"size={len(self)}",
f"freq={self.freq}",
f"num_feat_dynamic_real={self.num_feat_dynamic_real}",
f"num_past_feat_dynamic_real={self.num_past_feat_dynamic_real}",
f"num_feat_static_real={self.num_feat_static_real}",
f"num_feat_static_cat={self.num_feat_static_cat}",
f"static_cardinalities={self.static_cardinalities}",
]
)
info = ", ".join([
f"size={len(self)}",
f"freq={self.freq}",
f"num_feat_dynamic_real={self.num_feat_dynamic_real}",
f"num_past_feat_dynamic_real={self.num_past_feat_dynamic_real}",
f"num_feat_static_real={self.num_feat_static_real}",
f"num_feat_static_cat={self.num_feat_static_cat}",
f"static_cardinalities={self.static_cardinalities}",
])
return f"PandasDataset<{info}>"

@classmethod
Expand Down
28 changes: 12 additions & 16 deletions src/gluonts/dataset/repository/_lstnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,12 @@ def generate_lstnet_dataset(
for cat, ts in enumerate(timeseries):
sliced_ts = ts[:training_end]
if len(sliced_ts) > 0:
train_ts.append(
{
"target": sliced_ts.values,
"start": sliced_ts.index[0],
"feat_static_cat": [cat],
"item_id": cat,
}
)
train_ts.append({
"target": sliced_ts.values,
"start": sliced_ts.index[0],
"feat_static_cat": [cat],
"item_id": cat,
})

assert len(train_ts) == ds_info.num_series

Expand All @@ -186,14 +184,12 @@ def generate_lstnet_dataset(
prediction_start_date + ds_info.prediction_length
)
sliced_ts = ts[:prediction_end_date]
test_ts.append(
{
"target": sliced_ts.values,
"start": sliced_ts.index[0],
"feat_static_cat": [cat],
"item_id": cat,
}
)
test_ts.append({
"target": sliced_ts.values,
"start": sliced_ts.index[0],
"feat_static_cat": [cat],
"item_id": cat,
})

assert len(test_ts) == ds_info.num_series * ds_info.rolling_evaluations

Expand Down
30 changes: 13 additions & 17 deletions src/gluonts/dataset/repository/_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,23 +163,19 @@ def normalize_category(c: str):
start = str(pd.Period(time_stamp, freq=subset.freq))
cat = [i, cat_map[category]]

train_data.append(
{
"target": target[: -subset.prediction_length],
"start": start,
"feat_static_cat": cat,
"item_id": series,
}
)

test_data.append(
{
"target": target,
"start": start,
"feat_static_cat": cat,
"item_id": series,
}
)
train_data.append({
"target": target[: -subset.prediction_length],
"start": start,
"feat_static_cat": cat,
"item_id": series,
})

test_data.append({
"target": target,
"start": start,
"feat_static_cat": cat,
"item_id": series,
})

meta = MetaData(
**metadata(
Expand Down
36 changes: 16 additions & 20 deletions src/gluonts/dataset/repository/_tsf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,27 +201,23 @@ def convert_data(
# timestamps
# - `item_id` is added for all datasets ... many datasets provide
# the "series_name"
test_data.append(
{
"target": data_entry["target"],
"start": str(
data_entry.get("start_timestamp", default_start_timestamp)
),
"item_id": data_entry.get("series_name", i),
"feat_static_cat": [i],
}
)
test_data.append({
"target": data_entry["target"],
"start": str(
data_entry.get("start_timestamp", default_start_timestamp)
),
"item_id": data_entry.get("series_name", i),
"feat_static_cat": [i],
})

train_data.append(
{
"target": data_entry["target"][:-train_offset],
"start": str(
data_entry.get("start_timestamp", default_start_timestamp)
),
"item_id": data_entry.get("series_name", i),
"feat_static_cat": [i],
}
)
train_data.append({
"target": data_entry["target"][:-train_offset],
"start": str(
data_entry.get("start_timestamp", default_start_timestamp)
),
"item_id": data_entry.get("series_name", i),
"feat_static_cat": [i],
})

return train_data, test_data

Expand Down
14 changes: 6 additions & 8 deletions src/gluonts/dataset/schema/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,12 @@ class TokenStream:

@classmethod
def from_str(cls, s):
stream = cls(
[
Token(name, value, match)
for match in re.finditer(cls.RX, s)
for name, value in valfilter(bool, match.groupdict()).items()
if name != "WHITESPACE"
]
)
stream = cls([
Token(name, value, match)
for match in re.finditer(cls.RX, s)
for name, value in valfilter(bool, match.groupdict()).items()
if name != "WHITESPACE"
])

for token in stream:
if token.name == "INVALID":
Expand Down
9 changes: 3 additions & 6 deletions src/gluonts/ev/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,9 @@ def update(self, data: Mapping[str, np.ndarray]) -> Self:
return self

def get(self) -> np.ndarray:
return self.post_process(
**{
name: evaluator.get()
for name, evaluator in self.metrics.items()
}
)
return self.post_process(**{
name: evaluator.get() for name, evaluator in self.metrics.items()
})


@runtime_checkable
Expand Down
Loading
Loading