Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backports for v0.14.0rc2 #3032

Merged
merged 7 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/style_type_checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
pip install click black mypy
pip install types-python-dateutil
pip install types-waitress
pip install types-PyYAML
- name: Style and type checks
run: |
just black
Expand Down
21 changes: 13 additions & 8 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,26 @@ SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build

APIDOC = sphinx-apidoc
APIDOC_OPTS = --implicit-namespaces --separate --module-first
APIDOC_ROOT = gluonts

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile

apidoc:
@rm -Rf api/$(APIDOC_ROOT)
@$(APIDOC) $(APIDOC_OPTS) -o api/$(APIDOC_ROOT) ../src/$(APIDOC_ROOT) setup* */bin/* test docs *pycache*
@rm -Rf api/$(APIDOC_ROOT)/modules.rst
@sed -i"" -e "s/$(APIDOC_ROOT) package/API Docs/" api/$(APIDOC_ROOT)/$(APIDOC_ROOT).rst
@rm -Rf api/gluonts

@sphinx-apidoc \
--implicit-namespaces \
--separate \
--module-first \
-o api/gluonts \
../src/gluonts \
../src/gluonts/nursery/* \
../src/gluonts/pydantic.py

@rm -Rf api/gluonts/modules.rst
@sed -i"" -e "s/gluonts package/API Docs/" api/gluonts/gluonts.rst

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
Expand Down
9 changes: 3 additions & 6 deletions requirements/requirements-pytorch.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
torch>=1.9,<3
lightning>=1.8,<2.2
lightning>=2.0,<2.2
# Capping `lightning` does not cap `pytorch_lightning`, so we cap manually
pytorch_lightning>=1.8,<2.2
# Need to pin protobuf (for now)
# See: https://github.com/PyTorchLightning/pytorch-lightning/issues/13159
protobuf~=3.19.0
pytorch_lightning>=2.0,<2.2
scipy~=1.10; python_version > "3.7.0"
scipy~=1.7.3; python_version <= "3.7.0"
scipy~=1.7.3; python_version <= "3.7.0"
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy~=1.16
pandas>=1.0,<3
pydantic~=1.7
pydantic>=1.7,<3
tqdm~=4.23
toolz~=0.10

Expand Down
24 changes: 15 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,24 @@ def run(self):
# otherwise a module-not-found error is thrown
import mypy.api

folders = [
str(p.parent.resolve()) for p in ROOT.glob("src/**/.typesafe")
excluded_folders = [
str(p.parent.relative_to(ROOT)) for p in ROOT.glob("src/**/.typeunsafe")
]

print(
"The following folders contain a `.typesafe` marker file "
"and will be type-checked with `mypy`:"
)
for folder in folders:
if len(excluded_folders) > 0:
print(
"The following folders contain a `.typeunsafe` marker file "
"and will *not* be type-checked with `mypy`:"
)
for folder in excluded_folders:
print(f" {folder}")

std_out, std_err, exit_code = mypy.api.run(folders)
args = [str(ROOT / "src")]
for folder in excluded_folders:
args.append("--exclude")
args.append(folder)

std_out, std_err, exit_code = mypy.api.run(args)

print(std_out, file=sys.stdout)
print(std_err, file=sys.stderr)
Expand All @@ -78,7 +84,7 @@ def run(self):
f"""
Mypy command

mypy {" ".join(folders)}
mypy {" ".join(args)}

returned a non-zero exit code. Fix the type errors listed above
and then run
Expand Down
11 changes: 8 additions & 3 deletions src/gluonts/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@
from typing import Any, Type, TypeVar

import numpy as np
from pydantic import BaseConfig, BaseModel, ValidationError, create_model

from gluonts.core import fqname_for
from gluonts.exceptions import GluonTSHyperparametersError
from gluonts.pydantic import (
BaseConfig,
BaseModel,
ValidationError,
create_model,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -252,7 +257,7 @@ def validated(base_model=None):
>>> c = ComplexNumber(y=None)
Traceback (most recent call last):
...
pydantic.error_wrappers.ValidationError: 1 validation error for
pydantic.v1.error_wrappers.ValidationError: 1 validation error for
ComplexNumberModel
y
none is not an allowed value (type=type_error.none.not_allowed)
Expand All @@ -262,7 +267,7 @@ def validated(base_model=None):
accessed through the ``Model`` attribute of the decorated initializer.

>>> ComplexNumber.__init__.Model
<class 'pydantic.main.ComplexNumberModel'>
<class 'pydantic.v1.main.ComplexNumberModel'>

The Pydantic model is synthesized automatically from on the parameter
names and types of the decorated initializer. In the ``ComplexNumber``
Expand Down
3 changes: 1 addition & 2 deletions src/gluonts/core/serde/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@

from toolz.dicttoolz import valmap

from pydantic import BaseModel

from gluonts.core import fqname_for
from gluonts.pydantic import BaseModel

bad_type_msg = textwrap.dedent(
"""
Expand Down
6 changes: 2 additions & 4 deletions src/gluonts/core/serde/_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
TypeVar,
)

