Skip to content

Commit

Permalink
Bump dependency version ranges for numpy & lightning (#3226)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup
  • Loading branch information
shchur authored Oct 17, 2024
1 parent 7668ce1 commit ede5664
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions requirements/requirements-pytorch.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
torch>=1.9,<3
lightning>=2.2.2,<2.4
lightning>=2.2.2,<2.5
# Capping `lightning` does not cap `pytorch_lightning`, so we cap manually
pytorch_lightning>=2.2.2,<2.4
pytorch_lightning>=2.2.2,<2.5
scipy~=1.10; python_version > "3.7.0"
scipy~=1.7.3; python_version <= "3.7.0"
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
numpy~=1.16
numpy>=1.16,<2.2
pandas>=1.0,<3
pydantic>=1.7,<3
tqdm~=4.23
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/dataset/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import logging
from dataclasses import dataclass, field, InitVar
from typing import Any, Iterable, Optional, Type, Union, cast
from typing import Any, Iterable, Optional, Type, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -351,4 +351,4 @@ def is_uniform(index: pd.PeriodIndex) -> bool:
False
"""

return cast(bool, np.all(np.diff(index.asi8) == index.freq.n))
return bool(np.all(np.diff(index.asi8) == index.freq.n))
2 changes: 1 addition & 1 deletion src/gluonts/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _location_for(self, idx, side="right") -> _SubIndex:
else:
local_idx = idx - self._offsets[part_no - 1]

return _SubIndex(part_no, local_idx)
return _SubIndex(int(part_no), int(local_idx))

def __getitem__(self, idx):
if isinstance(idx, slice):
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/model/forecast_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _unpack(batched) -> Iterator:
This assumes that arrays are wrapped in a nested structure of lists and
tuples, and each array has the same shape::
>>> a = np.arange(5)
>>> a = np.arange(5, dtype="O")
>>> batched = [a, (a, [a, a, a])]
>>> list(_unpack(batched))
[[0, (0, [0, 0, 0])],
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/mx/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def base_path() -> str:
best_epoch_info = {
"params_path": "{}-{}.params".format(base_path(), "init"),
"epoch_no": -1,
"score": np.Inf,
"score": float("inf"),
}

optimizer = mx.optimizer.Adam(
Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/mx/trainer/learning_rate_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def should_update(self, metric: float) -> bool:

@dataclass
class Min(Objective):
best: float = np.Inf
best: float = float("inf")

def should_update(self, metric: float) -> bool:
return metric < self.best


@dataclass
class Max(Objective):
best: float = -np.Inf
best: float = -float("inf")

def should_update(self, metric: float) -> bool:
return metric > self.best
Expand Down

0 comments on commit ede5664

Please sign in to comment.