Skip to content

Commit

Permalink
Merge branch 'dev' into precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Feb 2, 2024
2 parents d453687 + 9ea944b commit 9a83d1b
Show file tree
Hide file tree
Showing 85 changed files with 480 additions and 373 deletions.
16 changes: 10 additions & 6 deletions src/gluonts/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,16 @@ def validator(init):
init_params = inspect.signature(init).parameters
init_fields = {
param.name: (
param.annotation
if param.annotation != inspect.Parameter.empty
else Any,
param.default
if param.default != inspect.Parameter.empty
else ...,
(
param.annotation
if param.annotation != inspect.Parameter.empty
else Any
),
(
param.default
if param.default != inspect.Parameter.empty
else ...
),
)
for param in init_params.values()
if param.name != "self"
Expand Down
9 changes: 3 additions & 6 deletions src/gluonts/dataset/arrow/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,13 @@ def infer(
return ArrowStreamFile(path)

@abc.abstractmethod
def metadata(self) -> Dict[str, str]:
...
def metadata(self) -> Dict[str, str]: ...

@abc.abstractmethod
def __iter__(self):
...
def __iter__(self): ...

@abc.abstractmethod
def __len__(self):
...
def __len__(self): ...


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/dataset/artificial/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
self,
num_timeseries: int = 10,
num_steps: int = 30,
freq: str = "1H",
freq: str = "1h",
start: str = "2000-01-01 00:00:00",
# Generates constant dataset of 0s with explicit NaN missing values
is_nan: bool = False,
Expand Down
6 changes: 3 additions & 3 deletions src/gluonts/dataset/repository/_lstnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def generate_lstnet_dataset(
meta = MetaData(
**metadata(
cardinality=ds_info.num_series,
freq=ds_info.freq
if ds_info.agg_freq is None
else ds_info.agg_freq,
freq=(
ds_info.freq if ds_info.agg_freq is None else ds_info.agg_freq
),
prediction_length=prediction_length or ds_info.prediction_length,
)
)
Expand Down
12 changes: 6 additions & 6 deletions src/gluonts/dataset/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,12 +418,12 @@ def calculate_dataset_statistics(ts_dataset: Any) -> DatasetStatistics:
max_target_length=max_target_length,
min_target=min_target,
num_missing_values=num_missing_values,
feat_static_real=observed_feat_static_real
if observed_feat_static_real
else [],
feat_static_cat=observed_feat_static_cat
if observed_feat_static_cat
else [],
feat_static_real=(
observed_feat_static_real if observed_feat_static_real else []
),
feat_static_cat=(
observed_feat_static_cat if observed_feat_static_cat else []
),
num_past_feat_dynamic_real=num_past_feat_dynamic_real,
num_feat_dynamic_real=num_feat_dynamic_real,
num_feat_dynamic_cat=num_feat_dynamic_cat,
Expand Down
6 changes: 3 additions & 3 deletions src/gluonts/evaluation/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,9 @@ def get_base_metrics(
return {
"item_id": forecast.item_id,
"forecast_start": forecast.start_date,
"MSE": mse(pred_target, mean_fcst)
if mean_fcst is not None
else None,
"MSE": (
mse(pred_target, mean_fcst) if mean_fcst is not None else None
),
"abs_error": abs_error(pred_target, median_fcst),
"abs_target_sum": abs_target_sum(pred_target),
"abs_target_mean": abs_target_mean(pred_target),
Expand Down
10 changes: 5 additions & 5 deletions src/gluonts/ext/rotbaum/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,11 @@ def _get_and_cache_quantile_computation(
The quantile of the associated true value bin.
"""
if feature_vector_in_train not in self.quantile_dicts[quantile]:
self.quantile_dicts[quantile][
feature_vector_in_train
] = np.percentile(
self.id_to_bins[self.preds_to_id[feature_vector_in_train]],
quantile * 100,
self.quantile_dicts[quantile][feature_vector_in_train] = (
np.percentile(
self.id_to_bins[self.preds_to_id[feature_vector_in_train]],
quantile * 100,
)
)
return self.quantile_dicts[quantile][feature_vector_in_train]

Expand Down
16 changes: 10 additions & 6 deletions src/gluonts/ext/rotbaum/_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,17 @@ def train(
target_data[:, train_QRX_only_using_timestep],
)
self.model_list = [
QRX(
xgboost_params=self.model_params,
min_bin_size=self.min_bin_size,
model=self.model_list[train_QRX_only_using_timestep].model,
(
QRX(
xgboost_params=self.model_params,
min_bin_size=self.min_bin_size,
model=self.model_list[
train_QRX_only_using_timestep
].model,
)
if i != train_QRX_only_using_timestep
else self.model_list[i]
)
if i != train_QRX_only_using_timestep
else self.model_list[i]
for i in range(n_models)
]
with concurrent.futures.ThreadPoolExecutor(
Expand Down
6 changes: 2 additions & 4 deletions src/gluonts/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@

@runtime_checkable
class SizedIterable(Protocol):
def __len__(self):
...
def __len__(self): ...

def __iter__(self):
...
def __iter__(self): ...


T = TypeVar("T")
Expand Down
24 changes: 15 additions & 9 deletions src/gluonts/model/forecast_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,11 @@ def __call__(
yield QuantileForecast(
output.T,
start_date=batch[FieldName.FORECAST_START][i],
item_id=batch[FieldName.ITEM_ID][i]
if FieldName.ITEM_ID in batch
else None,
item_id=(
batch[FieldName.ITEM_ID][i]
if FieldName.ITEM_ID in batch
else None
),
info=batch["info"][i] if "info" in batch else None,
forecast_keys=self.quantiles,
)
Expand Down Expand Up @@ -181,9 +183,11 @@ def __call__(
yield SampleForecast(
output,
start_date=batch[FieldName.FORECAST_START][i],
item_id=batch[FieldName.ITEM_ID][i]
if FieldName.ITEM_ID in batch
else None,
item_id=(
batch[FieldName.ITEM_ID][i]
if FieldName.ITEM_ID in batch
else None
),
info=batch["info"][i] if "info" in batch else None,
)
assert i + 1 == len(batch[FieldName.FORECAST_START])
Expand Down Expand Up @@ -221,9 +225,11 @@ def __call__(
yield make_distribution_forecast(
distr,
start_date=batch[FieldName.FORECAST_START][i],
item_id=batch[FieldName.ITEM_ID][i]
if FieldName.ITEM_ID in batch
else None,
item_id=(
batch[FieldName.ITEM_ID][i]
if FieldName.ITEM_ID in batch
else None
),
info=batch["info"][i] if "info" in batch else None,
)
assert i + 1 == len(batch[FieldName.FORECAST_START])
17 changes: 10 additions & 7 deletions src/gluonts/mx/batchify.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,16 @@ def as_in_context(batch: dict, ctx: mx.Context = None) -> DataBatch:
Move data into new context, should only be in main process.
"""
batch = {
k: v.as_in_context(ctx) if isinstance(v, mx.nd.NDArray)
# Workaround due to MXNet not being able to handle NDArrays with 0 in
# shape properly:
else (
stack(v, ctx=ctx, dtype=v.dtype, variable_length=False)
if isinstance(v[0], np.ndarray) and 0 in v[0].shape
else v
k: (
v.as_in_context(ctx)
if isinstance(v, mx.nd.NDArray)
# Workaround due to MXNet not being able to handle NDArrays with 0 in
# shape properly:
else (
stack(v, ctx=ctx, dtype=v.dtype, variable_length=False)
if isinstance(v[0], np.ndarray) and 0 in v[0].shape
else v
)
)
for k, v in batch.items()
}
Expand Down
8 changes: 5 additions & 3 deletions src/gluonts/mx/block/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,11 @@ def mask(p, like):
# mask as output, instead of simply copy output to the first element
# in case that the base cell is ResidualCell
new_states = [
F.where(output_mask, next_states[0], states[0])
if p_outputs != 0.0
else next_states[0]
(
F.where(output_mask, next_states[0], states[0])
if p_outputs != 0.0
else next_states[0]
)
]
new_states.extend(
[
Expand Down
1 change: 1 addition & 0 deletions src/gluonts/mx/distribution/box_cox_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class BoxCoxTransform(Bijection):
`tol_lambda_1`
F
"""

arg_names = ["box_cox.lambda_1", "box_cox.lambda_2"]

@validated()
Expand Down
2 changes: 2 additions & 0 deletions src/gluonts/mx/distribution/inflated_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class ZeroInflatedBeta(ZeroAndOneInflatedBeta):
`(*batch_shape, *event_shape)`.
F
"""

is_reparameterizable = False

@validated()
Expand Down Expand Up @@ -145,6 +146,7 @@ class OneInflatedBeta(ZeroAndOneInflatedBeta):
`(*batch_shape, *event_shape)`.
F
"""

is_reparameterizable = False

@validated()
Expand Down
16 changes: 10 additions & 6 deletions src/gluonts/mx/distribution/lds.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,11 @@ def sample(
# (num_samples, batch_size, latent_dim, latent_dim)
# innovation_coeff_t: (num_samples, batch_size, 1, latent_dim)
emission_coeff_t, transition_coeff_t, innovation_coeff_t = (
_broadcast_param(coeff, axes=[0], sizes=[num_samples])
if num_samples is not None
else coeff
(
_broadcast_param(coeff, axes=[0], sizes=[num_samples])
if num_samples is not None
else coeff
)
for coeff in [
self.emission_coeff[t],
self.transition_coeff[t],
Expand Down Expand Up @@ -458,9 +460,11 @@ def sample(
if scale is None
else F.broadcast_mul(
samples,
scale.expand_dims(axis=1).expand_dims(axis=0)
if num_samples is not None
else scale.expand_dims(axis=1),
(
scale.expand_dims(axis=1).expand_dims(axis=0)
if num_samples is not None
else scale.expand_dims(axis=1)
),
)
)

Expand Down
4 changes: 1 addition & 3 deletions src/gluonts/mx/distribution/lowrank_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ def hybrid_forward(self, F, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
D_vector = self.proj[1](x_plus_w).squeeze(axis=-1)

d_bias = (
0.0
if self.sigma_init == 0.0
else inv_softplus(self.sigma_init**2)
0.0 if self.sigma_init == 0.0 else inv_softplus(self.sigma_init**2)
)

D_positive = (
Expand Down
4 changes: 1 addition & 3 deletions src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,7 @@ def domain_map(self, F, mu_vector, D_vector, W_vector=None):
"""

d_bias = (
inv_softplus(self.sigma_init**2)
if self.sigma_init > 0.0
else 0.0
inv_softplus(self.sigma_init**2) if self.sigma_init > 0.0 else 0.0
)

# sigma_minimum helps avoiding cholesky problems, we could also jitter
Expand Down
28 changes: 16 additions & 12 deletions src/gluonts/mx/model/deepar/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,11 @@ def unroll_encoder_imputation(
static_feat = F.concat(
embedded_cat,
feat_static_real,
F.log(scale)
if len(self.target_shape) == 0
else F.log(scale.squeeze(axis=1)),
(
F.log(scale)
if len(self.target_shape) == 0
else F.log(scale.squeeze(axis=1))
),
dim=1,
)

Expand Down Expand Up @@ -603,9 +605,9 @@ def unroll_encoder_imputation(
begin_state = self.rnn.begin_state(
func=F.zeros,
dtype=self.dtype,
batch_size=inputs.shape[0]
if isinstance(inputs, mx.nd.NDArray)
else 0,
batch_size=(
inputs.shape[0] if isinstance(inputs, mx.nd.NDArray) else 0
),
)

unroll_results = self.imputation_rnn_unroll(
Expand Down Expand Up @@ -726,9 +728,11 @@ def unroll_encoder_default(
static_feat = F.concat(
embedded_cat,
feat_static_real,
F.log(scale)
if len(self.target_shape) == 0
else F.log(scale.squeeze(axis=1)),
(
F.log(scale)
if len(self.target_shape) == 0
else F.log(scale.squeeze(axis=1))
),
dim=1,
)

Expand Down Expand Up @@ -757,9 +761,9 @@ def unroll_encoder_default(
begin_state = self.rnn.begin_state(
func=F.zeros,
dtype=self.dtype,
batch_size=inputs.shape[0]
if isinstance(inputs, mx.nd.NDArray)
else 0,
batch_size=(
inputs.shape[0] if isinstance(inputs, mx.nd.NDArray) else 0
),
)
state = begin_state
# This is a dummy computation to avoid deferred initialization error
Expand Down
3 changes: 3 additions & 0 deletions src/gluonts/mx/model/deepstate/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,14 @@
# series in the dataset.
FREQ_LONGEST_PERIOD_DICT = {
"M": 12, # yearly seasonality
"ME": 12, # yearly seasonality
"W": 52, # yearly seasonality
"D": 31, # monthly seasonality
"B": 22, # monthly seasonality
"H": 168, # weekly seasonality
"h": 168, # weekly seasonality
"T": 1440, # daily seasonality
"min": 1440, # daily seasonality
}


Expand Down
Loading

0 comments on commit 9a83d1b

Please sign in to comment.