import pydantic
import pydantic.dataclasses

from gluonts.itertools import select
from gluonts.pydantic import pydantic, dataclass as pydantic_dataclass

T = TypeVar("T")

Expand Down Expand Up @@ -152,7 +150,7 @@ class Config:
arbitrary_types_allowed = True

# make `cls` a dataclass
pydantic.dataclasses.dataclass(
pydantic_dataclass(
init=init,
repr=repr,
eq=eq,
Expand Down
3 changes: 1 addition & 2 deletions src/gluonts/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ def fn(debug):
from operator import attrgetter
from typing import Any

import pydantic
from pydantic.utils import deep_update
from gluonts.pydantic import pydantic, deep_update


class ListElement:
Expand Down
Empty file removed src/gluonts/dataset/.typesafe
Empty file.
4 changes: 1 addition & 3 deletions src/gluonts/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,17 @@
import pandas as pd
from pandas.tseries.frequencies import to_offset

import pydantic

from gluonts import json
from gluonts.itertools import Cached, Map
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.schema import Translator
from gluonts.exceptions import GluonTSDataError
from gluonts.pydantic import pydantic


from . import Dataset, DatasetCollection, DataEntry, DataBatch # noqa
from . import jsonl, DatasetWriter


arrow: Optional[ModuleType]

try:
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing import Callable, Iterable, Optional

import numpy as np
from pydantic import BaseModel

from gluonts.dataset import DataBatch, Dataset
from gluonts.itertools import (
Expand All @@ -25,6 +24,7 @@
batcher,
rows_to_columns,
)
from gluonts.pydantic import BaseModel
from gluonts.transform import (
AdhocTransform,
Identity,
Expand Down
11 changes: 11 additions & 0 deletions src/gluonts/ev/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class Sum(Aggregation):
partial_result: Optional[Union[List[np.ndarray], np.ndarray]] = None

def step(self, values: np.ndarray) -> None:
assert self.axis is None or isinstance(self.axis, tuple)

summed_values = np.ma.sum(values, axis=self.axis)

if self.axis is None or 0 in self.axis:
Expand All @@ -57,9 +59,12 @@ def step(self, values: np.ndarray) -> None:
else:
if self.partial_result is None:
self.partial_result = []
assert isinstance(self.partial_result, list)
self.partial_result.append(summed_values)

def get(self) -> np.ndarray:
assert self.axis is None or isinstance(self.axis, tuple)

if self.axis is None or 0 in self.axis:
return np.ma.copy(self.partial_result)

Expand All @@ -85,6 +90,8 @@ class Mean(Aggregation):
n: Optional[Union[int, np.ndarray]] = None

def step(self, values: np.ndarray) -> None:
assert self.axis is None or isinstance(self.axis, tuple)

if self.axis is None or 0 in self.axis:
summed_values = np.ma.sum(values, axis=self.axis)
if self.partial_result is None:
Expand All @@ -101,10 +108,14 @@ def step(self, values: np.ndarray) -> None:
self.partial_result = []

mean_values = np.ma.mean(values, axis=self.axis)
assert isinstance(self.partial_result, list)
self.partial_result.append(mean_values)

def get(self) -> np.ndarray:
assert self.axis is None or isinstance(self.axis, tuple)

if self.axis is None or 0 in self.axis:
assert isinstance(self.partial_result, np.ndarray)
return self.partial_result / self.n

return np.ma.concatenate(self.partial_result)
5 changes: 4 additions & 1 deletion src/gluonts/ev/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def __call__(self, axis: Optional[int] = None) -> Metric:


class BaseMetricDefinition:
def __call__(self, axis):
raise NotImplementedError()

def __add__(self, other) -> MetricDefinitionCollection:
if isinstance(other, MetricDefinitionCollection):
return other + self
Expand All @@ -154,7 +157,7 @@ def add(self, *others):

@dataclass
class MetricDefinitionCollection(BaseMetricDefinition):
metrics: List[MetricDefinition]
metrics: List[BaseMetricDefinition]

def __call__(self, axis: Optional[int] = None) -> MetricCollection:
return MetricCollection([metric(axis=axis) for metric in self.metrics])
Expand Down
1 change: 1 addition & 0 deletions src/gluonts/ev/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

def num_masked_target_values(data: Dict[str, np.ndarray]) -> np.ndarray:
if np.ma.isMaskedArray(data["label"]):
assert isinstance(data["label"], np.ma.MaskedArray)
return data["label"].mask.astype(float)
else:
return np.zeros(data["label"].shape)
Expand Down
Empty file removed src/gluonts/evaluation/.typesafe
Empty file.
2 changes: 1 addition & 1 deletion src/gluonts/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from typing import Any

from pydantic.error_wrappers import ValidationError, display_errors
from gluonts.pydantic import ValidationError, display_errors


class GluonTSException(Exception):
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/ext/naive_2/_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from gluonts.model.predictor import RepresentablePredictor


def seasonality_test(past_ts_data: np.array, season_length: int) -> bool:
def seasonality_test(past_ts_data: np.ndarray, season_length: int) -> bool:
"""
Test the time series for seasonal patterns by performing a 90% auto-
correlation test.
Expand Down
Empty file removed src/gluonts/ext/prophet/.typesafe
Empty file.
6 changes: 4 additions & 2 deletions src/gluonts/ext/r_forecast/_hierarchical_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"erm",
]

HIERARCHICAL_SAMPLE_FORECAST_METHODS = [] # TODO: Add `depbu_mint`.
HIERARCHICAL_SAMPLE_FORECAST_METHODS: List[str] = [] # TODO: Add `depbu_mint`.

SUPPORTED_HIERARCHICAL_METHODS = (
HIERARCHICAL_POINT_FORECAST_METHODS + HIERARCHICAL_SAMPLE_FORECAST_METHODS
Expand Down Expand Up @@ -200,6 +200,8 @@ def _get_r_forecast(self, data: Dict) -> Dict:
else:
hier_ts = self._hts_pkg.gts(y_bottom_ts, groups=nodes)

assert isinstance(self.params["num_samples"], int)

forecast = self._r_method(hier_ts, r_params)

all_forecasts = list(forecast)
Expand Down Expand Up @@ -244,7 +246,7 @@ def _forecast_dict_to_obj(
forecast_dict: Dict,
forecast_start_date: pd.Timestamp,
item_id: Optional[str],
info: Dict,
info: Optional[Dict],
) -> SampleForecast:
samples = np.array(forecast_dict["samples"])

Expand Down
4 changes: 2 additions & 2 deletions src/gluonts/ext/r_forecast/_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
self,
freq: str,
prediction_length: int,
period: int = None,
period: Optional[int] = None,
trunc_length: Optional[int] = None,
save_info: bool = False,
r_file_prefix: str = "",
Expand Down Expand Up @@ -215,7 +215,7 @@ def _forecast_dict_to_obj(
forecast_dict: Dict,
forecast_start_date: pd.Timestamp,
item_id: Optional[str],
info: Dict,
info: Optional[Dict],
) -> Forecast:
"""
Returns object of type `gluonts.model.Forecast`.
Expand Down
8 changes: 4 additions & 4 deletions src/gluonts/ext/r_forecast/_univariate_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Dict, Optional
from typing import Dict, Optional, Any

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(
freq: str,
prediction_length: int,
method_name: str = "ets",
period: int = None,
period: Optional[int] = None,
trunc_length: Optional[int] = None,
save_info: bool = False,
params: Dict = dict(),
Expand All @@ -100,7 +100,7 @@ def __init__(
self.method_name = method_name
self._r_method = self._robjects.r[method_name]

self.params = {
self.params: Dict[str, Any] = {
"prediction_length": self.prediction_length,
"frequency": self.period,
}
Expand Down Expand Up @@ -185,7 +185,7 @@ def _forecast_dict_to_obj(
forecast_dict: Dict,
forecast_start_date: pd.Timestamp,
item_id: Optional[str],
info: Dict,
info: Optional[Dict],
) -> QuantileForecast:
stats_dict = {"mean": forecast_dict["mean"]}

Expand Down
Loading
Loading