Skip to content

Commit

Permalink
Fix numpy-2.0 test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
shchur committed Oct 17, 2024
1 parent 93cd9ba commit dd88a09
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/gluonts/dataset/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,4 +351,4 @@ def is_uniform(index: pd.PeriodIndex) -> bool:
False
"""

return cast(bool, np.all(np.diff(index.asi8) == index.freq.n))
return bool(np.all(np.diff(index.asi8) == index.freq.n))
2 changes: 1 addition & 1 deletion src/gluonts/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _location_for(self, idx, side="right") -> _SubIndex:
else:
local_idx = idx - self._offsets[part_no - 1]

return _SubIndex(part_no, local_idx)
return _SubIndex(int(part_no), int(local_idx))

def __getitem__(self, idx):
if isinstance(idx, slice):
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/model/forecast_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _unpack(batched) -> Iterator:
This assumes that arrays are wrapped in a nested structure of lists and
tuples, and each array has the same shape::
>>> a = np.arange(5)
>>> a = np.arange(5, dtype="O")
>>> batched = [a, (a, [a, a, a])]
>>> list(_unpack(batched))
[[0, (0, [0, 0, 0])],
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/mx/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def base_path() -> str:
best_epoch_info = {
"params_path": "{}-{}.params".format(base_path(), "init"),
"epoch_no": -1,
"score": np.Inf,
"score": float("inf"),
}

optimizer = mx.optimizer.Adam(
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/mx/trainer/learning_rate_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def should_update(self, metric: float) -> bool:

@dataclass
class Min(Objective):
best: float = np.Inf
best: float = float("inf")

def should_update(self, metric: float) -> bool:
return metric < self.best


@dataclass
class Max(Objective):
best: float = -np.Inf
best: float = -float("inf")

def should_update(self, metric: float) -> bool:
return metric > self.best
Expand Down

0 comments on commit dd88a09

Please sign in to comment.