Skip to content

Commit

Permalink
no preview
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Feb 22, 2024
1 parent 09af6c8 commit 98690f0
Show file tree
Hide file tree
Showing 151 changed files with 10,202 additions and 9,330 deletions.
2 changes: 1 addition & 1 deletion Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ release:
python setup.py sdist

black:
black --check --color --preview src test examples
black --check --color src test examples

mypy:
python setup.py type_check
Expand Down
22 changes: 12 additions & 10 deletions examples/benchmark_m4.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,17 @@ def evaluate(dataset_name, estimator):

df = pd.DataFrame(results)

sub_df = df[[
"dataset",
"estimator",
"RMSE",
"mean_wQuantileLoss",
"MASE",
"sMAPE",
"OWA",
"MSIS",
]]
sub_df = df[
[
"dataset",
"estimator",
"RMSE",
"mean_wQuantileLoss",
"MASE",
"sMAPE",
"OWA",
"MSIS",
]
]

print(sub_df.to_string())
12 changes: 7 additions & 5 deletions src/gluonts/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,11 +355,13 @@ def init_wrapper(*args, **kwargs):
# __init_args__ is not already set in order to avoid overriding a
# value set by a subclass initializer in super().__init__ calls
if not getattr(self, "__init_args__", {}):
self.__init_args__ = OrderedDict({
name: arg
for name, arg in sorted(all_args.items())
if not skip_encoding(arg)
})
self.__init_args__ = OrderedDict(
{
name: arg
for name, arg in sorted(all_args.items())
if not skip_encoding(arg)
}
)
self.__class__.__getnewargs_ex__ = validated_getnewargs_ex
self.__class__.__repr__ = validated_repr

Expand Down
12 changes: 7 additions & 5 deletions src/gluonts/dataset/arrow/dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ class ArrowDecoder:

@classmethod
def from_schema(cls, schema):
return cls([
(column.name[: -len("._np_shape")], column.name)
for column in schema
if column.name.endswith("._np_shape")
])
return cls(
[
(column.name[: -len("._np_shape")], column.name)
for column in schema
if column.name.endswith("._np_shape")
]
)

def decode(self, batch, row_number: int):
return next(self.decode_batch(batch.slice(row_number, row_number + 1)))
Expand Down
10 changes: 6 additions & 4 deletions src/gluonts/dataset/arrow/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,12 @@ def __post_init__(self):
self.decoder = ArrowDecoder.from_schema(self.reader.schema_arrow)

if not self._row_group_sizes:
self._row_group_sizes = np.cumsum([
self.reader.metadata.row_group(row_group).num_rows
for row_group in range(self.reader.metadata.num_row_groups)
])
self._row_group_sizes = np.cumsum(
[
self.reader.metadata.row_group(row_group).num_rows
for row_group in range(self.reader.metadata.num_row_groups)
]
)

def location_for(self, idx):
if idx == 0:
Expand Down
10 changes: 6 additions & 4 deletions src/gluonts/dataset/artificial/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,10 +714,12 @@ def __call__(self, x, field_name, global_state, **kwargs):
probs = [self.prob_fun(x, length=c) for c in self.cardinalities]
global_state[field_name] = probs
probs = global_state[field_name]
cats = np.array([
np.random.choice(np.arange(len(probs[i])), p=probs[i])
for i in range(len(probs))
])
cats = np.array(
[
np.random.choice(np.arange(len(probs[i])), p=probs[i])
for i in range(len(probs))
]
)
return cats


Expand Down
20 changes: 11 additions & 9 deletions src/gluonts/dataset/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,17 @@ def __len__(self) -> int:
return len(self._data_entries)

def __repr__(self) -> str:
info = ", ".join([
f"size={len(self)}",
f"freq={self.freq}",
f"num_feat_dynamic_real={self.num_feat_dynamic_real}",
f"num_past_feat_dynamic_real={self.num_past_feat_dynamic_real}",
f"num_feat_static_real={self.num_feat_static_real}",
f"num_feat_static_cat={self.num_feat_static_cat}",
f"static_cardinalities={self.static_cardinalities}",
])
info = ", ".join(
[
f"size={len(self)}",
f"freq={self.freq}",
f"num_feat_dynamic_real={self.num_feat_dynamic_real}",
f"num_past_feat_dynamic_real={self.num_past_feat_dynamic_real}",
f"num_feat_static_real={self.num_feat_static_real}",
f"num_feat_static_cat={self.num_feat_static_cat}",
f"static_cardinalities={self.static_cardinalities}",
]
)
return f"PandasDataset<{info}>"

@classmethod
Expand Down
28 changes: 16 additions & 12 deletions src/gluonts/dataset/repository/_lstnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,14 @@ def generate_lstnet_dataset(
for cat, ts in enumerate(timeseries):
sliced_ts = ts[:training_end]
if len(sliced_ts) > 0:
train_ts.append({
"target": sliced_ts.values,
"start": sliced_ts.index[0],
"feat_static_cat": [cat],
"item_id": cat,
})
train_ts.append(
{
"target": sliced_ts.values,
"start": sliced_ts.index[0],
"feat_static_cat": [cat],
"item_id": cat,
}
)

assert len(train_ts) == ds_info.num_series

Expand All @@ -184,12 +186,14 @@ def generate_lstnet_dataset(
prediction_start_date + ds_info.prediction_length
)
sliced_ts = ts[:prediction_end_date]
test_ts.append({
"target": sliced_ts.values,
"start": sliced_ts.index[0],
"feat_static_cat": [cat],
"item_id": cat,
})
test_ts.append(
{
"target": sliced_ts.values,
"start": sliced_ts.index[0],
"feat_static_cat": [cat],
"item_id": cat,
}
)

assert len(test_ts) == ds_info.num_series * ds_info.rolling_evaluations

Expand Down
30 changes: 17 additions & 13 deletions src/gluonts/dataset/repository/_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,23 @@ def normalize_category(c: str):
start = str(pd.Period(time_stamp, freq=subset.freq))
cat = [i, cat_map[category]]

train_data.append({
"target": target[: -subset.prediction_length],
"start": start,
"feat_static_cat": cat,
"item_id": series,
})

test_data.append({
"target": target,
"start": start,
"feat_static_cat": cat,
"item_id": series,
})
train_data.append(
{
"target": target[: -subset.prediction_length],
"start": start,
"feat_static_cat": cat,
"item_id": series,
}
)

test_data.append(
{
"target": target,
"start": start,
"feat_static_cat": cat,
"item_id": series,
}
)

meta = MetaData(
**metadata(
Expand Down
36 changes: 20 additions & 16 deletions src/gluonts/dataset/repository/_tsf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,23 +201,27 @@ def convert_data(
# timestamps
# - `item_id` is added for all datasets ... many datasets provide
# the "series_name"
test_data.append({
"target": data_entry["target"],
"start": str(
data_entry.get("start_timestamp", default_start_timestamp)
),
"item_id": data_entry.get("series_name", i),
"feat_static_cat": [i],
})
test_data.append(
{
"target": data_entry["target"],
"start": str(
data_entry.get("start_timestamp", default_start_timestamp)
),
"item_id": data_entry.get("series_name", i),
"feat_static_cat": [i],
}
)

train_data.append({
"target": data_entry["target"][:-train_offset],
"start": str(
data_entry.get("start_timestamp", default_start_timestamp)
),
"item_id": data_entry.get("series_name", i),
"feat_static_cat": [i],
})
train_data.append(
{
"target": data_entry["target"][:-train_offset],
"start": str(
data_entry.get("start_timestamp", default_start_timestamp)
),
"item_id": data_entry.get("series_name", i),
"feat_static_cat": [i],
}
)

return train_data, test_data

Expand Down
14 changes: 8 additions & 6 deletions src/gluonts/dataset/schema/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,14 @@ class TokenStream:

@classmethod
def from_str(cls, s):
stream = cls([
Token(name, value, match)
for match in re.finditer(cls.RX, s)
for name, value in valfilter(bool, match.groupdict()).items()
if name != "WHITESPACE"
])
stream = cls(
[
Token(name, value, match)
for match in re.finditer(cls.RX, s)
for name, value in valfilter(bool, match.groupdict()).items()
if name != "WHITESPACE"
]
)

for token in stream:
if token.name == "INVALID":
Expand Down
9 changes: 6 additions & 3 deletions src/gluonts/ev/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,12 @@ def update(self, data: Mapping[str, np.ndarray]) -> Self:
return self

def get(self) -> np.ndarray:
return self.post_process(**{
name: evaluator.get() for name, evaluator in self.metrics.items()
})
return self.post_process(
**{
name: evaluator.get()
for name, evaluator in self.metrics.items()
}
)


@runtime_checkable
Expand Down
44 changes: 27 additions & 17 deletions src/gluonts/evaluation/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,13 @@ def __call__(
# Thus we set dtype=np.float64 to convert masked values back to NaNs
# which are handled correctly by pandas Dataframes during
# aggregation.
metrics_per_ts = metrics_per_ts.astype({
col: np.float64
for col in metrics_per_ts.columns
if col not in ["item_id", "forecast_start"]
})
metrics_per_ts = metrics_per_ts.astype(
{
col: np.float64
for col in metrics_per_ts.columns
if col not in ["item_id", "forecast_start"]
}
)

return self.get_aggregate_metrics(metrics_per_ts)

Expand Down Expand Up @@ -534,18 +536,26 @@ def get_aggregate_metrics(
totals[f"QuantileLoss[{quantile}]"] / totals["abs_target_sum"]
)

totals["mean_absolute_QuantileLoss"] = np.array([
totals[f"QuantileLoss[{quantile}]"] for quantile in self.quantiles
]).mean()

totals["mean_wQuantileLoss"] = np.array([
totals[f"wQuantileLoss[{quantile}]"] for quantile in self.quantiles
]).mean()

totals["MAE_Coverage"] = np.mean([
np.abs(totals[f"Coverage[{quantile}]"] - np.array([q.value]))
for q in self.quantiles
])
totals["mean_absolute_QuantileLoss"] = np.array(
[
totals[f"QuantileLoss[{quantile}]"]
for quantile in self.quantiles
]
).mean()

totals["mean_wQuantileLoss"] = np.array(
[
totals[f"wQuantileLoss[{quantile}]"]
for quantile in self.quantiles
]
).mean()

totals["MAE_Coverage"] = np.mean(
[
np.abs(totals[f"Coverage[{quantile}]"] - np.array([q.value]))
for q in self.quantiles
]
)

# Compute OWA if required
if self.calculate_owa:
Expand Down
Loading

0 comments on commit 98690f0

Please sign in to comment.