From 450860bc547f3452ceb35e28d538f8cb9d5352f4 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 20 Feb 2024 17:39:20 +0100 Subject: [PATCH 01/10] lint: run Black --- src/gluonts/dataset/arrow/file.py | 9 ++- src/gluonts/ext/rotbaum/_model.py | 10 +-- src/gluonts/itertools.py | 6 +- src/gluonts/mx/distribution/lowrank_gp.py | 4 +- .../lowrank_multivariate_gaussian.py | 4 +- src/gluonts/mx/model/deepvar/_estimator.py | 6 +- src/gluonts/mx/model/estimator.py | 6 +- src/gluonts/mx/model/seq2seq/_transform.py | 30 ++++----- src/gluonts/mx/model/tft/_estimator.py | 12 ++-- .../nursery/SCott/dataset_tools/synthetic.py | 12 ++-- src/gluonts/nursery/SCott/preprocess_data.py | 12 ++-- .../nursery/daf/tslib/dataset/timeseries.py | 21 ++++-- .../src/meta/data/batch.py | 6 +- .../src/meta/datasets/artificial.py | 12 ++-- .../pts/model/deepvar/deepvar_estimator.py | 6 +- .../pts/model/tft/tft_estimator.py | 12 ++-- .../transformer_tempflow_network.py | 6 +- .../pts/modules/gaussian_diffusion.py | 4 +- .../nursery/robust-mts-attack/utils.py | 30 ++++----- .../model/cop_deepar/_estimator.py | 9 ++- .../src/tsbench/recommender/_factory.py | 6 +- src/gluonts/torch/model/deepar/module.py | 6 +- src/gluonts/torch/model/tft/module.py | 65 +++++++++---------- src/gluonts/zebras/_period.py | 6 +- 24 files changed, 155 insertions(+), 145 deletions(-) diff --git a/src/gluonts/dataset/arrow/file.py b/src/gluonts/dataset/arrow/file.py index 7bdb6cf898..40f1a06638 100644 --- a/src/gluonts/dataset/arrow/file.py +++ b/src/gluonts/dataset/arrow/file.py @@ -51,13 +51,16 @@ def infer( return ArrowStreamFile(path) @abc.abstractmethod - def metadata(self) -> Dict[str, str]: ... + def metadata(self) -> Dict[str, str]: + ... @abc.abstractmethod - def __iter__(self): ... + def __iter__(self): + ... @abc.abstractmethod - def __len__(self): ... + def __len__(self): + ... @dataclass diff --git a/src/gluonts/ext/rotbaum/_model.py b/src/gluonts/ext/rotbaum/_model.py index 3c0270a49a..29bd1e3cf7 100644 --- a/src/gluonts/ext/rotbaum/_model.py +++ b/src/gluonts/ext/rotbaum/_model.py @@ -340,11 +340,11 @@ def _get_and_cache_quantile_computation( The quantile of the associated true value bin. """ if feature_vector_in_train not in self.quantile_dicts[quantile]: - self.quantile_dicts[quantile][feature_vector_in_train] = ( - np.percentile( - self.id_to_bins[self.preds_to_id[feature_vector_in_train]], - quantile * 100, - ) + self.quantile_dicts[quantile][ + feature_vector_in_train + ] = np.percentile( + self.id_to_bins[self.preds_to_id[feature_vector_in_train]], + quantile * 100, ) return self.quantile_dicts[quantile][feature_vector_in_train] diff --git a/src/gluonts/itertools.py b/src/gluonts/itertools.py index 198559d6f6..e90281a1d8 100644 --- a/src/gluonts/itertools.py +++ b/src/gluonts/itertools.py @@ -42,9 +42,11 @@ @runtime_checkable class SizedIterable(Protocol): - def __len__(self): ... + def __len__(self): + ... - def __iter__(self): ... + def __iter__(self): + ... T = TypeVar("T") diff --git a/src/gluonts/mx/distribution/lowrank_gp.py b/src/gluonts/mx/distribution/lowrank_gp.py index 2e6de964a8..443cc40acb 100644 --- a/src/gluonts/mx/distribution/lowrank_gp.py +++ b/src/gluonts/mx/distribution/lowrank_gp.py @@ -101,7 +101,9 @@ def hybrid_forward(self, F, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: D_vector = self.proj[1](x_plus_w).squeeze(axis=-1) d_bias = ( - 0.0 if self.sigma_init == 0.0 else inv_softplus(self.sigma_init**2) + 0.0 + if self.sigma_init == 0.0 + else inv_softplus(self.sigma_init**2) ) D_positive = ( diff --git a/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py b/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py index f446799367..b8131178a5 100644 --- a/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py +++ b/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py @@ -408,7 +408,9 @@ def domain_map(self, F, mu_vector, D_vector, W_vector=None): """ d_bias = ( - inv_softplus(self.sigma_init**2) if self.sigma_init > 0.0 else 0.0 + inv_softplus(self.sigma_init**2) + if self.sigma_init > 0.0 + else 0.0 ) # sigma_minimum helps avoiding cholesky problems, we could also jitter diff --git a/src/gluonts/mx/model/deepvar/_estimator.py b/src/gluonts/mx/model/deepvar/_estimator.py index 9a95ab0c9a..ceea7b588c 100644 --- a/src/gluonts/mx/model/deepvar/_estimator.py +++ b/src/gluonts/mx/model/deepvar/_estimator.py @@ -306,9 +306,9 @@ def __init__( self.scaling = scaling if self.use_marginal_transformation: - self.output_transform: Optional[Callable] = ( - cdf_to_gaussian_forward_transform - ) + self.output_transform: Optional[ + Callable + ] = cdf_to_gaussian_forward_transform else: self.output_transform = None diff --git a/src/gluonts/mx/model/estimator.py b/src/gluonts/mx/model/estimator.py index 122e34fad3..1cf992a796 100644 --- a/src/gluonts/mx/model/estimator.py +++ b/src/gluonts/mx/model/estimator.py @@ -177,9 +177,9 @@ def train_model( transformation = self.create_transformation() with env._let(max_idle_transforms=max(len(training_data), 100)): - transformed_training_data: Union[TransformedDataset, Cached] = ( - transformation.apply(training_data) - ) + transformed_training_data: Union[ + TransformedDataset, Cached + ] = transformation.apply(training_data) if cache_data: transformed_training_data = Cached(transformed_training_data) diff --git a/src/gluonts/mx/model/seq2seq/_transform.py b/src/gluonts/mx/model/seq2seq/_transform.py index 7c443e82d7..4134e8b01b 100644 --- a/src/gluonts/mx/model/seq2seq/_transform.py +++ b/src/gluonts/mx/model/seq2seq/_transform.py @@ -180,21 +180,21 @@ def flatmap_transform( # (Fortran) ordering with strides = # (dtype, dtype*n_rows) stride = decoder_fields.strides - out[self._future(ts_field)][pad_length_dec:] = ( - as_strided( - decoder_fields, - shape=( - self.num_forking - pad_length_dec, - self.dec_len, - ts_len, - ), - # strides for 2D array expanded to 3D array of - # shape (dim1, dim2, dim3) =(1, n_rows, n_cols). - # For transposed data, strides = (dtype, dtype * - # dim1, dtype*dim1*dim2) = (dtype, dtype, - # dtype*n_rows). - strides=stride[0:1] + stride, - ) + out[self._future(ts_field)][ + pad_length_dec: + ] = as_strided( + decoder_fields, + shape=( + self.num_forking - pad_length_dec, + self.dec_len, + ts_len, + ), + # strides for 2D array expanded to 3D array of + # shape (dim1, dim2, dim3) =(1, n_rows, n_cols). + # For transposed data, strides = (dtype, dtype * + # dim1, dtype*dim1*dim2) = (dtype, dtype, + # dtype*n_rows). + strides=stride[0:1] + stride, ) # edge case for prediction_length = 1 diff --git a/src/gluonts/mx/model/tft/_estimator.py b/src/gluonts/mx/model/tft/_estimator.py index a4d8335719..db19544603 100644 --- a/src/gluonts/mx/model/tft/_estimator.py +++ b/src/gluonts/mx/model/tft/_estimator.py @@ -172,13 +172,13 @@ def __init__( self.past_dynamic_feature_dims = {} for name in self.past_dynamic_features: if name in self.dynamic_cardinalities: - self.past_dynamic_cardinalities[name] = ( - self.dynamic_cardinalities.pop(name) - ) + self.past_dynamic_cardinalities[ + name + ] = self.dynamic_cardinalities.pop(name) elif name in self.dynamic_feature_dims: - self.past_dynamic_feature_dims[name] = ( - self.dynamic_feature_dims.pop(name) - ) + self.past_dynamic_feature_dims[ + name + ] = self.dynamic_feature_dims.pop(name) else: raise ValueError( f"Feature name {name} is not provided in feature dicts" diff --git a/src/gluonts/nursery/SCott/dataset_tools/synthetic.py b/src/gluonts/nursery/SCott/dataset_tools/synthetic.py index 8e2dc383b6..a16034b83c 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/synthetic.py +++ b/src/gluonts/nursery/SCott/dataset_tools/synthetic.py @@ -57,12 +57,12 @@ def get_mixed_pattern(unit_length=16, num_duplicates=1000): for j in range(num_duplicates): context = torch.arange(context_length, dtype=torch.float) for i in range(1, pattern_number): - context[unit_length * (i - 1) : unit_length * i] = ( - _get_mixed_pattern( - context[unit_length * (i - 1) : unit_length * i] - - unit_length * (i - 1), - pattern[(gid + i) % pattern_number], - ) + context[ + unit_length * (i - 1) : unit_length * i + ] = _get_mixed_pattern( + context[unit_length * (i - 1) : unit_length * i] + - unit_length * (i - 1), + pattern[(gid + i) % pattern_number], ) ts_sample = torch.cat( [ diff --git a/src/gluonts/nursery/SCott/preprocess_data.py b/src/gluonts/nursery/SCott/preprocess_data.py index 40b181820f..4f247eb229 100644 --- a/src/gluonts/nursery/SCott/preprocess_data.py +++ b/src/gluonts/nursery/SCott/preprocess_data.py @@ -57,12 +57,12 @@ def get_mixed_pattern(unit_length=16, num_duplicates=1000): for j in range(num_duplicates): context = torch.arange(context_length, dtype=torch.float) for i in range(1, pattern_number): - context[unit_length * (i - 1) : unit_length * i] = ( - _get_mixed_pattern( - context[unit_length * (i - 1) : unit_length * i] - - unit_length * (i - 1), - pattern[(gid + i) % pattern_number], - ) + context[ + unit_length * (i - 1) : unit_length * i + ] = _get_mixed_pattern( + context[unit_length * (i - 1) : unit_length * i] + - unit_length * (i - 1), + pattern[(gid + i) % pattern_number], ) ts_sample = torch.cat( [ diff --git a/src/gluonts/nursery/daf/tslib/dataset/timeseries.py b/src/gluonts/nursery/daf/tslib/dataset/timeseries.py index 4bd63964b8..0036ef6acf 100644 --- a/src/gluonts/nursery/daf/tslib/dataset/timeseries.py +++ b/src/gluonts/nursery/daf/tslib/dataset/timeseries.py @@ -335,10 +335,12 @@ def __len__(self): return len(self.target) @overload - def index_by_timestamp(self, index: pd.Timestamp) -> int: ... + def index_by_timestamp(self, index: pd.Timestamp) -> int: + ... @overload - def index_by_timestamp(self, index: List[pd.Timestamp]) -> List[int]: ... + def index_by_timestamp(self, index: List[pd.Timestamp]) -> List[int]: + ... def index_by_timestamp(self, index): return pd.Series( @@ -346,19 +348,24 @@ def index_by_timestamp(self, index): ).loc[index] @overload - def __getitem__(self, index: int) -> TimeSeriesInstant: ... + def __getitem__(self, index: int) -> TimeSeriesInstant: + ... @overload - def __getitem__(self, index: pd.Timestamp) -> TimeSeriesInstant: ... + def __getitem__(self, index: pd.Timestamp) -> TimeSeriesInstant: + ... @overload - def __getitem__(self, index: slice) -> TimeSeries: ... + def __getitem__(self, index: slice) -> TimeSeries: + ... @overload - def __getitem__(self, index: List[int]) -> TimeSeries: ... + def __getitem__(self, index: List[int]) -> TimeSeries: + ... @overload - def __getitem__(self, index: List[pd.Timestamp]) -> TimeSeries: ... + def __getitem__(self, index: List[pd.Timestamp]) -> TimeSeries: + ... def __getitem__(self, index): if isinstance(index, pd.Timestamp) or ( diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py b/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py index 6512d7f5df..23b679b16a 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py @@ -34,9 +34,9 @@ class SeriesBatch: ) # shape [batch, num_sequences, max_sequence_length] lengths: torch.Tensor # shape [batch] split_sections: torch.Tensor # shape [batch] - scales: Optional[torch.Tensor] = ( - None # shape[batch, 2] contains mean and std the ts has been scaled with - ) + scales: Optional[ + torch.Tensor + ] = None # shape[batch, 2] contains mean and std the ts has been scaled with @classmethod def from_lists( diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py index 5a2dab200b..19e1475c66 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py @@ -593,9 +593,9 @@ def generate_artificial_tuplets( np.arange(0, context_length - signal_length) ) si = np.random.choice(support_set_size) - support_set[si][marker_start : marker_start + signal_length] = ( - query[-signal_length:] - ) + support_set[si][ + marker_start : marker_start + signal_length + ] = query[-signal_length:] else: signal = np.concatenate( (np.ones((4,)), query[-prediction_length:]) @@ -647,9 +647,9 @@ def generate_artificial_tuplets( np.arange(0, context_length - signal_length) ) si = np.random.choice(support_set_size) - support_set[si][marker_start : marker_start + signal_length] = ( - query[-signal_length:] - ) + support_set[si][ + marker_start : marker_start + signal_length + ] = query[-signal_length:] # else: # signal = query[-prediction_length:] # signal_length = prediction_length diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py index 4bea3ddf23..fec7d5a2ee 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py @@ -139,9 +139,9 @@ def __init__( self.scaling = scaling if self.use_marginal_transformation: - self.output_transform: Optional[Callable] = ( - cdf_to_gaussian_forward_transform - ) + self.output_transform: Optional[ + Callable + ] = cdf_to_gaussian_forward_transform else: self.output_transform = None diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py index fdb47f7d87..af2805c0b3 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py @@ -106,13 +106,13 @@ def __init__( self.past_dynamic_feature_dims = {} for name in self.past_dynamic_features: if name in self.dynamic_cardinalities: - self.past_dynamic_cardinalities[name] = ( - self.dynamic_cardinalities.pop(name) - ) + self.past_dynamic_cardinalities[ + name + ] = self.dynamic_cardinalities.pop(name) elif name in self.dynamic_feature_dims: - self.past_dynamic_feature_dims[name] = ( - self.dynamic_feature_dims.pop(name) - ) + self.past_dynamic_feature_dims[ + name + ] = self.dynamic_feature_dims.pop(name) else: raise ValueError( f"Feature name {name} is not provided in feature dicts" diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py index 55d60d1537..56b7b4d457 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py @@ -165,11 +165,7 @@ def create_network_input( future_time_feat: Optional[torch.Tensor], future_target_cdf: Optional[torch.Tensor], target_dimension_indicator: torch.Tensor, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - ]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,]: """ Unrolls the RNN encoder over past and, if present, future data. Returns outputs and state of the encoder, plus the scale of diff --git a/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py b/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py index 73195e7957..2819be62a4 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py @@ -80,7 +80,9 @@ def __init__( if beta_schedule == "linear": betas = np.linspace(1e-4, beta_end, diff_steps) elif beta_schedule == "quad": - betas = np.linspace(1e-4**0.5, beta_end**0.5, diff_steps) ** 2 + betas = ( + np.linspace(1e-4**0.5, beta_end**0.5, diff_steps) ** 2 + ) elif beta_schedule == "const": betas = beta_end * np.ones(diff_steps) elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 diff --git a/src/gluonts/nursery/robust-mts-attack/utils.py b/src/gluonts/nursery/robust-mts-attack/utils.py index 75a36ec169..3bafea8d4d 100644 --- a/src/gluonts/nursery/robust-mts-attack/utils.py +++ b/src/gluonts/nursery/robust-mts-attack/utils.py @@ -263,13 +263,13 @@ def calc_loss( if ( true_future_target[:, attack_idx][..., target_items] != 0 ).prod() == 0: - mape[attack_type][testset_idx : testset_idx + batch_size] = ( - np.abs( - forecasts[attack_type][i][:, :, attack_idx][ - ..., target_items - ].mean(1) - - true_future_target[:, attack_idx][..., target_items] - ) + mape[attack_type][ + testset_idx : testset_idx + batch_size + ] = np.abs( + forecasts[attack_type][i][:, :, attack_idx][ + ..., target_items + ].mean(1) + - true_future_target[:, attack_idx][..., target_items] ) mse[attack_type][testset_idx : testset_idx + batch_size] = ( forecasts[attack_type][i][:, :, attack_idx][ @@ -290,14 +290,14 @@ def calc_loss( j, testset_idx : testset_idx + batch_size ] = quantile_loss(true, pred, quantile) else: - mape[attack_type][testset_idx : testset_idx + batch_size] = ( - np.abs( - forecasts[attack_type][i][:, :, attack_idx][ - ..., target_items - ].mean(1) - / true_future_target[:, attack_idx][..., target_items] - - 1 - ) + mape[attack_type][ + testset_idx : testset_idx + batch_size + ] = np.abs( + forecasts[attack_type][i][:, :, attack_idx][ + ..., target_items + ].mean(1) + / true_future_target[:, attack_idx][..., target_items] + - 1 ) mse[attack_type][testset_idx : testset_idx + batch_size] = ( mape[attack_type][testset_idx : testset_idx + batch_size] diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py index 8c8b297fa6..03bb055f00 100644 --- a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py +++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py @@ -238,11 +238,10 @@ def __init__( # adapt window_length if RollingMeanValueImputation is used if isinstance(imputation_method, RollingMeanValueImputation): - base_estimator_hps_agg["imputation_method"] = ( - RollingMeanValueImputation( - window_size=imputation_method.window_size - // agg_multiple - ) + base_estimator_hps_agg[ + "imputation_method" + ] = RollingMeanValueImputation( + window_size=imputation_method.window_size // agg_multiple ) # Hack to enforce correct serialization of lags_seq and history length diff --git a/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py b/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py index a84ba975ff..96cbb43b37 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py @@ -16,9 +16,9 @@ from ._base import Recommender RECOMMENDER_REGISTRY: Dict[str, Type[Recommender[ModelConfig]]] = {} -ENSEMBLE_RECOMMENDER_REGISTRY: Dict[str, Type[Recommender[EnsembleConfig]]] = ( - {} -) +ENSEMBLE_RECOMMENDER_REGISTRY: Dict[ + str, Type[Recommender[EnsembleConfig]] +] = {} R = TypeVar("R", bound=Type[Recommender[ModelConfig]]) E = TypeVar("E", bound=Type[Recommender[EnsembleConfig]]) diff --git a/src/gluonts/torch/model/deepar/module.py b/src/gluonts/torch/model/deepar/module.py index ebe4ef2479..406468079c 100644 --- a/src/gluonts/torch/model/deepar/module.py +++ b/src/gluonts/torch/model/deepar/module.py @@ -221,11 +221,7 @@ def prepare_rnn_input( past_observed_values: torch.Tensor, future_time_feat: torch.Tensor, future_target: Optional[torch.Tensor] = None, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - ]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,]: context = past_target[..., -self.context_length :] observed_context = past_observed_values[..., -self.context_length :] diff --git a/src/gluonts/torch/model/tft/module.py b/src/gluonts/torch/model/tft/module.py index f113160cb7..f0dd7c00a1 100644 --- a/src/gluonts/torch/model/tft/module.py +++ b/src/gluonts/torch/model/tft/module.py @@ -102,66 +102,65 @@ def __init__( self.target_proj = nn.Linear(in_features=1, out_features=self.d_var) # Past-only dynamic features if self.d_past_feat_dynamic_real: - self.past_feat_dynamic_proj: Optional[FeatureProjector] = ( - FeatureProjector( - feature_dims=self.d_past_feat_dynamic_real, - embedding_dims=[self.d_var] - * len(self.d_past_feat_dynamic_real), - ) + self.past_feat_dynamic_proj: Optional[ + FeatureProjector + ] = FeatureProjector( + feature_dims=self.d_past_feat_dynamic_real, + embedding_dims=[self.d_var] + * len(self.d_past_feat_dynamic_real), ) else: self.past_feat_dynamic_proj = None if self.c_past_feat_dynamic_cat: - self.past_feat_dynamic_embed: Optional[FeatureEmbedder] = ( - FeatureEmbedder( - cardinalities=self.c_past_feat_dynamic_cat, - embedding_dims=[self.d_var] - * len(self.c_past_feat_dynamic_cat), - ) + self.past_feat_dynamic_embed: Optional[ + FeatureEmbedder + ] = FeatureEmbedder( + cardinalities=self.c_past_feat_dynamic_cat, + embedding_dims=[self.d_var] + * len(self.c_past_feat_dynamic_cat), ) else: self.past_feat_dynamic_embed = None # Known dynamic features if self.d_feat_dynamic_real: - self.feat_dynamic_proj: Optional[FeatureProjector] = ( - FeatureProjector( - feature_dims=self.d_feat_dynamic_real, - embedding_dims=[self.d_var] - * len(self.d_feat_dynamic_real), - ) + self.feat_dynamic_proj: Optional[ + FeatureProjector + ] = FeatureProjector( + feature_dims=self.d_feat_dynamic_real, + embedding_dims=[self.d_var] * len(self.d_feat_dynamic_real), ) else: self.feat_dynamic_proj = None if self.c_feat_dynamic_cat: - self.feat_dynamic_embed: Optional[FeatureEmbedder] = ( - FeatureEmbedder( - cardinalities=self.c_feat_dynamic_cat, - embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat), - ) + self.feat_dynamic_embed: Optional[ + FeatureEmbedder + ] = FeatureEmbedder( + cardinalities=self.c_feat_dynamic_cat, + embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat), ) else: self.feat_dynamic_embed = None # Static features if self.d_feat_static_real: - self.feat_static_proj: Optional[FeatureProjector] = ( - FeatureProjector( - feature_dims=self.d_feat_static_real, - embedding_dims=[self.d_var] * len(self.d_feat_static_real), - ) + self.feat_static_proj: Optional[ + FeatureProjector + ] = FeatureProjector( + feature_dims=self.d_feat_static_real, + embedding_dims=[self.d_var] * len(self.d_feat_static_real), ) else: self.feat_static_proj = None if self.c_feat_static_cat: - self.feat_static_embed: Optional[FeatureEmbedder] = ( - FeatureEmbedder( - cardinalities=self.c_feat_static_cat, - embedding_dims=[self.d_var] * len(self.c_feat_static_cat), - ) + self.feat_static_embed: Optional[ + FeatureEmbedder + ] = FeatureEmbedder( + cardinalities=self.c_feat_static_cat, + embedding_dims=[self.d_var] * len(self.c_feat_static_cat), ) else: self.feat_static_embed = None diff --git a/src/gluonts/zebras/_period.py b/src/gluonts/zebras/_period.py index 4cda3178a9..66f32e311b 100644 --- a/src/gluonts/zebras/_period.py +++ b/src/gluonts/zebras/_period.py @@ -330,10 +330,12 @@ def __len__(self): return len(self.data) @overload - def __getitem__(self, idx: int) -> Period: ... + def __getitem__(self, idx: int) -> Period: + ... @overload - def __getitem__(self, idx: slice) -> Periods: ... + def __getitem__(self, idx: slice) -> Periods: + ... def __getitem__(self, idx): if _is_number(idx): From d049a3c0c3b12b8374e3be531d5dee174cc0f0c1 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 20 Feb 2024 17:49:47 +0100 Subject: [PATCH 02/10] update --- .github/workflows/style_type_checks.yml | 6 +- src/gluonts/dataset/arrow/file.py | 9 +-- src/gluonts/ext/rotbaum/_model.py | 10 +-- src/gluonts/itertools.py | 6 +- src/gluonts/mx/distribution/lowrank_gp.py | 4 +- .../lowrank_multivariate_gaussian.py | 4 +- src/gluonts/mx/model/deepvar/_estimator.py | 6 +- src/gluonts/mx/model/estimator.py | 6 +- src/gluonts/mx/model/seq2seq/_transform.py | 30 ++++----- src/gluonts/mx/model/tft/_estimator.py | 12 ++-- .../nursery/SCott/dataset_tools/synthetic.py | 12 ++-- src/gluonts/nursery/SCott/preprocess_data.py | 12 ++-- .../nursery/daf/tslib/dataset/timeseries.py | 21 ++---- .../src/meta/data/batch.py | 6 +- .../src/meta/datasets/artificial.py | 12 ++-- .../pts/model/deepvar/deepvar_estimator.py | 6 +- .../pts/model/tft/tft_estimator.py | 12 ++-- .../transformer_tempflow_network.py | 6 +- .../pts/modules/gaussian_diffusion.py | 4 +- .../nursery/robust-mts-attack/utils.py | 30 ++++----- .../model/cop_deepar/_estimator.py | 9 +-- .../src/tsbench/recommender/_factory.py | 6 +- src/gluonts/torch/model/deepar/module.py | 6 +- src/gluonts/torch/model/tft/module.py | 65 ++++++++++--------- src/gluonts/zebras/_period.py | 6 +- 25 files changed, 147 insertions(+), 159 deletions(-) diff --git a/.github/workflows/style_type_checks.yml b/.github/workflows/style_type_checks.yml index 8ad3c0fa68..2e89970f5f 100644 --- a/.github/workflows/style_type_checks.yml +++ b/.github/workflows/style_type_checks.yml @@ -18,10 +18,8 @@ jobs: - name: Install dependencies run: | pip install . - pip install click black mypy - pip install types-python-dateutil - pip install types-waitress - pip install types-PyYAML + pip install click "black==24.01" "mypy==1.8.0" \ + types-python-dateutil types-waitress types-PyYAML - name: Style and type checks run: | just black diff --git a/src/gluonts/dataset/arrow/file.py b/src/gluonts/dataset/arrow/file.py index 40f1a06638..7bdb6cf898 100644 --- a/src/gluonts/dataset/arrow/file.py +++ b/src/gluonts/dataset/arrow/file.py @@ -51,16 +51,13 @@ def infer( return ArrowStreamFile(path) @abc.abstractmethod - def metadata(self) -> Dict[str, str]: - ... + def metadata(self) -> Dict[str, str]: ... @abc.abstractmethod - def __iter__(self): - ... + def __iter__(self): ... @abc.abstractmethod - def __len__(self): - ... + def __len__(self): ... @dataclass diff --git a/src/gluonts/ext/rotbaum/_model.py b/src/gluonts/ext/rotbaum/_model.py index 29bd1e3cf7..3c0270a49a 100644 --- a/src/gluonts/ext/rotbaum/_model.py +++ b/src/gluonts/ext/rotbaum/_model.py @@ -340,11 +340,11 @@ def _get_and_cache_quantile_computation( The quantile of the associated true value bin. """ if feature_vector_in_train not in self.quantile_dicts[quantile]: - self.quantile_dicts[quantile][ - feature_vector_in_train - ] = np.percentile( - self.id_to_bins[self.preds_to_id[feature_vector_in_train]], - quantile * 100, + self.quantile_dicts[quantile][feature_vector_in_train] = ( + np.percentile( + self.id_to_bins[self.preds_to_id[feature_vector_in_train]], + quantile * 100, + ) ) return self.quantile_dicts[quantile][feature_vector_in_train] diff --git a/src/gluonts/itertools.py b/src/gluonts/itertools.py index e90281a1d8..198559d6f6 100644 --- a/src/gluonts/itertools.py +++ b/src/gluonts/itertools.py @@ -42,11 +42,9 @@ @runtime_checkable class SizedIterable(Protocol): - def __len__(self): - ... + def __len__(self): ... - def __iter__(self): - ... + def __iter__(self): ... T = TypeVar("T") diff --git a/src/gluonts/mx/distribution/lowrank_gp.py b/src/gluonts/mx/distribution/lowrank_gp.py index 443cc40acb..2e6de964a8 100644 --- a/src/gluonts/mx/distribution/lowrank_gp.py +++ b/src/gluonts/mx/distribution/lowrank_gp.py @@ -101,9 +101,7 @@ def hybrid_forward(self, F, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: D_vector = self.proj[1](x_plus_w).squeeze(axis=-1) d_bias = ( - 0.0 - if self.sigma_init == 0.0 - else inv_softplus(self.sigma_init**2) + 0.0 if self.sigma_init == 0.0 else inv_softplus(self.sigma_init**2) ) D_positive = ( diff --git a/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py b/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py index b8131178a5..f446799367 100644 --- a/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py +++ b/src/gluonts/mx/distribution/lowrank_multivariate_gaussian.py @@ -408,9 +408,7 @@ def domain_map(self, F, mu_vector, D_vector, W_vector=None): """ d_bias = ( - inv_softplus(self.sigma_init**2) - if self.sigma_init > 0.0 - else 0.0 + inv_softplus(self.sigma_init**2) if self.sigma_init > 0.0 else 0.0 ) # sigma_minimum helps avoiding cholesky problems, we could also jitter diff --git a/src/gluonts/mx/model/deepvar/_estimator.py b/src/gluonts/mx/model/deepvar/_estimator.py index ceea7b588c..9a95ab0c9a 100644 --- a/src/gluonts/mx/model/deepvar/_estimator.py +++ b/src/gluonts/mx/model/deepvar/_estimator.py @@ -306,9 +306,9 @@ def __init__( self.scaling = scaling if self.use_marginal_transformation: - self.output_transform: Optional[ - Callable - ] = cdf_to_gaussian_forward_transform + self.output_transform: Optional[Callable] = ( + cdf_to_gaussian_forward_transform + ) else: self.output_transform = None diff --git a/src/gluonts/mx/model/estimator.py b/src/gluonts/mx/model/estimator.py index 1cf992a796..122e34fad3 100644 --- a/src/gluonts/mx/model/estimator.py +++ b/src/gluonts/mx/model/estimator.py @@ -177,9 +177,9 @@ def train_model( transformation = self.create_transformation() with env._let(max_idle_transforms=max(len(training_data), 100)): - transformed_training_data: Union[ - TransformedDataset, Cached - ] = transformation.apply(training_data) + transformed_training_data: Union[TransformedDataset, Cached] = ( + transformation.apply(training_data) + ) if cache_data: transformed_training_data = Cached(transformed_training_data) diff --git a/src/gluonts/mx/model/seq2seq/_transform.py b/src/gluonts/mx/model/seq2seq/_transform.py index 4134e8b01b..7c443e82d7 100644 --- a/src/gluonts/mx/model/seq2seq/_transform.py +++ b/src/gluonts/mx/model/seq2seq/_transform.py @@ -180,21 +180,21 @@ def flatmap_transform( # (Fortran) ordering with strides = # (dtype, dtype*n_rows) stride = decoder_fields.strides - out[self._future(ts_field)][ - pad_length_dec: - ] = as_strided( - decoder_fields, - shape=( - self.num_forking - pad_length_dec, - self.dec_len, - ts_len, - ), - # strides for 2D array expanded to 3D array of - # shape (dim1, dim2, dim3) =(1, n_rows, n_cols). - # For transposed data, strides = (dtype, dtype * - # dim1, dtype*dim1*dim2) = (dtype, dtype, - # dtype*n_rows). - strides=stride[0:1] + stride, + out[self._future(ts_field)][pad_length_dec:] = ( + as_strided( + decoder_fields, + shape=( + self.num_forking - pad_length_dec, + self.dec_len, + ts_len, + ), + # strides for 2D array expanded to 3D array of + # shape (dim1, dim2, dim3) =(1, n_rows, n_cols). + # For transposed data, strides = (dtype, dtype * + # dim1, dtype*dim1*dim2) = (dtype, dtype, + # dtype*n_rows). + strides=stride[0:1] + stride, + ) ) # edge case for prediction_length = 1 diff --git a/src/gluonts/mx/model/tft/_estimator.py b/src/gluonts/mx/model/tft/_estimator.py index db19544603..a4d8335719 100644 --- a/src/gluonts/mx/model/tft/_estimator.py +++ b/src/gluonts/mx/model/tft/_estimator.py @@ -172,13 +172,13 @@ def __init__( self.past_dynamic_feature_dims = {} for name in self.past_dynamic_features: if name in self.dynamic_cardinalities: - self.past_dynamic_cardinalities[ - name - ] = self.dynamic_cardinalities.pop(name) + self.past_dynamic_cardinalities[name] = ( + self.dynamic_cardinalities.pop(name) + ) elif name in self.dynamic_feature_dims: - self.past_dynamic_feature_dims[ - name - ] = self.dynamic_feature_dims.pop(name) + self.past_dynamic_feature_dims[name] = ( + self.dynamic_feature_dims.pop(name) + ) else: raise ValueError( f"Feature name {name} is not provided in feature dicts" diff --git a/src/gluonts/nursery/SCott/dataset_tools/synthetic.py b/src/gluonts/nursery/SCott/dataset_tools/synthetic.py index a16034b83c..8e2dc383b6 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/synthetic.py +++ b/src/gluonts/nursery/SCott/dataset_tools/synthetic.py @@ -57,12 +57,12 @@ def get_mixed_pattern(unit_length=16, num_duplicates=1000): for j in range(num_duplicates): context = torch.arange(context_length, dtype=torch.float) for i in range(1, pattern_number): - context[ - unit_length * (i - 1) : unit_length * i - ] = _get_mixed_pattern( - context[unit_length * (i - 1) : unit_length * i] - - unit_length * (i - 1), - pattern[(gid + i) % pattern_number], + context[unit_length * (i - 1) : unit_length * i] = ( + _get_mixed_pattern( + context[unit_length * (i - 1) : unit_length * i] + - unit_length * (i - 1), + pattern[(gid + i) % pattern_number], + ) ) ts_sample = torch.cat( [ diff --git a/src/gluonts/nursery/SCott/preprocess_data.py b/src/gluonts/nursery/SCott/preprocess_data.py index 4f247eb229..40b181820f 100644 --- a/src/gluonts/nursery/SCott/preprocess_data.py +++ b/src/gluonts/nursery/SCott/preprocess_data.py @@ -57,12 +57,12 @@ def get_mixed_pattern(unit_length=16, num_duplicates=1000): for j in range(num_duplicates): context = torch.arange(context_length, dtype=torch.float) for i in range(1, pattern_number): - context[ - unit_length * (i - 1) : unit_length * i - ] = _get_mixed_pattern( - context[unit_length * (i - 1) : unit_length * i] - - unit_length * (i - 1), - pattern[(gid + i) % pattern_number], + context[unit_length * (i - 1) : unit_length * i] = ( + _get_mixed_pattern( + context[unit_length * (i - 1) : unit_length * i] + - unit_length * (i - 1), + pattern[(gid + i) % pattern_number], + ) ) ts_sample = torch.cat( [ diff --git a/src/gluonts/nursery/daf/tslib/dataset/timeseries.py b/src/gluonts/nursery/daf/tslib/dataset/timeseries.py index 0036ef6acf..4bd63964b8 100644 --- a/src/gluonts/nursery/daf/tslib/dataset/timeseries.py +++ b/src/gluonts/nursery/daf/tslib/dataset/timeseries.py @@ -335,12 +335,10 @@ def __len__(self): return len(self.target) @overload - def index_by_timestamp(self, index: pd.Timestamp) -> int: - ... + def index_by_timestamp(self, index: pd.Timestamp) -> int: ... @overload - def index_by_timestamp(self, index: List[pd.Timestamp]) -> List[int]: - ... + def index_by_timestamp(self, index: List[pd.Timestamp]) -> List[int]: ... def index_by_timestamp(self, index): return pd.Series( @@ -348,24 +346,19 @@ def index_by_timestamp(self, index): ).loc[index] @overload - def __getitem__(self, index: int) -> TimeSeriesInstant: - ... + def __getitem__(self, index: int) -> TimeSeriesInstant: ... @overload - def __getitem__(self, index: pd.Timestamp) -> TimeSeriesInstant: - ... + def __getitem__(self, index: pd.Timestamp) -> TimeSeriesInstant: ... @overload - def __getitem__(self, index: slice) -> TimeSeries: - ... + def __getitem__(self, index: slice) -> TimeSeries: ... @overload - def __getitem__(self, index: List[int]) -> TimeSeries: - ... + def __getitem__(self, index: List[int]) -> TimeSeries: ... @overload - def __getitem__(self, index: List[pd.Timestamp]) -> TimeSeries: - ... + def __getitem__(self, index: List[pd.Timestamp]) -> TimeSeries: ... def __getitem__(self, index): if isinstance(index, pd.Timestamp) or ( diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py b/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py index 23b679b16a..6512d7f5df 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/data/batch.py @@ -34,9 +34,9 @@ class SeriesBatch: ) # shape [batch, num_sequences, max_sequence_length] lengths: torch.Tensor # shape [batch] split_sections: torch.Tensor # shape [batch] - scales: Optional[ - torch.Tensor - ] = None # shape[batch, 2] contains mean and std the ts has been scaled with + scales: Optional[torch.Tensor] = ( + None # shape[batch, 2] contains mean and std the ts has been scaled with + ) @classmethod def from_lists( diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py index 19e1475c66..5a2dab200b 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py @@ -593,9 +593,9 @@ def generate_artificial_tuplets( np.arange(0, context_length - signal_length) ) si = np.random.choice(support_set_size) - support_set[si][ - marker_start : marker_start + signal_length - ] = query[-signal_length:] + support_set[si][marker_start : marker_start + signal_length] = ( + query[-signal_length:] + ) else: signal = np.concatenate( (np.ones((4,)), query[-prediction_length:]) @@ -647,9 +647,9 @@ def generate_artificial_tuplets( np.arange(0, context_length - signal_length) ) si = np.random.choice(support_set_size) - support_set[si][ - marker_start : marker_start + signal_length - ] = query[-signal_length:] + support_set[si][marker_start : marker_start + signal_length] = ( + query[-signal_length:] + ) # else: # signal = query[-prediction_length:] # signal_length = prediction_length diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py index fec7d5a2ee..4bea3ddf23 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py @@ -139,9 +139,9 @@ def __init__( self.scaling = scaling if self.use_marginal_transformation: - self.output_transform: Optional[ - Callable - ] = cdf_to_gaussian_forward_transform + self.output_transform: Optional[Callable] = ( + cdf_to_gaussian_forward_transform + ) else: self.output_transform = None diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py index af2805c0b3..fdb47f7d87 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py @@ -106,13 +106,13 @@ def __init__( self.past_dynamic_feature_dims = {} for name in self.past_dynamic_features: if name in self.dynamic_cardinalities: - self.past_dynamic_cardinalities[ - name - ] = self.dynamic_cardinalities.pop(name) + self.past_dynamic_cardinalities[name] = ( + self.dynamic_cardinalities.pop(name) + ) elif name in self.dynamic_feature_dims: - self.past_dynamic_feature_dims[ - name - ] = self.dynamic_feature_dims.pop(name) + self.past_dynamic_feature_dims[name] = ( + self.dynamic_feature_dims.pop(name) + ) else: raise ValueError( f"Feature name {name} is not provided in feature dicts" diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py index 56b7b4d457..55d60d1537 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py @@ -165,7 +165,11 @@ def create_network_input( future_time_feat: Optional[torch.Tensor], future_target_cdf: Optional[torch.Tensor], target_dimension_indicator: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,]: + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: """ Unrolls the RNN encoder over past and, if present, future data. Returns outputs and state of the encoder, plus the scale of diff --git a/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py b/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py index 2819be62a4..73195e7957 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py @@ -80,9 +80,7 @@ def __init__( if beta_schedule == "linear": betas = np.linspace(1e-4, beta_end, diff_steps) elif beta_schedule == "quad": - betas = ( - np.linspace(1e-4**0.5, beta_end**0.5, diff_steps) ** 2 - ) + betas = np.linspace(1e-4**0.5, beta_end**0.5, diff_steps) ** 2 elif beta_schedule == "const": betas = beta_end * np.ones(diff_steps) elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 diff --git a/src/gluonts/nursery/robust-mts-attack/utils.py b/src/gluonts/nursery/robust-mts-attack/utils.py index 3bafea8d4d..75a36ec169 100644 --- a/src/gluonts/nursery/robust-mts-attack/utils.py +++ b/src/gluonts/nursery/robust-mts-attack/utils.py @@ -263,13 +263,13 @@ def calc_loss( if ( true_future_target[:, attack_idx][..., target_items] != 0 ).prod() == 0: - mape[attack_type][ - testset_idx : testset_idx + batch_size - ] = np.abs( - forecasts[attack_type][i][:, :, attack_idx][ - ..., target_items - ].mean(1) - - true_future_target[:, attack_idx][..., target_items] + mape[attack_type][testset_idx : testset_idx + batch_size] = ( + np.abs( + forecasts[attack_type][i][:, :, attack_idx][ + ..., target_items + ].mean(1) + - true_future_target[:, attack_idx][..., target_items] + ) ) mse[attack_type][testset_idx : testset_idx + batch_size] = ( forecasts[attack_type][i][:, :, attack_idx][ @@ -290,14 +290,14 @@ def calc_loss( j, testset_idx : testset_idx + batch_size ] = quantile_loss(true, pred, quantile) else: - mape[attack_type][ - testset_idx : testset_idx + batch_size - ] = np.abs( - forecasts[attack_type][i][:, :, attack_idx][ - ..., target_items - ].mean(1) - / true_future_target[:, attack_idx][..., target_items] - - 1 + mape[attack_type][testset_idx : testset_idx + batch_size] = ( + np.abs( + forecasts[attack_type][i][:, :, attack_idx][ + ..., target_items + ].mean(1) + / true_future_target[:, attack_idx][..., target_items] + - 1 + ) ) mse[attack_type][testset_idx : testset_idx + batch_size] = ( mape[attack_type][testset_idx : testset_idx + batch_size] diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py index 03bb055f00..8c8b297fa6 100644 --- a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py +++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py @@ -238,10 +238,11 @@ def __init__( # adapt window_length if RollingMeanValueImputation is used if isinstance(imputation_method, RollingMeanValueImputation): - base_estimator_hps_agg[ - "imputation_method" - ] = RollingMeanValueImputation( - window_size=imputation_method.window_size // agg_multiple + base_estimator_hps_agg["imputation_method"] = ( + RollingMeanValueImputation( + window_size=imputation_method.window_size + // agg_multiple + ) ) # Hack to enforce correct serialization of lags_seq and history length diff --git a/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py b/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py index 96cbb43b37..a84ba975ff 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/recommender/_factory.py @@ -16,9 +16,9 @@ from ._base import Recommender RECOMMENDER_REGISTRY: Dict[str, Type[Recommender[ModelConfig]]] = {} -ENSEMBLE_RECOMMENDER_REGISTRY: Dict[ - str, Type[Recommender[EnsembleConfig]] -] = {} +ENSEMBLE_RECOMMENDER_REGISTRY: Dict[str, Type[Recommender[EnsembleConfig]]] = ( + {} +) R = TypeVar("R", bound=Type[Recommender[ModelConfig]]) E = TypeVar("E", bound=Type[Recommender[EnsembleConfig]]) diff --git a/src/gluonts/torch/model/deepar/module.py b/src/gluonts/torch/model/deepar/module.py index 406468079c..ebe4ef2479 100644 --- a/src/gluonts/torch/model/deepar/module.py +++ b/src/gluonts/torch/model/deepar/module.py @@ -221,7 +221,11 @@ def prepare_rnn_input( past_observed_values: torch.Tensor, future_time_feat: torch.Tensor, future_target: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,]: + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: context = past_target[..., -self.context_length :] observed_context = past_observed_values[..., -self.context_length :] diff --git a/src/gluonts/torch/model/tft/module.py b/src/gluonts/torch/model/tft/module.py index f0dd7c00a1..f113160cb7 100644 --- a/src/gluonts/torch/model/tft/module.py +++ b/src/gluonts/torch/model/tft/module.py @@ -102,65 +102,66 @@ def __init__( self.target_proj = nn.Linear(in_features=1, out_features=self.d_var) # Past-only dynamic features if self.d_past_feat_dynamic_real: - self.past_feat_dynamic_proj: Optional[ - FeatureProjector - ] = FeatureProjector( - feature_dims=self.d_past_feat_dynamic_real, - embedding_dims=[self.d_var] - * len(self.d_past_feat_dynamic_real), + self.past_feat_dynamic_proj: Optional[FeatureProjector] = ( + FeatureProjector( + feature_dims=self.d_past_feat_dynamic_real, + embedding_dims=[self.d_var] + * len(self.d_past_feat_dynamic_real), + ) ) else: self.past_feat_dynamic_proj = None if self.c_past_feat_dynamic_cat: - self.past_feat_dynamic_embed: Optional[ - FeatureEmbedder - ] = FeatureEmbedder( - cardinalities=self.c_past_feat_dynamic_cat, - embedding_dims=[self.d_var] - * len(self.c_past_feat_dynamic_cat), + self.past_feat_dynamic_embed: Optional[FeatureEmbedder] = ( + FeatureEmbedder( + cardinalities=self.c_past_feat_dynamic_cat, + embedding_dims=[self.d_var] + * len(self.c_past_feat_dynamic_cat), + ) ) else: self.past_feat_dynamic_embed = None # Known dynamic features if self.d_feat_dynamic_real: - self.feat_dynamic_proj: Optional[ - FeatureProjector - ] = FeatureProjector( - feature_dims=self.d_feat_dynamic_real, - embedding_dims=[self.d_var] * len(self.d_feat_dynamic_real), + self.feat_dynamic_proj: Optional[FeatureProjector] = ( + FeatureProjector( + feature_dims=self.d_feat_dynamic_real, + embedding_dims=[self.d_var] + * len(self.d_feat_dynamic_real), + ) ) else: self.feat_dynamic_proj = None if self.c_feat_dynamic_cat: - self.feat_dynamic_embed: Optional[ - FeatureEmbedder - ] = FeatureEmbedder( - cardinalities=self.c_feat_dynamic_cat, - embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat), + self.feat_dynamic_embed: Optional[FeatureEmbedder] = ( + FeatureEmbedder( + cardinalities=self.c_feat_dynamic_cat, + embedding_dims=[self.d_var] * len(self.c_feat_dynamic_cat), + ) ) else: self.feat_dynamic_embed = None # Static features if self.d_feat_static_real: - self.feat_static_proj: Optional[ - FeatureProjector - ] = FeatureProjector( - feature_dims=self.d_feat_static_real, - embedding_dims=[self.d_var] * len(self.d_feat_static_real), + self.feat_static_proj: Optional[FeatureProjector] = ( + FeatureProjector( + feature_dims=self.d_feat_static_real, + embedding_dims=[self.d_var] * len(self.d_feat_static_real), + ) ) else: self.feat_static_proj = None if self.c_feat_static_cat: - self.feat_static_embed: Optional[ - FeatureEmbedder - ] = FeatureEmbedder( - cardinalities=self.c_feat_static_cat, - embedding_dims=[self.d_var] * len(self.c_feat_static_cat), + self.feat_static_embed: Optional[FeatureEmbedder] = ( + FeatureEmbedder( + cardinalities=self.c_feat_static_cat, + embedding_dims=[self.d_var] * len(self.c_feat_static_cat), + ) ) else: self.feat_static_embed = None diff --git a/src/gluonts/zebras/_period.py b/src/gluonts/zebras/_period.py index 66f32e311b..4cda3178a9 100644 --- a/src/gluonts/zebras/_period.py +++ b/src/gluonts/zebras/_period.py @@ -330,12 +330,10 @@ def __len__(self): return len(self.data) @overload - def __getitem__(self, idx: int) -> Period: - ... + def __getitem__(self, idx: int) -> Period: ... @overload - def __getitem__(self, idx: slice) -> Periods: - ... + def __getitem__(self, idx: slice) -> Periods: ... def __getitem__(self, idx): if _is_number(idx): From 44b2f687cf9d0b3c59632c0baa8e4079e56d18f3 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 20 Feb 2024 18:22:30 +0100 Subject: [PATCH 03/10] black[jupyter] --- .github/workflows/style_type_checks.yml | 2 +- Justfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/style_type_checks.yml b/.github/workflows/style_type_checks.yml index 2e89970f5f..34a073cda1 100644 --- a/.github/workflows/style_type_checks.yml +++ b/.github/workflows/style_type_checks.yml @@ -18,7 +18,7 @@ jobs: - name: Install dependencies run: | pip install . - pip install click "black==24.01" "mypy==1.8.0" \ + pip install click "black[jupyter]==24.01" "mypy==1.8.0" \ types-python-dateutil types-waitress types-PyYAML - name: Style and type checks run: | diff --git a/Justfile b/Justfile index 1b4243f147..a90ec87b92 100644 --- a/Justfile +++ b/Justfile @@ -34,7 +34,7 @@ release: python setup.py sdist black: - black --check src test examples + black --check --color --preview src test examples mypy: python setup.py type_check From ab6b351be99238594c412e82c947831744326284 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 20 Feb 2024 18:55:30 +0100 Subject: [PATCH 04/10] split --- .github/workflows/style_type_checks.yml | 15 +++++---------- pyproject.toml | 1 + 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/.github/workflows/style_type_checks.yml b/.github/workflows/style_type_checks.yml index 34a073cda1..08c1a3a73f 100644 --- a/.github/workflows/style_type_checks.yml +++ b/.github/workflows/style_type_checks.yml @@ -9,20 +9,15 @@ jobs: steps: - uses: actions/checkout@v3 - uses: extractions/setup-just@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Set up Python 3.8 - uses: actions/setup-python@v4 - with: - python-version: '3.8' + - uses: actions/setup-python@v4 - name: Install dependencies run: | pip install . pip install click "black[jupyter]==24.01" "mypy==1.8.0" \ types-python-dateutil types-waitress types-PyYAML - - name: Style and type checks - run: | - just black - just mypy + - name: Style check + run: just black + - name: Type check + run: just mypy - name: Check license headers run: just license diff --git a/pyproject.toml b/pyproject.toml index 8a19a453e2..89c5d91765 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,5 @@ [tool.black] +target-version = ['py38'] line-length = 79 [tool.pytest.ini_options] From 6e66f1ba544d0609a498eb12a6803eafc2cfe352 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 20 Feb 2024 19:09:03 +0100 Subject: [PATCH 05/10] run black --- examples/anomaly_detection.py | 1 + examples/benchmark_m4.py | 22 +- examples/warm_start.py | 1 + src/gluonts/core/component.py | 12 +- src/gluonts/dataset/arrow/dec.py | 12 +- src/gluonts/dataset/arrow/file.py | 10 +- src/gluonts/dataset/artificial/recipe.py | 10 +- src/gluonts/dataset/pandas.py | 20 +- src/gluonts/dataset/repository/_lstnet.py | 28 +- src/gluonts/dataset/repository/_m3.py | 30 +- .../dataset/repository/_tsf_datasets.py | 36 +- src/gluonts/dataset/schema/translate.py | 14 +- src/gluonts/ev/metrics.py | 9 +- src/gluonts/evaluation/_base.py | 44 +- src/gluonts/ext/hierarchicalforecast.py | 48 +- src/gluonts/ext/rotbaum/_model.py | 12 +- src/gluonts/ext/rotbaum/_predictor.py | 50 +- src/gluonts/ext/rotbaum/_preprocess.py | 70 +- src/gluonts/model/evaluation.py | 16 +- src/gluonts/model/forecast.py | 40 +- src/gluonts/model/forecast_generator.py | 18 +- src/gluonts/model/npts/_predictor.py | 12 +- src/gluonts/model/trivial/mean.py | 9 +- src/gluonts/mx/block/dropout.py | 12 +- src/gluonts/mx/block/regularization.py | 4 +- src/gluonts/mx/block/scaler.py | 2 +- src/gluonts/mx/block/sndense.py | 2 +- src/gluonts/mx/distribution/iresnet.py | 11 +- src/gluonts/mx/distribution/lds.py | 2 +- .../mx/model/deep_factor/_estimator.py | 28 +- src/gluonts/mx/model/deepvar/_estimator.py | 92 +- src/gluonts/mx/model/gpvar/_estimator.py | 88 +- src/gluonts/mx/model/n_beats/_ensemble.py | 10 +- src/gluonts/mx/model/renewal/_estimator.py | 30 +- src/gluonts/mx/model/renewal/_transform.py | 14 +- .../mx/model/seq2seq/_forking_estimator.py | 18 +- .../mx/model/seq2seq/_seq2seq_estimator.py | 44 +- src/gluonts/mx/model/tft/_estimator.py | 124 +- .../mx/model/tpp/deeptpp/_estimator.py | 26 +- src/gluonts/mx/model/tpp/forecast.py | 20 +- src/gluonts/mx/model/wavenet/_estimator.py | 54 +- .../mx/representation/binning_helpers.py | 12 +- .../SCott/dataset_tools/algo_clustering.py | 66 +- .../SCott/dataset_tools/electricity.py | 48 +- .../SCott/dataset_tools/exchange_rate.py | 36 +- .../SCott/dataset_tools/group_raw_data.py | 192 +- .../nursery/SCott/dataset_tools/synthetic.py | 28 +- .../nursery/SCott/dataset_tools/traffic.py | 72 +- .../nursery/SCott/model/ar/ar_estimator.py | 30 +- .../SCott/model/lstm/lstm_estimator.py | 30 +- src/gluonts/nursery/SCott/preprocess_data.py | 172 +- src/gluonts/nursery/daf/engine/parallel.py | 30 +- src/gluonts/nursery/daf/estimator/modules.py | 10 +- .../nursery/daf/tslib/dataset/timeseries.py | 130 +- src/gluonts/nursery/daf/tslib/metrics/dict.py | 10 +- .../nursery/daf/tslib/nn/attention/posemb.py | 10 +- .../daf/tslib/nn/attention/selfattn.py | 2 +- .../src/meta/datasets/artificial.py | 26 +- .../src/meta/datasets/cheat.py | 108 +- .../src/meta/datasets/datasets.py | 12 +- .../src/meta/datasets/super.py | 20 +- .../src/meta/models/module.py | 9 +- .../few_shot_prediction/src/scripts/data.py | 14 +- .../few_shot_prediction/src/scripts/train.py | 60 +- .../multivariate/datasets/dataset.py | 30 +- .../pts/dataset/repository/_m5.py | 16 +- .../robust-mts-attack/pts/feature/holiday.py | 38 +- .../causal_deepar/causal_deepar_network.py | 38 +- .../pts/model/deepar/deepar_network.py | 24 +- .../pts/model/deepvar/deepvar_estimator.py | 10 +- .../pts/model/deepvar/deepvar_network.py | 14 +- .../pts/model/n_beats/n_beats_ensemble.py | 10 +- .../pts/model/n_beats/n_beats_estimator.py | 30 +- .../pts/model/tempflow/tempflow_estimator.py | 80 +- .../pts/model/tempflow/tempflow_network.py | 14 +- .../pts/model/tft/tft_estimator.py | 210 +- .../pts/model/tft/tft_modules.py | 24 +- .../pts/model/time_grad/epsilon_theta.py | 18 +- .../model/time_grad/time_grad_estimator.py | 80 +- .../pts/model/time_grad/time_grad_network.py | 14 +- .../model/transformer/transformer_network.py | 24 +- .../transformer_tempflow_estimator.py | 90 +- .../transformer_tempflow_network.py | 14 +- .../robust-mts-attack/pts/modules/feature.py | 10 +- .../pts/modules/gaussian_diffusion.py | 2 +- .../nursery/robust-mts-attack/read_pickle.py | 30 +- .../nursery/robust-mts-attack/utils.py | 10 +- src/gluonts/nursery/san/_estimator.py | 146 +- .../model/cop_deepar/_estimator.py | 10 +- .../model/cop_deepar/_network.py | 36 +- .../model/cop_deepar/gnn.py | 2 +- .../utils/utils.py | 10 +- .../analysis/scripts/ensemble_recommender.py | 2 +- .../src/cli/analysis/scripts/recommender.py | 2 +- .../src/cli/analysis/scripts/surrogate.py | 2 +- .../nursery/tsbench/src/cli/utils/config.py | 31 +- .../src/tsbench/config/dataset/datasets.py | 88 +- .../src/tsbench/evaluations/aws/analytics.py | 12 +- .../src/tsbench/evaluations/tracking/_info.py | 40 +- .../tsbench/evaluations/tracking/ensemble.py | 10 +- .../src/tsbench/forecasts/evaluation.py | 20 +- .../tsbench/src/tsbench/recommender/greedy.py | 56 +- .../src/tsbench/surrogate/nonparametric.py | 38 +- .../tsbench/surrogate/transformers/config.py | 56 +- .../surrogate/transformers/performance.py | 11 +- src/gluonts/shell/sagemaker/dyn.py | 46 +- src/gluonts/time_feature/holiday.py | 20 +- src/gluonts/torch/model/estimator.py | 12 +- .../torch/model/i_transformer/estimator.py | 2 +- src/gluonts/torch/model/lag_tst/estimator.py | 2 +- .../torch/model/patch_tst/estimator.py | 2 +- src/gluonts/torch/model/patch_tst/module.py | 10 +- src/gluonts/torch/model/tft/layers.py | 20 +- src/gluonts/torch/model/wavenet/estimator.py | 104 +- src/gluonts/transform/feature.py | 36 +- src/gluonts/zebras/_period.py | 9 +- src/gluonts/zebras/_time_frame.py | 40 +- src/gluonts/zebras/schema.py | 12 +- test/conftest.py | 12 +- test/dataset/test_data_loader.py | 28 +- test/dataset/test_dataset_mutability.py | 22 +- test/dataset/test_multivariate_grouper.py | 26 +- test/dataset/test_pandas.py | 120 +- test/dataset/test_split.py | 10 +- test/ev/test_aggregations.py | 44 +- ...t_metrics_compared_to_previous_approach.py | 25 +- test/evaluation/test_evaluator.py | 248 +- test/ext/prophet/test_prophet.py | 40 +- .../r_forecast/test_r_multi_seasonality.py | 95 +- test/ext/rotbaum/test_rotbaum_smoke.py | 44 +- test/ext/statsforecast/test_statsforecast.py | 14 +- test/model/npts/test_npts.py | 41 +- test/mx/block/test_scaler.py | 384 +- .../distribution/test_distribution_methods.py | 20 +- .../test_distribution_output_shapes.py | 10 +- .../test_distribution_sampling.py | 26 +- test/mx/distribution/test_nan_mixture.py | 10 +- test/mx/distribution/test_piecewise_linear.py | 12 +- test/mx/kernels/test_periodic_kernel.py | 12 +- test/mx/kernels/test_rbf_kernel.py | 12 +- .../generate_hierarchical_dataset.py | 12 +- .../test_train_prediction_with_hts.py | 10 +- test/mx/model/gp_forecaster/data.py | 10916 ++++++++-------- test/mx/model/renewal/test_predictor.py | 70 +- .../seq2seq/test_forking_sequence_splitter.py | 102 +- .../mx/model/simple_feedforward/test_serde.py | 10 +- test/mx/model/tpp/common.py | 12 +- test/mx/representation/test_bin.py | 384 +- test/mx/representation/test_grb.py | 716 +- test/mx/representation/test_hyb.py | 586 +- test/mx/representation/test_lab.py | 520 +- test/mx/representation/test_mean.py | 60 +- test/mx/representation/test_rep.py | 60 +- test/mx/test_transform_equals.py | 92 +- .../test_precision_recall.py | 38 +- .../test_autogluon_tabular.py | 72 +- .../sagemaker_sdk/test_entry_point_scripts.py | 2 +- test/shell/test_nested_params.py | 12 +- test/time_feature/test_agg_lags.py | 64 +- test/time_feature/test_holiday.py | 178 +- .../test_discrete_distribution.py | 30 +- .../test_torch_piecewise_linear.py | 12 +- test/torch/model/test_mqf2_modules.py | 22 +- test/torch/model/test_tft.py | 24 +- test/torch/test_scaler.py | 162 +- test/transform/test_transform.py | 218 +- 166 files changed, 9359 insertions(+), 10229 deletions(-) diff --git a/examples/anomaly_detection.py b/examples/anomaly_detection.py index bf74d5f95f..094404486f 100644 --- a/examples/anomaly_detection.py +++ b/examples/anomaly_detection.py @@ -14,6 +14,7 @@ """ This example shows how to do anomaly detection with DeepAR. The model is first trained and then time-points with the largest negative log-likelihood are plotted. + """ import numpy as np from itertools import islice diff --git a/examples/benchmark_m4.py b/examples/benchmark_m4.py index 5368ee8411..e9b17717de 100644 --- a/examples/benchmark_m4.py +++ b/examples/benchmark_m4.py @@ -95,17 +95,15 @@ def evaluate(dataset_name, estimator): df = pd.DataFrame(results) - sub_df = df[ - [ - "dataset", - "estimator", - "RMSE", - "mean_wQuantileLoss", - "MASE", - "sMAPE", - "OWA", - "MSIS", - ] - ] + sub_df = df[[ + "dataset", + "estimator", + "RMSE", + "mean_wQuantileLoss", + "MASE", + "sMAPE", + "OWA", + "MSIS", + ]] print(sub_df.to_string()) diff --git a/examples/warm_start.py b/examples/warm_start.py index f8ca97eafd..fd425615a8 100644 --- a/examples/warm_start.py +++ b/examples/warm_start.py @@ -13,6 +13,7 @@ """ This example show how to intialize the network with parameters from a model that was previously trained. + """ from gluonts.dataset.repository import get_dataset, dataset_recipes diff --git a/src/gluonts/core/component.py b/src/gluonts/core/component.py index c5f18d011a..264f46b1eb 100644 --- a/src/gluonts/core/component.py +++ b/src/gluonts/core/component.py @@ -355,13 +355,11 @@ def init_wrapper(*args, **kwargs): # __init_args__ is not already set in order to avoid overriding a # value set by a subclass initializer in super().__init__ calls if not getattr(self, "__init_args__", {}): - self.__init_args__ = OrderedDict( - { - name: arg - for name, arg in sorted(all_args.items()) - if not skip_encoding(arg) - } - ) + self.__init_args__ = OrderedDict({ + name: arg + for name, arg in sorted(all_args.items()) + if not skip_encoding(arg) + }) self.__class__.__getnewargs_ex__ = validated_getnewargs_ex self.__class__.__repr__ = validated_repr diff --git a/src/gluonts/dataset/arrow/dec.py b/src/gluonts/dataset/arrow/dec.py index 148d5311c7..aab36898d0 100644 --- a/src/gluonts/dataset/arrow/dec.py +++ b/src/gluonts/dataset/arrow/dec.py @@ -23,13 +23,11 @@ class ArrowDecoder: @classmethod def from_schema(cls, schema): - return cls( - [ - (column.name[: -len("._np_shape")], column.name) - for column in schema - if column.name.endswith("._np_shape") - ] - ) + return cls([ + (column.name[: -len("._np_shape")], column.name) + for column in schema + if column.name.endswith("._np_shape") + ]) def decode(self, batch, row_number: int): return next(self.decode_batch(batch.slice(row_number, row_number + 1))) diff --git a/src/gluonts/dataset/arrow/file.py b/src/gluonts/dataset/arrow/file.py index 7bdb6cf898..552dad0f08 100644 --- a/src/gluonts/dataset/arrow/file.py +++ b/src/gluonts/dataset/arrow/file.py @@ -229,12 +229,10 @@ def __post_init__(self): self.decoder = ArrowDecoder.from_schema(self.reader.schema_arrow) if not self._row_group_sizes: - self._row_group_sizes = np.cumsum( - [ - self.reader.metadata.row_group(row_group).num_rows - for row_group in range(self.reader.metadata.num_row_groups) - ] - ) + self._row_group_sizes = np.cumsum([ + self.reader.metadata.row_group(row_group).num_rows + for row_group in range(self.reader.metadata.num_row_groups) + ]) def location_for(self, idx): if idx == 0: diff --git a/src/gluonts/dataset/artificial/recipe.py b/src/gluonts/dataset/artificial/recipe.py index 36cdbede42..f5ede6a253 100644 --- a/src/gluonts/dataset/artificial/recipe.py +++ b/src/gluonts/dataset/artificial/recipe.py @@ -714,12 +714,10 @@ def __call__(self, x, field_name, global_state, **kwargs): probs = [self.prob_fun(x, length=c) for c in self.cardinalities] global_state[field_name] = probs probs = global_state[field_name] - cats = np.array( - [ - np.random.choice(np.arange(len(probs[i])), p=probs[i]) - for i in range(len(probs)) - ] - ) + cats = np.array([ + np.random.choice(np.arange(len(probs[i])), p=probs[i]) + for i in range(len(probs)) + ]) return cats diff --git a/src/gluonts/dataset/pandas.py b/src/gluonts/dataset/pandas.py index dcdb1d5456..e8dacf86e2 100644 --- a/src/gluonts/dataset/pandas.py +++ b/src/gluonts/dataset/pandas.py @@ -221,17 +221,15 @@ def __len__(self) -> int: return len(self._data_entries) def __repr__(self) -> str: - info = ", ".join( - [ - f"size={len(self)}", - f"freq={self.freq}", - f"num_feat_dynamic_real={self.num_feat_dynamic_real}", - f"num_past_feat_dynamic_real={self.num_past_feat_dynamic_real}", - f"num_feat_static_real={self.num_feat_static_real}", - f"num_feat_static_cat={self.num_feat_static_cat}", - f"static_cardinalities={self.static_cardinalities}", - ] - ) + info = ", ".join([ + f"size={len(self)}", + f"freq={self.freq}", + f"num_feat_dynamic_real={self.num_feat_dynamic_real}", + f"num_past_feat_dynamic_real={self.num_past_feat_dynamic_real}", + f"num_feat_static_real={self.num_feat_static_real}", + f"num_feat_static_cat={self.num_feat_static_cat}", + f"static_cardinalities={self.static_cardinalities}", + ]) return f"PandasDataset<{info}>" @classmethod diff --git a/src/gluonts/dataset/repository/_lstnet.py b/src/gluonts/dataset/repository/_lstnet.py index e933666c77..2841bbb92d 100644 --- a/src/gluonts/dataset/repository/_lstnet.py +++ b/src/gluonts/dataset/repository/_lstnet.py @@ -161,14 +161,12 @@ def generate_lstnet_dataset( for cat, ts in enumerate(timeseries): sliced_ts = ts[:training_end] if len(sliced_ts) > 0: - train_ts.append( - { - "target": sliced_ts.values, - "start": sliced_ts.index[0], - "feat_static_cat": [cat], - "item_id": cat, - } - ) + train_ts.append({ + "target": sliced_ts.values, + "start": sliced_ts.index[0], + "feat_static_cat": [cat], + "item_id": cat, + }) assert len(train_ts) == ds_info.num_series @@ -186,14 +184,12 @@ def generate_lstnet_dataset( prediction_start_date + ds_info.prediction_length ) sliced_ts = ts[:prediction_end_date] - test_ts.append( - { - "target": sliced_ts.values, - "start": sliced_ts.index[0], - "feat_static_cat": [cat], - "item_id": cat, - } - ) + test_ts.append({ + "target": sliced_ts.values, + "start": sliced_ts.index[0], + "feat_static_cat": [cat], + "item_id": cat, + }) assert len(test_ts) == ds_info.num_series * ds_info.rolling_evaluations diff --git a/src/gluonts/dataset/repository/_m3.py b/src/gluonts/dataset/repository/_m3.py index 84206d0f15..5c0b3ba86c 100644 --- a/src/gluonts/dataset/repository/_m3.py +++ b/src/gluonts/dataset/repository/_m3.py @@ -163,23 +163,19 @@ def normalize_category(c: str): start = str(pd.Period(time_stamp, freq=subset.freq)) cat = [i, cat_map[category]] - train_data.append( - { - "target": target[: -subset.prediction_length], - "start": start, - "feat_static_cat": cat, - "item_id": series, - } - ) - - test_data.append( - { - "target": target, - "start": start, - "feat_static_cat": cat, - "item_id": series, - } - ) + train_data.append({ + "target": target[: -subset.prediction_length], + "start": start, + "feat_static_cat": cat, + "item_id": series, + }) + + test_data.append({ + "target": target, + "start": start, + "feat_static_cat": cat, + "item_id": series, + }) meta = MetaData( **metadata( diff --git a/src/gluonts/dataset/repository/_tsf_datasets.py b/src/gluonts/dataset/repository/_tsf_datasets.py index ba073cdf4c..d9c27c9ecc 100644 --- a/src/gluonts/dataset/repository/_tsf_datasets.py +++ b/src/gluonts/dataset/repository/_tsf_datasets.py @@ -201,27 +201,23 @@ def convert_data( # timestamps # - `item_id` is added for all datasets ... many datasets provide # the "series_name" - test_data.append( - { - "target": data_entry["target"], - "start": str( - data_entry.get("start_timestamp", default_start_timestamp) - ), - "item_id": data_entry.get("series_name", i), - "feat_static_cat": [i], - } - ) + test_data.append({ + "target": data_entry["target"], + "start": str( + data_entry.get("start_timestamp", default_start_timestamp) + ), + "item_id": data_entry.get("series_name", i), + "feat_static_cat": [i], + }) - train_data.append( - { - "target": data_entry["target"][:-train_offset], - "start": str( - data_entry.get("start_timestamp", default_start_timestamp) - ), - "item_id": data_entry.get("series_name", i), - "feat_static_cat": [i], - } - ) + train_data.append({ + "target": data_entry["target"][:-train_offset], + "start": str( + data_entry.get("start_timestamp", default_start_timestamp) + ), + "item_id": data_entry.get("series_name", i), + "feat_static_cat": [i], + }) return train_data, test_data diff --git a/src/gluonts/dataset/schema/translate.py b/src/gluonts/dataset/schema/translate.py index 5ea7c41955..09f0e4e602 100644 --- a/src/gluonts/dataset/schema/translate.py +++ b/src/gluonts/dataset/schema/translate.py @@ -141,14 +141,12 @@ class TokenStream: @classmethod def from_str(cls, s): - stream = cls( - [ - Token(name, value, match) - for match in re.finditer(cls.RX, s) - for name, value in valfilter(bool, match.groupdict()).items() - if name != "WHITESPACE" - ] - ) + stream = cls([ + Token(name, value, match) + for match in re.finditer(cls.RX, s) + for name, value in valfilter(bool, match.groupdict()).items() + if name != "WHITESPACE" + ]) for token in stream: if token.name == "INVALID": diff --git a/src/gluonts/ev/metrics.py b/src/gluonts/ev/metrics.py index 4d3e55b335..00f74bb2f6 100644 --- a/src/gluonts/ev/metrics.py +++ b/src/gluonts/ev/metrics.py @@ -124,12 +124,9 @@ def update(self, data: Mapping[str, np.ndarray]) -> Self: return self def get(self) -> np.ndarray: - return self.post_process( - **{ - name: evaluator.get() - for name, evaluator in self.metrics.items() - } - ) + return self.post_process(**{ + name: evaluator.get() for name, evaluator in self.metrics.items() + }) @runtime_checkable diff --git a/src/gluonts/evaluation/_base.py b/src/gluonts/evaluation/_base.py index b623bf3d75..8280c5a062 100644 --- a/src/gluonts/evaluation/_base.py +++ b/src/gluonts/evaluation/_base.py @@ -293,13 +293,11 @@ def __call__( # Thus we set dtype=np.float64 to convert masked values back to NaNs # which are handled correctly by pandas Dataframes during # aggregation. - metrics_per_ts = metrics_per_ts.astype( - { - col: np.float64 - for col in metrics_per_ts.columns - if col not in ["item_id", "forecast_start"] - } - ) + metrics_per_ts = metrics_per_ts.astype({ + col: np.float64 + for col in metrics_per_ts.columns + if col not in ["item_id", "forecast_start"] + }) return self.get_aggregate_metrics(metrics_per_ts) @@ -536,26 +534,18 @@ def get_aggregate_metrics( totals[f"QuantileLoss[{quantile}]"] / totals["abs_target_sum"] ) - totals["mean_absolute_QuantileLoss"] = np.array( - [ - totals[f"QuantileLoss[{quantile}]"] - for quantile in self.quantiles - ] - ).mean() - - totals["mean_wQuantileLoss"] = np.array( - [ - totals[f"wQuantileLoss[{quantile}]"] - for quantile in self.quantiles - ] - ).mean() - - totals["MAE_Coverage"] = np.mean( - [ - np.abs(totals[f"Coverage[{quantile}]"] - np.array([q.value])) - for q in self.quantiles - ] - ) + totals["mean_absolute_QuantileLoss"] = np.array([ + totals[f"QuantileLoss[{quantile}]"] for quantile in self.quantiles + ]).mean() + + totals["mean_wQuantileLoss"] = np.array([ + totals[f"wQuantileLoss[{quantile}]"] for quantile in self.quantiles + ]).mean() + + totals["MAE_Coverage"] = np.mean([ + np.abs(totals[f"Coverage[{quantile}]"] - np.array([q.value])) + for q in self.quantiles + ]) # Compute OWA if required if self.calculate_owa: diff --git a/src/gluonts/ext/hierarchicalforecast.py b/src/gluonts/ext/hierarchicalforecast.py index 44ed9d7b54..0bb4ab4ce5 100644 --- a/src/gluonts/ext/hierarchicalforecast.py +++ b/src/gluonts/ext/hierarchicalforecast.py @@ -96,13 +96,11 @@ def unpivot(df: pd.DataFrame) -> pd.DataFrame: """ n, k = df.shape - return pd.DataFrame( - { - "unique_id": np.asarray(df.columns).repeat(n), - "ds": np.tile(np.asarray(df.index), k), - "y": df.to_numpy().ravel("F"), - } - ) + return pd.DataFrame({ + "unique_id": np.asarray(df.columns).repeat(n), + "ds": np.tile(np.asarray(df.index), k), + "y": df.to_numpy().ravel("F"), + }) def format_reconciled_forecasts( @@ -243,16 +241,14 @@ def __init__( def predict_item(self, entry: DataEntry) -> QuantileForecast: kwargs = {} - if self.config.intervals is not None and all( - [ - proportion not in _build_fn_name(self.hrec.reconcilers[0]) - for proportion in [ - "forecast_proportions", - "average_proportions", - "proportion_averages", - ] + if self.config.intervals is not None and all([ + proportion not in _build_fn_name(self.hrec.reconcilers[0]) + for proportion in [ + "forecast_proportions", + "average_proportions", + "proportion_averages", ] - ): + ]): kwargs["level"] = self.config.intervals Y_df = format_data_entry(entry, self.S) @@ -300,17 +296,15 @@ def predict_item(self, entry: DataEntry) -> QuantileForecast: fcst_col_names = self.config.statsforecast_keys # prepare for QuantileForecast format - forecast_arrays = np.array( - [ - format_reconciled_forecasts( - df=Y_hat_df_rec, - fcst_col_name=fcst_col_names[e], - prediction_length=self.prediction_length, - S=self.S, - ) - for e, k in enumerate(self.config.statsforecast_keys) - ] - ) + forecast_arrays = np.array([ + format_reconciled_forecasts( + df=Y_hat_df_rec, + fcst_col_name=fcst_col_names[e], + prediction_length=self.prediction_length, + S=self.S, + ) + for e, k in enumerate(self.config.statsforecast_keys) + ]) return QuantileForecast( forecast_arrays=forecast_arrays, diff --git a/src/gluonts/ext/rotbaum/_model.py b/src/gluonts/ext/rotbaum/_model.py index 3c0270a49a..fcca9250c3 100644 --- a/src/gluonts/ext/rotbaum/_model.py +++ b/src/gluonts/ext/rotbaum/_model.py @@ -142,7 +142,7 @@ def fit( # XGBoost, but True if one uses lightgbm. model_is_already_trained: bool = False, # True if there is no need to # train self.model - **kwargs + **kwargs, ): """ Fits self.model and partitions R^n into cells. @@ -180,12 +180,10 @@ def fit( if not model_is_already_trained: self.model.fit(x_train, y_train, **kwargs) y_train_pred = self.model.predict(x_train) - df = pd.DataFrame( - { - "y_true": y_train, - "y_pred": y_train_pred, - } - ).reset_index(drop=True) + df = pd.DataFrame({ + "y_true": y_train, + "y_pred": y_train_pred, + }).reset_index(drop=True) self.sorted_train_preds = sorted(df["y_pred"].unique()) cell_values_dict = self.preprocess_df( df, min_bin_size=self.min_bin_size diff --git a/src/gluonts/ext/rotbaum/_predictor.py b/src/gluonts/ext/rotbaum/_predictor.py index 6631e8dde0..ab324fa69c 100644 --- a/src/gluonts/ext/rotbaum/_predictor.py +++ b/src/gluonts/ext/rotbaum/_predictor.py @@ -411,19 +411,15 @@ def explain( assert self.model_list is not None - importances = np.array( + importances = np.array([ [ - [ - self.model_list[time_stamp] - .models[quantile] - .booster_.feature_importance( - importance_type=importance_type - ) - for time_stamp in range(self.prediction_length) - ] - for quantile in self.quantiles + self.model_list[time_stamp] + .models[quantile] + .booster_.feature_importance(importance_type=importance_type) + for time_stamp in range(self.prediction_length) ] - ).transpose((2, 1, 0)) + for quantile in self.quantiles + ]).transpose((2, 1, 0)) # The shape is: (features, pred_length, quantiles) importances = importances.mean(axis=2) # Average over quantiles # The shape of importances is: (features, pred_length) @@ -463,17 +459,15 @@ def explain( ) for i in range(num_feat_static_cat): - coordinate_map["feat_static_cat"].append( - ( - dynamic_length - + num_feat_static_real - + static_cat_features_so_far, - dynamic_length - + num_feat_static_real - + static_cat_features_so_far - + cardinality[i], - ) - ) + coordinate_map["feat_static_cat"].append(( + dynamic_length + + num_feat_static_real + + static_cat_features_so_far, + dynamic_length + + num_feat_static_real + + static_cat_features_so_far + + cardinality[i], + )) static_cat_features_so_far += cardinality[i] coordinate_map["past_feat_dynamic_real"] = [ @@ -523,13 +517,11 @@ def explain( ) logger.info(f"shape of importance matrix is: {importances.shape}") assert ( - sum( - [ - sum([coor[1] - coor[0] for coor in coordinate_map[key]]) - for key in coordinate_map - if key != "target" - ] - ) + sum([ + sum([coor[1] - coor[0] for coor in coordinate_map[key]]) + for key in coordinate_map + if key != "target" + ]) + coordinate_map["target"][1] - coordinate_map["target"][0] ) == importances.shape[ diff --git a/src/gluonts/ext/rotbaum/_preprocess.py b/src/gluonts/ext/rotbaum/_preprocess.py index 6baee5be68..6ab57d569d 100644 --- a/src/gluonts/ext/rotbaum/_preprocess.py +++ b/src/gluonts/ext/rotbaum/_preprocess.py @@ -49,7 +49,7 @@ def __init__( max_n_datapts: int = 400000, seed: Optional[int] = None, num_samples: Optional[int] = None, - **kwargs + **kwargs, ): """ Parameters @@ -168,15 +168,13 @@ def preprocess_from_single_ts(self, time_series: Dict) -> Tuple: feature_data.append( list(featurized_data) + [forecast_horizon_index] ) - target_data.append( - [ - time_series["target"][ - starting_index - + self.context_window_size - + forecast_horizon_index - ] + target_data.append([ + time_series["target"][ + starting_index + + self.context_window_size + + forecast_horizon_index ] - ) + ]) else: featurized_data = self.make_features( altered_time_series, starting_index @@ -296,7 +294,7 @@ def __init__( one_hot_encode: bool = True, subtract_mean: bool = True, count_nans: bool = False, - **kwargs + **kwargs, ): if one_hot_encode: assert cardinality != "ignore" or ( @@ -313,7 +311,7 @@ def __init__( stratify_targets=stratify_targets, n_ignore_last=n_ignore_last, num_samples=num_samples, - **kwargs + **kwargs, ) self.use_feat_static_real = use_feat_static_real @@ -481,41 +479,37 @@ def make_features(self, time_series: Dict, starting_index: int) -> List: past_feat_dynamic_real = ( list( - chain( - *[ - prefix + list(ent[0]) + list(ent[1].values()) - for ent in [ - self._pre_transform( - ts if prefix else ts[starting_index:end_index], - self.subtract_mean, - self.count_nans, - ) - for ts in time_series["past_feat_dynamic_real"] - ] + chain(*[ + prefix + list(ent[0]) + list(ent[1].values()) + for ent in [ + self._pre_transform( + ts if prefix else ts[starting_index:end_index], + self.subtract_mean, + self.count_nans, + ) + for ts in time_series["past_feat_dynamic_real"] ] - ) + ]) ) if self.use_past_feat_dynamic_real else [] ) feat_dynamic_real = ( list( - chain( - *[ - list(ent[0]) + list(ent[1].values()) - for ent in [ - self._pre_transform( - ts[ - starting_index : end_index - + self.forecast_horizon - ], - self.subtract_mean, - self.count_nans, - ) - for ts in time_series["feat_dynamic_real"] - ] + chain(*[ + list(ent[0]) + list(ent[1].values()) + for ent in [ + self._pre_transform( + ts[ + starting_index : end_index + + self.forecast_horizon + ], + self.subtract_mean, + self.count_nans, + ) + for ts in time_series["feat_dynamic_real"] ] - ) + ]) ) if self.use_feat_dynamic_real else [] diff --git a/src/gluonts/model/evaluation.py b/src/gluonts/model/evaluation.py index 15638dea84..597cee9703 100644 --- a/src/gluonts/model/evaluation.py +++ b/src/gluonts/model/evaluation.py @@ -104,7 +104,7 @@ def evaluate_forecasts_raw( batch_size: int = 100, mask_invalid_label: bool = True, allow_nan_forecast: bool = False, - seasonality: Optional[int] = None + seasonality: Optional[int] = None, ) -> dict: """ Evaluate ``forecasts`` by comparing them with ``test_data``, according @@ -147,12 +147,10 @@ def evaluate_forecasts_raw( input_batches, label_batches, forecast_batches ): if 0 not in axis: - index_data.extend( - [ - (forecast.item_id, forecast.start_date) - for forecast in forecast_batch - ] - ) + index_data.extend([ + (forecast.item_id, forecast.start_date) + for forecast in forecast_batch + ]) data_batch = _get_data_batch( input_batch, @@ -189,7 +187,7 @@ def evaluate_forecasts( batch_size: int = 100, mask_invalid_label: bool = True, allow_nan_forecast: bool = False, - seasonality: Optional[int] = None + seasonality: Optional[int] = None, ) -> pd.DataFrame: """ Evaluate ``forecasts`` by comparing them with ``test_data``, according @@ -243,7 +241,7 @@ def evaluate_model( batch_size: int = 100, mask_invalid_label: bool = True, allow_nan_forecast: bool = False, - seasonality: Optional[int] = None + seasonality: Optional[int] = None, ) -> pd.DataFrame: """ Evaluate ``model`` when applied to ``test_data``, according diff --git a/src/gluonts/model/forecast.py b/src/gluonts/model/forecast.py index c69ae385fc..73c0a6093f 100644 --- a/src/gluonts/model/forecast.py +++ b/src/gluonts/model/forecast.py @@ -532,23 +532,19 @@ def dim(self) -> int: return self._dim def __repr__(self): - return ", ".join( - [ - f"SampleForecast({self.samples!r})", - f"{self.start_date!r}", - f"item_id={self.item_id!r}", - f"info={self.info!r})", - ] - ) + return ", ".join([ + f"SampleForecast({self.samples!r})", + f"{self.start_date!r}", + f"item_id={self.item_id!r}", + f"info={self.info!r})", + ]) def to_quantile_forecast(self, quantiles: List[str]) -> "QuantileForecast": return QuantileForecast( - forecast_arrays=np.array( - [ - self.quantile(q) if q != "mean" else self.mean - for q in quantiles - ] - ), + forecast_arrays=np.array([ + self.quantile(q) if q != "mean" else self.mean + for q in quantiles + ]), start_date=self.start_date, forecast_keys=quantiles, item_id=self.item_id, @@ -690,12 +686,10 @@ def dim(self) -> int: return self._dim def __repr__(self): - return ", ".join( - [ - f"QuantileForecast({self.forecast_array!r})", - f"start_date={self.start_date!r}", - f"forecast_keys={self.forecast_keys!r}", - f"item_id={self.item_id!r}", - f"info={self.info!r})", - ] - ) + return ", ".join([ + f"QuantileForecast({self.forecast_array!r})", + f"start_date={self.start_date!r}", + f"forecast_keys={self.forecast_keys!r}", + f"item_id={self.item_id!r}", + f"info={self.info!r})", + ]) diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index d66a1d361d..6408a72e6c 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -94,7 +94,7 @@ def __call__( input_names: List[str], output_transform: Optional[OutputTransform], num_samples: Optional[int], - **kwargs + **kwargs, ) -> Iterator[Forecast]: raise NotImplementedError() @@ -111,7 +111,7 @@ def __call__( input_names: List[str], output_transform: Optional[OutputTransform], num_samples: Optional[int], - **kwargs + **kwargs, ) -> Iterator[Forecast]: for batch in inference_data_loader: inputs = select(input_names, batch, ignore_missing=True) @@ -155,7 +155,7 @@ def __call__( input_names: List[str], output_transform: Optional[OutputTransform], num_samples: Optional[int], - **kwargs + **kwargs, ) -> Iterator[Forecast]: for batch in inference_data_loader: inputs = select(input_names, batch, ignore_missing=True) @@ -171,12 +171,10 @@ def __call__( outputs = output_transform(batch, outputs) collected_samples.append(outputs) num_collected_samples += outputs[0].shape[0] - outputs = np.stack( - [ - np.concatenate(s)[:num_samples] - for s in zip(*collected_samples) - ] - ) + outputs = np.stack([ + np.concatenate(s)[:num_samples] + for s in zip(*collected_samples) + ]) assert len(outputs[0]) == num_samples i = -1 for i, output in enumerate(outputs): @@ -205,7 +203,7 @@ def __call__( input_names: List[str], output_transform: Optional[OutputTransform], num_samples: Optional[int], - **kwargs + **kwargs, ) -> Iterator[Forecast]: for batch in inference_data_loader: inputs = select(input_names, batch, ignore_missing=True) diff --git a/src/gluonts/model/npts/_predictor.py b/src/gluonts/model/npts/_predictor.py index baf77fdd10..7bb088a796 100644 --- a/src/gluonts/model/npts/_predictor.py +++ b/src/gluonts/model/npts/_predictor.py @@ -205,14 +205,10 @@ def predict( custom_features: Optional[np.ndarray] if "feat_dynamic_real" in data.keys(): - custom_features = np.array( - [ - dynamic_feature[ - -train_length - self.prediction_length : - ] - for dynamic_feature in data["feat_dynamic_real"] - ] - ) + custom_features = np.array([ + dynamic_feature[-train_length - self.prediction_length :] + for dynamic_feature in data["feat_dynamic_real"] + ]) else: custom_features = None diff --git a/src/gluonts/model/trivial/mean.py b/src/gluonts/model/trivial/mean.py index ff9ce17b2c..027c66e5c5 100644 --- a/src/gluonts/model/trivial/mean.py +++ b/src/gluonts/model/trivial/mean.py @@ -157,12 +157,9 @@ def train( training_data: Dataset, validation_dataset: Optional[Dataset] = None, ) -> ConstantPredictor: - contexts = np.array( - [ - item["target"][-self.prediction_length :] - for item in training_data - ] - ) + contexts = np.array([ + item["target"][-self.prediction_length :] for item in training_data + ]) samples = np.broadcast_to( array=contexts.mean(axis=0), diff --git a/src/gluonts/mx/block/dropout.py b/src/gluonts/mx/block/dropout.py index 2522820a2c..de192777f1 100644 --- a/src/gluonts/mx/block/dropout.py +++ b/src/gluonts/mx/block/dropout.py @@ -230,13 +230,11 @@ def mask(p, like): # only for RNN, the first element of states is output. Use the same # mask as output, instead of simply copy output to the first element # in case that the base cell is ResidualCell - new_states = [ - ( - F.where(output_mask, next_states[0], states[0]) - if p_outputs != 0.0 - else next_states[0] - ) - ] + new_states = [( + F.where(output_mask, next_states[0], states[0]) + if p_outputs != 0.0 + else next_states[0] + )] new_states.extend( [ F.where(mask(p_states, new_s), new_s, old_s) diff --git a/src/gluonts/mx/block/regularization.py b/src/gluonts/mx/block/regularization.py index aeb8719ed3..2716d81b57 100644 --- a/src/gluonts/mx/block/regularization.py +++ b/src/gluonts/mx/block/regularization.py @@ -49,7 +49,7 @@ def __init__( weight: Optional[float] = None, batch_axis: int = 1, time_axis: int = 0, - **kwargs + **kwargs, ): super().__init__(weight, batch_axis, **kwargs) self._alpha = alpha @@ -121,7 +121,7 @@ def __init__( weight: Optional[float] = None, batch_axis: int = 1, time_axis: int = 0, - **kwargs + **kwargs, ): super().__init__(weight, batch_axis, **kwargs) self._beta = beta diff --git a/src/gluonts/mx/block/scaler.py b/src/gluonts/mx/block/scaler.py index 11eb0ba9b8..66b49b45f7 100644 --- a/src/gluonts/mx/block/scaler.py +++ b/src/gluonts/mx/block/scaler.py @@ -121,7 +121,7 @@ def __init__( minimum_scale: float = 1e-10, default_scale: Optional[float] = None, *args, - **kwargs + **kwargs, ): super().__init__(*args, **kwargs) self.minimum_scale = minimum_scale diff --git a/src/gluonts/mx/block/sndense.py b/src/gluonts/mx/block/sndense.py index 906c7fd713..0e46fa4c14 100644 --- a/src/gluonts/mx/block/sndense.py +++ b/src/gluonts/mx/block/sndense.py @@ -51,7 +51,7 @@ def __init__( dtype="float32", num_power_iter: int = 1, ctx: Optional[mx.Context] = None, - **kwargs + **kwargs, ): super().__init__(**kwargs) self._coeff = coeff diff --git a/src/gluonts/mx/distribution/iresnet.py b/src/gluonts/mx/distribution/iresnet.py index bd92f6dcba..78f56fca9b 100644 --- a/src/gluonts/mx/distribution/iresnet.py +++ b/src/gluonts/mx/distribution/iresnet.py @@ -65,7 +65,7 @@ def __init__( coeff: float = 0.9, use_caching: bool = True, *args, - **kwargs + **kwargs, ): super().__init__(*args, **kwargs) assert len(event_shape) == 1 @@ -199,9 +199,6 @@ def iresnet(num_blocks: int, **block_kwargs) -> ComposedBijectionHybridBlock: ------- """ - return ComposedBijectionHybridBlock( - [ - InvertibleResnetHybridBlock(**block_kwargs) - for _ in range(num_blocks) - ] - ) + return ComposedBijectionHybridBlock([ + InvertibleResnetHybridBlock(**block_kwargs) for _ in range(num_blocks) + ]) diff --git a/src/gluonts/mx/distribution/lds.py b/src/gluonts/mx/distribution/lds.py index 31d3ae225a..d7d7dfa295 100644 --- a/src/gluonts/mx/distribution/lds.py +++ b/src/gluonts/mx/distribution/lds.py @@ -65,7 +65,7 @@ def _safe_split(x, num_outputs, axis, squeeze_axis, *args, **kwargs): num_outputs=num_outputs, squeeze_axis=squeeze_axis, *args, - **kwargs + **kwargs, ) return [x.squeeze(axis=axis)] if squeeze_axis else [x] diff --git a/src/gluonts/mx/model/deep_factor/_estimator.py b/src/gluonts/mx/model/deep_factor/_estimator.py index e867f17e8e..222726c9cd 100644 --- a/src/gluonts/mx/model/deep_factor/_estimator.py +++ b/src/gluonts/mx/model/deep_factor/_estimator.py @@ -168,22 +168,18 @@ def __init__( ) def create_transformation(self) -> Transformation: - return Chain( - [ - AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=time_features_from_frequency_str(self.freq), - pred_length=self.prediction_length, - ), - SetFieldIfNotPresent( - field=FieldName.FEAT_STATIC_CAT, value=[0.0] - ), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ] - ) + return Chain([ + AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=time_features_from_frequency_str(self.freq), + pred_length=self.prediction_length, + ), + SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0.0]), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ]) def _create_instance_splitter(self, mode: str): return transform.InstanceSplitter( diff --git a/src/gluonts/mx/model/deepvar/_estimator.py b/src/gluonts/mx/model/deepvar/_estimator.py index 9a95ab0c9a..5e68ec1dfa 100644 --- a/src/gluonts/mx/model/deepvar/_estimator.py +++ b/src/gluonts/mx/model/deepvar/_estimator.py @@ -331,48 +331,44 @@ def __init__( ) def create_transformation(self) -> Transformation: - return Chain( - [ - AsNumpyArray( - field=FieldName.TARGET, - expected_ndim=1 + len(self.distr_output.event_shape), - ), - # maps the target to (1, T) - # if the target data is uni dimensional - ExpandDimArray( - field=FieldName.TARGET, - axis=0 if self.distr_output.event_shape[0] == 1 else None, - ), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME] - + ( - [FieldName.FEAT_DYNAMIC_REAL] - if self.use_feat_dynamic_real - else [] - ), - ), - SetFieldIfNotPresent( - field=FieldName.FEAT_STATIC_CAT, value=[0.0] - ), - TargetDimIndicator( - field_name="target_dimension_indicator", - target_field=FieldName.TARGET, + return Chain([ + AsNumpyArray( + field=FieldName.TARGET, + expected_ndim=1 + len(self.distr_output.event_shape), + ), + # maps the target to (1, T) + # if the target data is uni dimensional + ExpandDimArray( + field=FieldName.TARGET, + axis=0 if self.distr_output.event_shape[0] == 1 else None, + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME] + + ( + [FieldName.FEAT_DYNAMIC_REAL] + if self.use_feat_dynamic_real + else [] ), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ] - ) + ), + SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0.0]), + TargetDimIndicator( + field_name="target_dimension_indicator", + target_field=FieldName.TARGET, + ), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ]) def _create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] @@ -403,14 +399,12 @@ def _create_instance_splitter(self, mode: str): target_dim=self.target_dim, ) if self.use_marginal_transformation - else RenameFields( - { - f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", - f"future_{FieldName.TARGET}": ( - f"future_{FieldName.TARGET}_cdf" - ), - } - ) + else RenameFields({ + f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", + f"future_{FieldName.TARGET}": ( + f"future_{FieldName.TARGET}_cdf" + ), + }) ) def create_training_data_loader( diff --git a/src/gluonts/mx/model/gpvar/_estimator.py b/src/gluonts/mx/model/gpvar/_estimator.py index 29aea6ac31..c71f0d1a92 100644 --- a/src/gluonts/mx/model/gpvar/_estimator.py +++ b/src/gluonts/mx/model/gpvar/_estimator.py @@ -244,43 +244,39 @@ def __init__( ) def create_transformation(self) -> Transformation: - return Chain( - [ - AsNumpyArray( - field=FieldName.TARGET, - expected_ndim=1 + len(self.distr_output.event_shape), - ), - # maps the target to (1, T) if the target data is uni - # dimensional - ExpandDimArray( - field=FieldName.TARGET, - axis=0 if self.distr_output.event_shape[0] == 1 else None, - ), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME], - ), - SetFieldIfNotPresent( - field=FieldName.FEAT_STATIC_CAT, value=[0.0] - ), - TargetDimIndicator( - field_name=FieldName.TARGET_DIM_INDICATOR, - target_field=FieldName.TARGET, - ), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ] - ) + return Chain([ + AsNumpyArray( + field=FieldName.TARGET, + expected_ndim=1 + len(self.distr_output.event_shape), + ), + # maps the target to (1, T) if the target data is uni + # dimensional + ExpandDimArray( + field=FieldName.TARGET, + axis=0 if self.distr_output.event_shape[0] == 1 else None, + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME], + ), + SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0.0]), + TargetDimIndicator( + field_name=FieldName.TARGET_DIM_INDICATOR, + target_field=FieldName.TARGET, + ), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ]) def _create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] @@ -313,16 +309,14 @@ def _create_instance_splitter(self, mode: str): target_dim=self.target_dim, ) if self.use_marginal_transformation - else RenameFields( - { - f"past_{FieldName.TARGET}": ( - f"past_{FieldName.TARGET}_cdf" - ), - f"future_{FieldName.TARGET}": ( - f"future_{FieldName.TARGET}_cdf" - ), - } - ) + else RenameFields({ + f"past_{FieldName.TARGET}": ( + f"past_{FieldName.TARGET}_cdf" + ), + f"future_{FieldName.TARGET}": ( + f"future_{FieldName.TARGET}_cdf" + ), + }) ) + SampleTargetDim( field_name=FieldName.TARGET_DIM_INDICATOR, diff --git a/src/gluonts/mx/model/n_beats/_ensemble.py b/src/gluonts/mx/model/n_beats/_ensemble.py index 4351163251..838a0bc15c 100644 --- a/src/gluonts/mx/model/n_beats/_ensemble.py +++ b/src/gluonts/mx/model/n_beats/_ensemble.py @@ -353,12 +353,10 @@ def __init__( self.freq = freq self.prediction_length = prediction_length - assert meta_loss_function is None or all( - [ - loss_function in VALID_LOSS_FUNCTIONS - for loss_function in meta_loss_function - ] - ), ( + assert meta_loss_function is None or all([ + loss_function in VALID_LOSS_FUNCTIONS + for loss_function in meta_loss_function + ]), ( "Each loss function has to be one of the following:" f" {VALID_LOSS_FUNCTIONS}." ) diff --git a/src/gluonts/mx/model/renewal/_estimator.py b/src/gluonts/mx/model/renewal/_estimator.py index c944850f91..1443d73ffb 100644 --- a/src/gluonts/mx/model/renewal/_estimator.py +++ b/src/gluonts/mx/model/renewal/_estimator.py @@ -181,22 +181,20 @@ def _create_instance_splitter(self, mode: str): @staticmethod def _create_post_split_transform(): - return Chain( - [ - CountTrailingZeros( - new_field="time_remaining", - target_field="past_target", - as_array=True, - ), - ToIntervalSizeFormat( - target_field="past_target", discard_first=True - ), - RenameFields({"future_target": "sparse_future"}), - AsNumpyArray(field="past_target", expected_ndim=2), - SwapAxes(input_fields=["past_target"], axes=(0, 1)), - AddAxisLength(target_field="past_target", axis=0), - ] - ) + return Chain([ + CountTrailingZeros( + new_field="time_remaining", + target_field="past_target", + as_array=True, + ), + ToIntervalSizeFormat( + target_field="past_target", discard_first=True + ), + RenameFields({"future_target": "sparse_future"}), + AsNumpyArray(field="past_target", expected_ndim=2), + SwapAxes(input_fields=["past_target"], axes=(0, 1)), + AddAxisLength(target_field="past_target", axis=0), + ]) def _stack_fn(self) -> Callable: return partial( diff --git a/src/gluonts/mx/model/renewal/_transform.py b/src/gluonts/mx/model/renewal/_transform.py index 884ec31051..986ec20c34 100644 --- a/src/gluonts/mx/model/renewal/_transform.py +++ b/src/gluonts/mx/model/renewal/_transform.py @@ -51,13 +51,9 @@ def __init__( def transform(self, data: DataEntry) -> DataEntry: target = data[self.target_field] - data[self.output_field] = np.array( - [ - ( - len(target) - if isinstance(target, list) - else target.shape[self.axis] - ) - ] - ) + data[self.output_field] = np.array([( + len(target) + if isinstance(target, list) + else target.shape[self.axis] + )]) return data diff --git a/src/gluonts/mx/model/seq2seq/_forking_estimator.py b/src/gluonts/mx/model/seq2seq/_forking_estimator.py index 7bb38915f2..910758f9bf 100644 --- a/src/gluonts/mx/model/seq2seq/_forking_estimator.py +++ b/src/gluonts/mx/model/seq2seq/_forking_estimator.py @@ -287,16 +287,14 @@ def create_transformation(self) -> Transformation: if not self.use_feat_static_cat: remove_field_names.append(FieldName.FEAT_STATIC_CAT) - chain.extend( - [ - RemoveFields(field_names=remove_field_names), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - dtype=self.dtype, - ), - ] - ) + chain.extend([ + RemoveFields(field_names=remove_field_names), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + dtype=self.dtype, + ), + ]) # --- TRANSFORMATION CHAIN FOR DYNAMIC FEATURES --- diff --git a/src/gluonts/mx/model/seq2seq/_seq2seq_estimator.py b/src/gluonts/mx/model/seq2seq/_seq2seq_estimator.py index bd73cbc57e..a4fc483115 100644 --- a/src/gluonts/mx/model/seq2seq/_seq2seq_estimator.py +++ b/src/gluonts/mx/model/seq2seq/_seq2seq_estimator.py @@ -123,30 +123,26 @@ def __init__( self.num_parallel_samples = num_parallel_samples def create_transformation(self) -> transform.Transformation: - return transform.Chain( - [ - transform.AsNumpyArray( - field=FieldName.TARGET, expected_ndim=1 - ), - transform.AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=time_features_from_frequency_str(self.freq), - pred_length=self.prediction_length, - ), - transform.VstackFeatures( - output_field=FieldName.FEAT_DYNAMIC_REAL, - input_fields=[FieldName.FEAT_TIME], - ), - transform.SetFieldIfNotPresent( - field=FieldName.FEAT_STATIC_CAT, value=[0.0] - ), - transform.AsNumpyArray( - field=FieldName.FEAT_STATIC_CAT, expected_ndim=1 - ), - ] - ) + return transform.Chain([ + transform.AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), + transform.AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=time_features_from_frequency_str(self.freq), + pred_length=self.prediction_length, + ), + transform.VstackFeatures( + output_field=FieldName.FEAT_DYNAMIC_REAL, + input_fields=[FieldName.FEAT_TIME], + ), + transform.SetFieldIfNotPresent( + field=FieldName.FEAT_STATIC_CAT, value=[0.0] + ), + transform.AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, expected_ndim=1 + ), + ]) def _create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] diff --git a/src/gluonts/mx/model/tft/_estimator.py b/src/gluonts/mx/model/tft/_estimator.py index a4d8335719..c6fbe6abc6 100644 --- a/src/gluonts/mx/model/tft/_estimator.py +++ b/src/gluonts/mx/model/tft/_estimator.py @@ -200,12 +200,10 @@ def __init__( def create_transformation(self) -> Transformation: transforms = ( [AsNumpyArray(field=FieldName.TARGET, expected_ndim=1)] - + ( - [ - AsNumpyArray(field=name, expected_ndim=1) - for name in self.static_cardinalities.keys() - ] - ) + + ([ + AsNumpyArray(field=name, expected_ndim=1) + for name in self.static_cardinalities.keys() + ]) + [ AsNumpyArray(field=name, expected_ndim=1) for name in chain( @@ -241,17 +239,13 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend( - [ - SetField( - output_field=FieldName.FEAT_STATIC_CAT, - value=[0.0], - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_CAT, expected_ndim=1 - ), - ] - ) + transforms.extend([ + SetField( + output_field=FieldName.FEAT_STATIC_CAT, + value=[0.0], + ), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ]) if self.static_feature_dims: transforms.append( @@ -262,17 +256,15 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend( - [ - SetField( - output_field=FieldName.FEAT_STATIC_REAL, - value=[0.0], - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_REAL, expected_ndim=1 - ), - ] - ) + transforms.extend([ + SetField( + output_field=FieldName.FEAT_STATIC_REAL, + value=[0.0], + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_REAL, expected_ndim=1 + ), + ]) if self.dynamic_cardinalities: transforms.append( @@ -282,22 +274,20 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend( - [ - SetField( - output_field=FieldName.FEAT_DYNAMIC_CAT, - value=[[0.0]], - ), - AsNumpyArray( - field=FieldName.FEAT_DYNAMIC_CAT, - expected_ndim=2, - ), - BroadcastTo( - field=FieldName.FEAT_DYNAMIC_CAT, - ext_length=self.prediction_length, - ), - ] - ) + transforms.extend([ + SetField( + output_field=FieldName.FEAT_DYNAMIC_CAT, + value=[[0.0]], + ), + AsNumpyArray( + field=FieldName.FEAT_DYNAMIC_CAT, + expected_ndim=2, + ), + BroadcastTo( + field=FieldName.FEAT_DYNAMIC_CAT, + ext_length=self.prediction_length, + ), + ]) input_fields = [FieldName.FEAT_TIME] if self.dynamic_feature_dims: @@ -317,19 +307,17 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend( - [ - SetField( - output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat", - value=[[0.0]], - ), - AsNumpyArray( - field=FieldName.PAST_FEAT_DYNAMIC + "_cat", - expected_ndim=2, - ), - BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC + "_cat"), - ] - ) + transforms.extend([ + SetField( + output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat", + value=[[0.0]], + ), + AsNumpyArray( + field=FieldName.PAST_FEAT_DYNAMIC + "_cat", + expected_ndim=2, + ), + BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC + "_cat"), + ]) if self.past_dynamic_feature_dims: transforms.append( @@ -339,18 +327,16 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend( - [ - SetField( - output_field=FieldName.PAST_FEAT_DYNAMIC_REAL, - value=[[0.0]], - ), - AsNumpyArray( - field=FieldName.PAST_FEAT_DYNAMIC_REAL, expected_ndim=2 - ), - BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC_REAL), - ] - ) + transforms.extend([ + SetField( + output_field=FieldName.PAST_FEAT_DYNAMIC_REAL, + value=[[0.0]], + ), + AsNumpyArray( + field=FieldName.PAST_FEAT_DYNAMIC_REAL, expected_ndim=2 + ), + BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC_REAL), + ]) return Chain(transforms) diff --git a/src/gluonts/mx/model/tpp/deeptpp/_estimator.py b/src/gluonts/mx/model/tpp/deeptpp/_estimator.py index e6390e899c..4a16504b3a 100644 --- a/src/gluonts/mx/model/tpp/deeptpp/_estimator.py +++ b/src/gluonts/mx/model/tpp/deeptpp/_estimator.py @@ -177,21 +177,17 @@ def _create_instance_splitter(self, mode: str): assert isinstance(instance_sampler, ContinuousTimePointSampler) - return Chain( - [ - ContinuousTimeInstanceSplitter( - past_interval_length=self.context_interval_length, - future_interval_length=self.prediction_interval_length, - instance_sampler=instance_sampler, - ), - RenameFields( - { - "past_target": "target", - "past_valid_length": "valid_length", - } - ), - ] - ) + return Chain([ + ContinuousTimeInstanceSplitter( + past_interval_length=self.context_interval_length, + future_interval_length=self.prediction_interval_length, + instance_sampler=instance_sampler, + ), + RenameFields({ + "past_target": "target", + "past_valid_length": "valid_length", + }), + ]) def create_training_data_loader( self, diff --git a/src/gluonts/mx/model/tpp/forecast.py b/src/gluonts/mx/model/tpp/forecast.py index 4082a12987..facde8b9dc 100644 --- a/src/gluonts/mx/model/tpp/forecast.py +++ b/src/gluonts/mx/model/tpp/forecast.py @@ -132,17 +132,15 @@ def index(self) -> pd.PeriodIndex: ) def __repr__(self): - return ", ".join( - [ - f"PointProcessSampleForecast({self.samples!r})", - f"{self.valid_length!r}", - f"{self.start_date!r}", - f"{self.end_date!r}", - f"{self.freq!r}", - f"item_id={self.item_id!r}", - f"info={self.info!r})", - ] - ) + return ", ".join([ + f"PointProcessSampleForecast({self.samples!r})", + f"{self.valid_length!r}", + f"{self.start_date!r}", + f"{self.end_date!r}", + f"{self.freq!r}", + f"item_id={self.item_id!r}", + f"info={self.info!r})", + ]) def quantile(self, q: Union[float, str]) -> np.ndarray: raise NotImplementedError( diff --git a/src/gluonts/mx/model/wavenet/_estimator.py b/src/gluonts/mx/model/wavenet/_estimator.py index 112ab3fdc8..20e14e8c49 100644 --- a/src/gluonts/mx/model/wavenet/_estimator.py +++ b/src/gluonts/mx/model/wavenet/_estimator.py @@ -265,35 +265,31 @@ def __init__( ) def create_transformation(self) -> transform.Transformation: - return Chain( - [ - AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=time_features_from_frequency_str(self.freq), - pred_length=self.prediction_length, - ), - AddAgeFeature( - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_AGE, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE], - ), - SetFieldIfNotPresent( - field=FieldName.FEAT_STATIC_CAT, value=[0.0] - ), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ] - ) + return Chain([ + AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=time_features_from_frequency_str(self.freq), + pred_length=self.prediction_length, + ), + AddAgeFeature( + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_AGE, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE], + ), + SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0.0]), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ]) def _create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] diff --git a/src/gluonts/mx/representation/binning_helpers.py b/src/gluonts/mx/representation/binning_helpers.py index 31d37417ae..b9c555032b 100644 --- a/src/gluonts/mx/representation/binning_helpers.py +++ b/src/gluonts/mx/representation/binning_helpers.py @@ -31,13 +31,11 @@ def ensure_binning_monotonicity(bin_centers: np.ndarray): def bin_edges_from_bin_centers(bin_centers: np.ndarray): lower_edge = -np.inf upper_edge = np.inf - bin_edges = np.concatenate( - [ - [lower_edge], - (bin_centers[1:] + bin_centers[:-1]) / 2.0, - [upper_edge], - ] - ) + bin_edges = np.concatenate([ + [lower_edge], + (bin_centers[1:] + bin_centers[:-1]) / 2.0, + [upper_edge], + ]) return bin_edges diff --git a/src/gluonts/nursery/SCott/dataset_tools/algo_clustering.py b/src/gluonts/nursery/SCott/dataset_tools/algo_clustering.py index 0441113443..ebcf547400 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/algo_clustering.py +++ b/src/gluonts/nursery/SCott/dataset_tools/algo_clustering.py @@ -124,19 +124,15 @@ def KMeans_inside_dataset( 0, len(target) - len_sample, prediction_length ): ts_slice = target[ts_sample_start : ts_sample_start + len_sample] - feature = torch.cat( - ( - feature, - torch.Tensor( - [ - ts_slice.mean(), - ts_slice.var(), - index % 7, - index // 90, - ] - ), - ) - ) + feature = torch.cat(( + feature, + torch.Tensor([ + ts_slice.mean(), + ts_slice.var(), + index % 7, + index // 90, + ]), + )) index += 1 feature = feature.reshape(index, 4) feature = _get_pre_features(feature).contiguous() @@ -154,20 +150,16 @@ def KMeans_inside_dataset( ): ts_slice = target[ts_sample_start : ts_sample_start + len_sample] gid = cl[sample_id] - dataset_group[gid].append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) - whole_data.append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) + dataset_group[gid].append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + }) + whole_data.append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + }) unsplit_start += pd.Timedelta(hours=prediction_length) sample_id += 1 print(len(whole_data)) @@ -228,18 +220,14 @@ def KMeans_m5_dataset( # import pdb;pdb.set_trace() gid = cl[sample_id] unsplit_start = pd.Timestamp("1990-01-01") - dataset_group[gid].append( - { - "target": ts_slice, - "start": unsplit_start, - } # , 'feat_static_cat': train_entry['feat_static_cat']} - ) - whole_data.append( - { - "target": ts_slice, - "start": unsplit_start, - } # , 'feat_static_cat': train_entry['feat_static_cat']} - ) + dataset_group[gid].append({ + "target": ts_slice, + "start": unsplit_start, + }) # , 'feat_static_cat': train_entry['feat_static_cat']} + whole_data.append({ + "target": ts_slice, + "start": unsplit_start, + }) # , 'feat_static_cat': train_entry['feat_static_cat']} sample_id += 1 print(len(whole_data)) ret["group_ratio"] = [len(i) / len(whole_data) for i in dataset_group] diff --git a/src/gluonts/nursery/SCott/dataset_tools/electricity.py b/src/gluonts/nursery/SCott/dataset_tools/electricity.py index 6e75d5d95c..782ea49437 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/electricity.py +++ b/src/gluonts/nursery/SCott/dataset_tools/electricity.py @@ -102,13 +102,11 @@ def group_electricity_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - } - ) + test_full_data.append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + }) print("total number of training examples: ", len(train_full_data)) ret["group_ratio"] = [len(i) / len(train_full_data) for i in dataset_group] @@ -169,20 +167,16 @@ def group_electricity_mb( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append( - { - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) - dataset_group[gid].append( - { - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) + train_full_data.append({ + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + }) + dataset_group[gid].append({ + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + }) unsplit_start += pd.Timedelta(hours=prediction_length) # get ready the test data @@ -196,13 +190,11 @@ def group_electricity_mb( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - } - ) + test_full_data.append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + }) print("total number of training examples: ", len(train_full_data)) ret["group_ratio"] = [len(i) / len(train_full_data) for i in dataset_group] diff --git a/src/gluonts/nursery/SCott/dataset_tools/exchange_rate.py b/src/gluonts/nursery/SCott/dataset_tools/exchange_rate.py index cc5e609296..76f56e1a04 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/exchange_rate.py +++ b/src/gluonts/nursery/SCott/dataset_tools/exchange_rate.py @@ -58,20 +58,16 @@ def group_exchangerate_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) - dataset_group[gid].append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) + train_full_data.append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + }) + dataset_group[gid].append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + }) unsplit_start += pd.Timedelta("1D") * prediction_length # get ready the test data for i in range(int(num_ts * 0.2)): @@ -84,13 +80,11 @@ def group_exchangerate_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - } - ) + test_full_data.append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + }) print("total number of training examples: ", len(train_full_data)) ret["group_ratio"] = [len(i) / len(train_full_data) for i in dataset_group] print("ratio for each group: ", ret["group_ratio"]) diff --git a/src/gluonts/nursery/SCott/dataset_tools/group_raw_data.py b/src/gluonts/nursery/SCott/dataset_tools/group_raw_data.py index c89a580eac..84bf0402b0 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/group_raw_data.py +++ b/src/gluonts/nursery/SCott/dataset_tools/group_raw_data.py @@ -71,20 +71,16 @@ def get_m4_by_freq( continue nu = 1 + sum(ts_slice) / len_sample ts_slice = [i / nu for i in ts_slice] - whole_data.append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) - dataset_group[i].append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) + whole_data.append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + }) + dataset_group[i].append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + }) # unsplit_start += pd.Timedelta(hours=prediction_length*hours_factor[i]) unsplit_start += pd.Timedelta(hours=prediction_length) # for j in range(len(dataset_group)): @@ -140,18 +136,14 @@ def get_temperature_data( nu = 1 + sum(ts_slice) / len(ts_slice) ts_slice /= nu if torch.sum(torch.isnan(ts_slice)).item() == 0: - dataset_group[gid].append( - { - "target": ts_slice, - "start": pd.Timestamp(datetime[index]), - } - ) - whole_data.append( - { - "target": ts_slice, - "start": pd.Timestamp(datetime[index]), - } - ) + dataset_group[gid].append({ + "target": ts_slice, + "start": pd.Timestamp(datetime[index]), + }) + whole_data.append({ + "target": ts_slice, + "start": pd.Timestamp(datetime[index]), + }) if num_samples == samples_per_ts: break random.shuffle(whole_data) @@ -239,14 +231,12 @@ def get_group_data_by_var(name, num_groups, len_sample=9): ) ) continue - dataset_group[group_id].append( - { - "target": unsplit_ts[ - ts_sample_start : ts_sample_start + len_sample - ], - "start": unsplit_start, - } - ) + dataset_group[group_id].append({ + "target": unsplit_ts[ + ts_sample_start : ts_sample_start + len_sample + ], + "start": unsplit_start, + }) unsplit_start += pd.Timedelta(hours=1) import pdb @@ -305,18 +295,14 @@ def get_group_data_by_duplicate(name, num_duplicates, num_groups): {"target": train_entry["target"], "start": train_entry["start"]} ) for j in range(num_duplicates): - dataset_group[i % num_groups].append( - { - "target": train_entry["target"], - "start": train_entry["start"], - } - ) - whole_data_list.append( - { - "target": train_entry["target"], - "start": train_entry["start"], - } - ) + dataset_group[i % num_groups].append({ + "target": train_entry["target"], + "start": train_entry["start"], + }) + whole_data_list.append({ + "target": train_entry["target"], + "start": train_entry["start"], + }) random.shuffle(whole_data_list) random.shuffle(no_duplicate_whole_data_list) ret.append( @@ -342,12 +328,10 @@ def get_whole_data_by_duplicate(name, num_duplicates): {"target": train_entry["target"], "start": train_entry["start"]} ) for j in range(num_duplicates): - dataset_group.append( - { - "target": train_entry["target"], - "start": train_entry["start"], - } - ) + dataset_group.append({ + "target": train_entry["target"], + "start": train_entry["start"], + }) random.shuffle(dataset_group) random.shuffle(no_duplicate_whole_data_list) ret.append( @@ -366,12 +350,10 @@ def get_group_data(name): train_entry = next(it) dataset_group.append( ListDataset( - [ - { - "target": train_entry["target"], - "start": train_entry["start"], - } - ], + [{ + "target": train_entry["target"], + "start": train_entry["start"], + }], freq=dataset.metadata.freq, ) ) @@ -420,24 +402,20 @@ def get_synthetic_data(model_name=None, num_groups=8, mean_boundary=1): prediction = net.get_distr(ts_slice).sample((5000,)) prediction = sum(prediction) / len(prediction) ts = torch.cat([ts, prediction], dim=1) - whole_data_list.append( - { - "target": ts.view( - len(ts[0]), - )[context_length:], - "start": start, - } - ) + whole_data_list.append({ + "target": ts.view( + len(ts[0]), + )[context_length:], + "start": start, + }) dataset_group.append( ListDataset( - [ - { - "target": ts.view( - len(ts[0]), - )[context_length:], - "start": start, - } - ], + [{ + "target": ts.view( + len(ts[0]), + )[context_length:], + "start": start, + }], freq="1H", ) ) @@ -608,22 +586,18 @@ def get_synthetic_data_linear_simple( ) for j in range(num_duplicates): ts += torch.normal(0, 0.01, size=ts.shape) - whole_data_list.append( - { - "target": ts.view( - len(ts[0]), - ), - "start": start, - } - ) - pattern_group.append( - { - "target": ts.view( - len(ts[0]), - ), - "start": start, - } - ) + whole_data_list.append({ + "target": ts.view( + len(ts[0]), + ), + "start": start, + }) + pattern_group.append({ + "target": ts.view( + len(ts[0]), + ), + "start": start, + }) dataset_group.append(ListDataset(pattern_group, freq="1D")) random.shuffle(whole_data_list) @@ -658,32 +632,26 @@ def get_synthetic_data_sin( 1, num_time_steps ) ts += torch.FloatTensor((gid + 1) * base).view(1, num_time_steps) - no_duplicate_whole_data_list.append( - { + no_duplicate_whole_data_list.append({ + "target": ts.view( + len(ts[0]), + ), + "start": start, + }) + for j in range(num_duplicates): + ts += torch.normal(0, 0.1, size=ts.shape) + whole_data_list.append({ "target": ts.view( len(ts[0]), ), "start": start, - } - ) - for j in range(num_duplicates): - ts += torch.normal(0, 0.1, size=ts.shape) - whole_data_list.append( - { - "target": ts.view( - len(ts[0]), - ), - "start": start, - } - ) - pattern_group.append( - { - "target": ts.view( - len(ts[0]), - ), - "start": start, - } - ) + }) + pattern_group.append({ + "target": ts.view( + len(ts[0]), + ), + "start": start, + }) dataset_group.append(ListDataset(pattern_group, freq="1D")) random.shuffle(whole_data_list) diff --git a/src/gluonts/nursery/SCott/dataset_tools/synthetic.py b/src/gluonts/nursery/SCott/dataset_tools/synthetic.py index 8e2dc383b6..b176ef3658 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/synthetic.py +++ b/src/gluonts/nursery/SCott/dataset_tools/synthetic.py @@ -64,24 +64,20 @@ def get_mixed_pattern(unit_length=16, num_duplicates=1000): pattern[(gid + i) % pattern_number], ) ) - ts_sample = torch.cat( - [ - context, - _get_mixed_pattern( - torch.arange(prediction_length, dtype=torch.float), - pattern[gid], - ), - ] - ) + ts_sample = torch.cat([ + context, + _get_mixed_pattern( + torch.arange(prediction_length, dtype=torch.float), + pattern[gid], + ), + ]) whole_data.append({"target": ts_sample, "start": start}) if j % 5 == 0: - val_data.append( - { - "target": ts_sample - + torch.normal(0, 1, ts_sample.shape), - "start": start, - } - ) + val_data.append({ + "target": ts_sample + + torch.normal(0, 1, ts_sample.shape), + "start": start, + }) dataset_group[m * 4 + gid].append( {"target": ts_sample, "start": start} ) diff --git a/src/gluonts/nursery/SCott/dataset_tools/traffic.py b/src/gluonts/nursery/SCott/dataset_tools/traffic.py index 8582ba87b2..4ec45b14fd 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/traffic.py +++ b/src/gluonts/nursery/SCott/dataset_tools/traffic.py @@ -70,20 +70,16 @@ def group_traffic_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append( - { - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) - dataset_group[gid].append( - { - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) + train_full_data.append({ + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + }) + dataset_group[gid].append({ + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + }) unsplit_start += pd.Timedelta(hours=prediction_length) # get ready the test data @@ -97,13 +93,11 @@ def group_traffic_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - } - ) + test_full_data.append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + }) print("total number of training examples: ", len(train_full_data)) ret["group_ratio"] = [len(i) / len(train_full_data) for i in dataset_group] @@ -164,20 +158,16 @@ def group_traffic_mb( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append( - { - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) - dataset_group[gid].append( - { - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) + train_full_data.append({ + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + }) + dataset_group[gid].append({ + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + }) unsplit_start += pd.Timedelta(hours=prediction_length) # get ready the test data @@ -191,13 +181,11 @@ def group_traffic_mb( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - } - ) + test_full_data.append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + }) print("total number of training examples: ", len(train_full_data)) ret["group_ratio"] = [len(i) / len(train_full_data) for i in dataset_group] diff --git a/src/gluonts/nursery/SCott/model/ar/ar_estimator.py b/src/gluonts/nursery/SCott/model/ar/ar_estimator.py index 62d78b0aae..116e2f26f6 100644 --- a/src/gluonts/nursery/SCott/model/ar/ar_estimator.py +++ b/src/gluonts/nursery/SCott/model/ar/ar_estimator.py @@ -63,22 +63,20 @@ def __init__( # transformation that includes time features, age feature, observed values # indicator, etc. def create_transformation(self, is_full_batch=False) -> Transformation: - return Chain( - [ - InstanceSplitter( - target_field=FieldName.TARGET, - is_pad_field=FieldName.IS_PAD, - start_field=FieldName.START, - forecast_start_field=FieldName.FORECAST_START, - # train_sampler=ExpectedNumInstanceSampler(num_instances=1), - train_sampler=CustomUniformSampler(), - past_length=self.context_length, - future_length=self.prediction_length, - is_full_batch=is_full_batch, - time_series_fields=[], # [FieldName.FEAT_DYNAMIC_REAL] - ) - ] - ) + return Chain([ + InstanceSplitter( + target_field=FieldName.TARGET, + is_pad_field=FieldName.IS_PAD, + start_field=FieldName.START, + forecast_start_field=FieldName.FORECAST_START, + # train_sampler=ExpectedNumInstanceSampler(num_instances=1), + train_sampler=CustomUniformSampler(), + past_length=self.context_length, + future_length=self.prediction_length, + is_full_batch=is_full_batch, + time_series_fields=[], # [FieldName.FEAT_DYNAMIC_REAL] + ) + ]) # defines the network, we get to see one batch to initialize it. # the network should return at least one tensor that is used as a loss to minimize in the training loop. diff --git a/src/gluonts/nursery/SCott/model/lstm/lstm_estimator.py b/src/gluonts/nursery/SCott/model/lstm/lstm_estimator.py index 2b55e12da0..3519b8c995 100644 --- a/src/gluonts/nursery/SCott/model/lstm/lstm_estimator.py +++ b/src/gluonts/nursery/SCott/model/lstm/lstm_estimator.py @@ -66,22 +66,20 @@ def __init__( # transformation that includes time features, age feature, observed values # indicator, etc. def create_transformation(self, is_full_batch=False) -> Transformation: - return Chain( - [ - InstanceSplitter( - target_field=FieldName.TARGET, - is_pad_field=FieldName.IS_PAD, - start_field=FieldName.START, - forecast_start_field=FieldName.FORECAST_START, - # train_sampler=ExpectedNumInstanceSampler(num_instances=1), - train_sampler=CustomUniformSampler(), - past_length=self.context_length, - future_length=self.prediction_length, - is_full_batch=is_full_batch, - time_series_fields=[], # [FieldName.FEAT_DYNAMIC_REAL] - ) - ] - ) + return Chain([ + InstanceSplitter( + target_field=FieldName.TARGET, + is_pad_field=FieldName.IS_PAD, + start_field=FieldName.START, + forecast_start_field=FieldName.FORECAST_START, + # train_sampler=ExpectedNumInstanceSampler(num_instances=1), + train_sampler=CustomUniformSampler(), + past_length=self.context_length, + future_length=self.prediction_length, + is_full_batch=is_full_batch, + time_series_fields=[], # [FieldName.FEAT_DYNAMIC_REAL] + ) + ]) # defines the network, we get to see one batch to initialize it. # the network should return at least one tensor that is used as a loss to minimize in the training loop. diff --git a/src/gluonts/nursery/SCott/preprocess_data.py b/src/gluonts/nursery/SCott/preprocess_data.py index 40b181820f..2b6981edea 100644 --- a/src/gluonts/nursery/SCott/preprocess_data.py +++ b/src/gluonts/nursery/SCott/preprocess_data.py @@ -64,24 +64,20 @@ def get_mixed_pattern(unit_length=16, num_duplicates=1000): pattern[(gid + i) % pattern_number], ) ) - ts_sample = torch.cat( - [ - context, - _get_mixed_pattern( - torch.arange(prediction_length, dtype=torch.float), - pattern[gid], - ), - ] - ) + ts_sample = torch.cat([ + context, + _get_mixed_pattern( + torch.arange(prediction_length, dtype=torch.float), + pattern[gid], + ), + ]) whole_data.append({"target": ts_sample, "start": start}) if j % 5 == 0: - val_data.append( - { - "target": ts_sample - + torch.normal(0, 1, ts_sample.shape), - "start": start, - } - ) + val_data.append({ + "target": ts_sample + + torch.normal(0, 1, ts_sample.shape), + "start": start, + }) dataset_group[m * 4 + gid].append( {"target": ts_sample, "start": start} ) @@ -157,20 +153,16 @@ def group_electricity_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append( - { - "target": ts_slice, - "start": t, - "feat_static_cat": np.array([gid]), - } - ) - dataset_group[gid].append( - { - "target": ts_slice, - "start": t, - "feat_static_cat": np.array([gid]), - } - ) + train_full_data.append({ + "target": ts_slice, + "start": t, + "feat_static_cat": np.array([gid]), + }) + dataset_group[gid].append({ + "target": ts_slice, + "start": t, + "feat_static_cat": np.array([gid]), + }) unsplit_start += pd.Timedelta(hours=prediction_length) # get ready the test data @@ -184,13 +176,11 @@ def group_electricity_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - } - ) + test_full_data.append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + }) print( "Generating the electricity training data, the total number of" @@ -248,20 +238,16 @@ def group_electricity_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append( - { - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) - dataset_group[gid].append( - { - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) + train_full_data.append({ + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + }) + dataset_group[gid].append({ + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + }) unsplit_start += pd.Timedelta(hours=prediction_length) # get ready the test data @@ -275,13 +261,11 @@ def group_electricity_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - } - ) + test_full_data.append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + }) print("total number of training examples: ", len(train_full_data)) ret["group_ratio"] = [len(i) / len(train_full_data) for i in dataset_group] @@ -338,20 +322,16 @@ def group_exchangerate_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) - dataset_group[gid].append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) + train_full_data.append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + }) + dataset_group[gid].append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + }) unsplit_start += pd.Timedelta("1D") * prediction_length # get ready the test data for i in range(int(num_ts * 0.2)): @@ -364,13 +344,11 @@ def group_exchangerate_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - } - ) + test_full_data.append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + }) print( "Generating the exchange rate training data, the total number of" " training examples:", @@ -441,20 +419,16 @@ def group_traffic_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append( - { - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) - dataset_group[gid].append( - { - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - } - ) + train_full_data.append({ + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + }) + dataset_group[gid].append({ + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + }) unsplit_start += pd.Timedelta(hours=prediction_length) # get ready the test data @@ -468,13 +442,11 @@ def group_traffic_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append( - { - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - } - ) + test_full_data.append({ + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + }) print( "Generating the traffic training data, the total number of training" diff --git a/src/gluonts/nursery/daf/engine/parallel.py b/src/gluonts/nursery/daf/engine/parallel.py index ae3682efa6..d95bce51a7 100644 --- a/src/gluonts/nursery/daf/engine/parallel.py +++ b/src/gluonts/nursery/daf/engine/parallel.py @@ -300,22 +300,20 @@ def __init__( **kwargs, ) del self.optimizer - self.optimizer = optimizer( - [ - { - "params": chain( - self.model.src.generative_parameters(), - self.model.tgt.generative_parameters(), - ) - }, - { - "params": chain( - self.model.src.discriminative_parameters(), - self.model.tgt.discriminative_parameters(), - ) - }, - ] - ) + self.optimizer = optimizer([ + { + "params": chain( + self.model.src.generative_parameters(), + self.model.tgt.generative_parameters(), + ) + }, + { + "params": chain( + self.model.src.discriminative_parameters(), + self.model.tgt.discriminative_parameters(), + ) + }, + ]) def _train(self, *data): self.model.generative() diff --git a/src/gluonts/nursery/daf/estimator/modules.py b/src/gluonts/nursery/daf/estimator/modules.py index 5541b36a3e..3f92ac7192 100644 --- a/src/gluonts/nursery/daf/estimator/modules.py +++ b/src/gluonts/nursery/daf/estimator/modules.py @@ -103,12 +103,10 @@ def n_layer(self) -> int: @property def tie_layers(self) -> bool: return (self.n_layer == 1) or ( - all( - [ - (a.encoder is b.encoder) and (a.decoder is b.decoder) - for a, b in product(self.blocks[:1], self.blocks[1:]) - ] - ) + all([ + (a.encoder is b.encoder) and (a.decoder is b.decoder) + for a, b in product(self.blocks[:1], self.blocks[1:]) + ]) ) def register_loss_func(self, func: LossFunction) -> None: diff --git a/src/gluonts/nursery/daf/tslib/dataset/timeseries.py b/src/gluonts/nursery/daf/tslib/dataset/timeseries.py index 4bd63964b8..11270eb9fb 100644 --- a/src/gluonts/nursery/daf/tslib/dataset/timeseries.py +++ b/src/gluonts/nursery/daf/tslib/dataset/timeseries.py @@ -108,25 +108,23 @@ def d_data(self) -> int: return self.target.shape[1] def __eq__(self, other) -> bool: - return all( - [ - isinstance(other, TimeSeriesInstant), - np.array_equal(self.target, other.target), - np.array_equal(self.timestamp, other.timestamp), - self.series_name == other.series_name, - np.array_equal(self.target_names, other.target_names), - _dict_equal( - self.categorical_features, - other.categorical_features, - np.array_equal, - ), - _dict_equal( - self.numerical_features, - other.numerical_features, - np.array_equal, - ), - ] - ) + return all([ + isinstance(other, TimeSeriesInstant), + np.array_equal(self.target, other.target), + np.array_equal(self.timestamp, other.timestamp), + self.series_name == other.series_name, + np.array_equal(self.target_names, other.target_names), + _dict_equal( + self.categorical_features, + other.categorical_features, + np.array_equal, + ), + _dict_equal( + self.numerical_features, + other.numerical_features, + np.array_equal, + ), + ]) def __repr__(self): string = f"time = {self.timestamp:%Y-%m-%d %H:%M:%S}\n" @@ -291,45 +289,43 @@ def dynamic_numerical_features(self) -> Dict: } def __eq__(self, other): - return all( - [ - isinstance(other, TimeSeries), - np.array_equal(self.target, other.target), - np.array_equal(self.time_index, other.time_index), - self.series_name == other.series_name, - np.array_equal(self.target_names, other.target_names), - _dict_equal( - self.static_categorical_features, - other.static_categorical_features, - np.array_equal, - ), - _dict_equal( - self.static_numerical_features, - other.static_numerical_features, - np.array_equal, - ), - _dict_equal( - self.revealed_categorical_features, - other.revealed_categorical_features, - np.array_equal, - ), - _dict_equal( - self.revealed_numerical_features, - other.revealed_numerical_features, - np.array_equal, - ), - _dict_equal( - self.observed_categorical_features, - other.observed_categorical_features, - np.array_equal, - ), - _dict_equal( - self.observed_numerical_features, - other.observed_numerical_features, - np.array_equal, - ), - ] - ) + return all([ + isinstance(other, TimeSeries), + np.array_equal(self.target, other.target), + np.array_equal(self.time_index, other.time_index), + self.series_name == other.series_name, + np.array_equal(self.target_names, other.target_names), + _dict_equal( + self.static_categorical_features, + other.static_categorical_features, + np.array_equal, + ), + _dict_equal( + self.static_numerical_features, + other.static_numerical_features, + np.array_equal, + ), + _dict_equal( + self.revealed_categorical_features, + other.revealed_categorical_features, + np.array_equal, + ), + _dict_equal( + self.revealed_numerical_features, + other.revealed_numerical_features, + np.array_equal, + ), + _dict_equal( + self.observed_categorical_features, + other.observed_categorical_features, + np.array_equal, + ), + _dict_equal( + self.observed_numerical_features, + other.observed_numerical_features, + np.array_equal, + ), + ]) def __len__(self): return len(self.target) @@ -562,16 +558,14 @@ def _check_consistency( self, instances: List[TimeSeries] ) -> List[TimeSeries]: def _consistent(ts1: TimeSeries, ts2: TimeSeries) -> bool: - return all( - [ - np.array_equal(ts1.target_names, ts2.target_names), - ts1._static_features == ts2._static_features, - ts1._revealed_features == ts2._revealed_features, - ts1._observed_features == ts2._observed_features, - ts1._categorical_features == ts2._categorical_features, - ts1._numerical_features == ts2._numerical_features, - ] - ) + return all([ + np.array_equal(ts1.target_names, ts2.target_names), + ts1._static_features == ts2._static_features, + ts1._revealed_features == ts2._revealed_features, + ts1._observed_features == ts2._observed_features, + ts1._categorical_features == ts2._categorical_features, + ts1._numerical_features == ts2._numerical_features, + ]) cats = defaultdict(list) nums = defaultdict(list) diff --git a/src/gluonts/nursery/daf/tslib/metrics/dict.py b/src/gluonts/nursery/daf/tslib/metrics/dict.py index 8ac9fa8c3f..53d0e3ea5b 100644 --- a/src/gluonts/nursery/daf/tslib/metrics/dict.py +++ b/src/gluonts/nursery/daf/tslib/metrics/dict.py @@ -193,12 +193,10 @@ def _add_spaces(str_, n_spaces=4): main_str = "\n".join( [f"{name}: {repr(meter)}" for name, meter in self._meters.items()] ) - child_str = "\n".join( - [ - f"{name}:\n{_add_spaces(repr(meterdict))}" - for name, meterdict in self._meterdicts.items() - ] - ) + child_str = "\n".join([ + f"{name}:\n{_add_spaces(repr(meterdict))}" + for name, meterdict in self._meterdicts.items() + ]) if child_str: main_str = "\n".join([main_str, child_str]) return main_str diff --git a/src/gluonts/nursery/daf/tslib/nn/attention/posemb.py b/src/gluonts/nursery/daf/tslib/nn/attention/posemb.py index ae560cfb7d..a002e9d2f5 100644 --- a/src/gluonts/nursery/daf/tslib/nn/attention/posemb.py +++ b/src/gluonts/nursery/daf/tslib/nn/attention/posemb.py @@ -63,12 +63,10 @@ def __init__( self.max_len = max_len self.sub_shape = sub_shape - self._weights = nn.ParameterList( - [ - nn.Parameter(Tensor(size, dim)) - for size, dim in zip(self.sub_shape, self.d_sub_embeds) - ] - ) + self._weights = nn.ParameterList([ + nn.Parameter(Tensor(size, dim)) + for size, dim in zip(self.sub_shape, self.d_sub_embeds) + ]) self._reset_parameters() def _reset_parameters(self): diff --git a/src/gluonts/nursery/daf/tslib/nn/attention/selfattn.py b/src/gluonts/nursery/daf/tslib/nn/attention/selfattn.py index bf5d98a83b..0d8964d804 100644 --- a/src/gluonts/nursery/daf/tslib/nn/attention/selfattn.py +++ b/src/gluonts/nursery/daf/tslib/nn/attention/selfattn.py @@ -309,7 +309,7 @@ def forward( value: Tensor, shape: Tensor, *, - mask: Optional[BoolTensor] = None + mask: Optional[BoolTensor] = None, ) -> Tensor: q, k, v = self._compute_qkv(value, shape) score = self._compute_attn_score(q, k, mask) diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py index 5a2dab200b..abce0e6f77 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py @@ -209,21 +209,17 @@ def generate(self) -> None: def generate_split( self, split: Literal["train", "val", "test"], n_samples: int ) -> None: - queries, support_sets = zip( - *[ - generate_artificial_tuplets( - dataset_name=self.dataset_name, - context_length=self.context_length, - support_length=self.support_length, - prediction_length=self.prediction_length, - support_set_size=self.support_set_size, - item_id=i, - ) - for i in tqdm( - range(n_samples), desc="generating artificial data" - ) - ] - ) + queries, support_sets = zip(*[ + generate_artificial_tuplets( + dataset_name=self.dataset_name, + context_length=self.context_length, + support_length=self.support_length, + prediction_length=self.prediction_length, + support_set_size=self.support_set_size, + item_id=i, + ) + for i in tqdm(range(n_samples), desc="generating artificial data") + ]) _write_data_to_file(self.root / split / "data.json", queries) _write_data_to_file( self.root / split / ".support_set.json", support_sets diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/cheat.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/cheat.py index 8f53989c7f..964c0fc65a 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/cheat.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/cheat.py @@ -234,39 +234,35 @@ def train_dataloader(self) -> DataLoader[TripletBatch]: def val_dataloader(self) -> DataLoader[TripletBatch]: splits = [self.splits.val(d_name) for d_name in self.dataset_names_val] - return list( - [ - DataLoader( - TripletDataset( - queries=split.data(), support_sets=split.support_set() - ), - collate_fn=TripletBatch.collate, - batch_size=self.batch_size_val_test, - num_workers=self.num_workers, - pin_memory=True, - ) - for split in splits - ] - ) + return list([ + DataLoader( + TripletDataset( + queries=split.data(), support_sets=split.support_set() + ), + collate_fn=TripletBatch.collate, + batch_size=self.batch_size_val_test, + num_workers=self.num_workers, + pin_memory=True, + ) + for split in splits + ]) def test_dataloader(self) -> DataLoader[TripletBatch]: splits = [ self.splits.test(d_name) for d_name in self.dataset_names_test ] - return list( - [ - DataLoader( - TripletDataset( - queries=split.data(), support_sets=split.support_set() - ), - collate_fn=TripletBatch.collate, - batch_size=self.batch_size_val_test, - num_workers=self.num_workers, - pin_memory=True, - ) - for split in splits - ] - ) + return list([ + DataLoader( + TripletDataset( + queries=split.data(), support_sets=split.support_set() + ), + collate_fn=TripletBatch.collate, + batch_size=self.batch_size_val_test, + num_workers=self.num_workers, + pin_memory=True, + ) + for split in splits + ]) def generate(self) -> None: if self.root.exists(): @@ -301,18 +297,14 @@ def generate_split( always_cheat=True, query_length_scale=5.0, ) -> None: - queries, support_sets = zip( - *[ - self.generate_artificial_tuplets( - item_id=i, - always_cheat=always_cheat, - query_length_scale=query_length_scale, - ) - for i in tqdm( - range(n_samples), desc="generating artificial data" - ) - ] - ) + queries, support_sets = zip(*[ + self.generate_artificial_tuplets( + item_id=i, + always_cheat=always_cheat, + query_length_scale=query_length_scale, + ) + for i in tqdm(range(n_samples), desc="generating artificial data") + ]) _write_data_to_file(self.root / split / "data.json", queries) _write_data_to_file( self.root / split / ".support_set.json", support_sets @@ -585,14 +577,12 @@ def train_dataloader(self) -> DataLoader[TripletBatch]: for d_name in self.dataset_names_train ] - datasets = ConcatDataset( - [ - TripletDataset( - queries=split.data(), support_sets=split.support_set() - ) - for split in splits - ] - ) + datasets = ConcatDataset([ + TripletDataset( + queries=split.data(), support_sets=split.support_set() + ) + for split in splits + ]) return DataLoader( datasets, @@ -638,19 +628,15 @@ def generate_split( query_length_scale: float = 5.0, counterfactual_size: int = 0, ) -> None: - queries, support_sets = zip( - *[ - self.generate_artificial_tuplets( - item_id=i, - always_cheat=always_cheat, - query_length_scale=query_length_scale, - counterfactual_size=counterfactual_size, - ) - for i in tqdm( - range(n_samples), desc="generating artificial data" - ) - ] - ) + queries, support_sets = zip(*[ + self.generate_artificial_tuplets( + item_id=i, + always_cheat=always_cheat, + query_length_scale=query_length_scale, + counterfactual_size=counterfactual_size, + ) + for i in tqdm(range(n_samples), desc="generating artificial data") + ]) _write_data_to_file(self.root / split / "data.json", queries) _write_data_to_file( self.root / split / ".support_set.json", support_sets diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/datasets.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/datasets.py index ddfec8fbf8..c2152b9cc9 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/datasets.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/datasets.py @@ -132,11 +132,9 @@ def sample_datasets( random_state=random_state.randint(low=0, high=10000), ) folds.append((train_split, val_split, test_split)) - assert not any( - [ - set(train_split) & (set(val_split)), - set(train_split) & set(test_split), - set(val_split) & set(test_split), - ] - ), "Splits should not intersect!" + assert not any([ + set(train_split) & (set(val_split)), + set(train_split) & set(test_split), + set(val_split) & set(test_split), + ]), "Splits should not intersect!" return folds diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/super.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/super.py index 2676713e3b..10574993a5 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/super.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/super.py @@ -195,20 +195,16 @@ def test_dataloader(self) -> DataLoader[TripletBatch]: def get_log_batches(self, n_logging_samples: int) -> Tuple[TripletBatch]: its_train = cycle( - list( - [ - iter(dm.sampling_triplet_dataset("train")) - for dm in self.data_modules_train - ] - ) + list([ + iter(dm.sampling_triplet_dataset("train")) + for dm in self.data_modules_train + ]) ) its_val = cycle( - list( - [ - iter(dm.sequential_triplet_dataset("val")) - for dm in self.data_modules_val - ] - ) + list([ + iter(dm.sequential_triplet_dataset("val")) + for dm in self.data_modules_val + ]) ) def get_log_batch(its): diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py b/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py index c5f7b3ad7f..799317d38d 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py @@ -63,12 +63,9 @@ def __init__( self.val_quantile_width = nn.ModuleList( [quantile_width.clone() for _ in range(len(val_dataset_names))] ) - self.test_quantile_width = nn.ModuleList( - [ - quantile_width.clone() - for _ in range(len(test_dataset_names)) - ] - ) + self.test_quantile_width = nn.ModuleList([ + quantile_width.clone() for _ in range(len(test_dataset_names)) + ]) self.val_dataset_names = val_dataset_names ( diff --git a/src/gluonts/nursery/few_shot_prediction/src/scripts/data.py b/src/gluonts/nursery/few_shot_prediction/src/scripts/data.py index 0f0faf6589..4e5028cc54 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/scripts/data.py +++ b/src/gluonts/nursery/few_shot_prediction/src/scripts/data.py @@ -139,15 +139,13 @@ def statistics(): ) n_total_normalized = n_total / sum(n_total) - datasets, lengths_normalized, n_total_normalized = zip( - *( - sorted( - list(zip(datasets, lengths_normalized, n_total_normalized)), - key=lambda x: x[1], - reverse=True, - ) + datasets, lengths_normalized, n_total_normalized = zip(*( + sorted( + list(zip(datasets, lengths_normalized, n_total_normalized)), + key=lambda x: x[1], + reverse=True, ) - ) + )) x = np.arange(len(datasets)) diff --git a/src/gluonts/nursery/few_shot_prediction/src/scripts/train.py b/src/gluonts/nursery/few_shot_prediction/src/scripts/train.py index b250d44577..556c6a0e65 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/scripts/train.py +++ b/src/gluonts/nursery/few_shot_prediction/src/scripts/train.py @@ -297,39 +297,35 @@ def main( # add callback that only works with attention attention_models = ["iwata", "cnn_iwata", "tcn"] if model_name in attention_models: - callbacks.extend( - [ - ForecastSupportSetAttentionPlotLoggerCallback( - log_batch_train, - quantiles=quantiles, - split="train", - every_n_epochs=log_plot_every_n_epochs, - ), - ForecastSupportSetAttentionPlotLoggerCallback( - log_batch_val, - quantiles=quantiles, - split="val", - every_n_epochs=log_plot_every_n_epochs, - ), - ] - ) + callbacks.extend([ + ForecastSupportSetAttentionPlotLoggerCallback( + log_batch_train, + quantiles=quantiles, + split="train", + every_n_epochs=log_plot_every_n_epochs, + ), + ForecastSupportSetAttentionPlotLoggerCallback( + log_batch_val, + quantiles=quantiles, + split="val", + every_n_epochs=log_plot_every_n_epochs, + ), + ]) else: - callbacks.extend( - [ - ForecastPlotLoggerCallback( - log_batch_val, - quantiles=quantiles, - split="val", - every_n_epochs=log_plot_every_n_epochs, - ), - ForecastPlotLoggerCallback( - log_batch_train, - quantiles=quantiles, - split="train", - every_n_epochs=log_plot_every_n_epochs, - ), - ] - ) + callbacks.extend([ + ForecastPlotLoggerCallback( + log_batch_val, + quantiles=quantiles, + split="val", + every_n_epochs=log_plot_every_n_epochs, + ), + ForecastPlotLoggerCallback( + log_batch_train, + quantiles=quantiles, + split="train", + every_n_epochs=log_plot_every_n_epochs, + ), + ]) # -------------------- train model --------------------------------------------------- trainer = pl.Trainer( diff --git a/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py b/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py index 768fe06070..92f615ef67 100644 --- a/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py +++ b/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py @@ -42,13 +42,11 @@ def extract_dataset(dataset_name: str): def pivot_dataset(dataset): ds_list = list(dataset) - return [ - { - "item": "0", - "start": ds_list[0]["start"], - "target": np.vstack([d["target"] for d in ds_list]), - } - ] + return [{ + "item": "0", + "start": ds_list[0]["start"], + "target": np.vstack([d["target"] for d in ds_list]), + }] class MultivariateDatasetInfo(NamedTuple): @@ -240,16 +238,14 @@ def taxi_30min(max_target_dim: int = None): ) -datasets = OrderedDict( - [ - ("solar", solar), - ("exchange_rate", exchange_rate), - ("electricity", electricity), - ("traffic", traffic), - ("wikipedia", wiki), - ("taxi_30min", taxi_30min), - ] -) +datasets = OrderedDict([ + ("solar", solar), + ("exchange_rate", exchange_rate), + ("electricity", electricity), + ("traffic", traffic), + ("wikipedia", wiki), + ("taxi_30min", taxi_30min), +]) if __name__ == "__main__": extract_dataset("electricity_nips") diff --git a/src/gluonts/nursery/robust-mts-attack/pts/dataset/repository/_m5.py b/src/gluonts/nursery/robust-mts-attack/pts/dataset/repository/_m5.py index 74d2aa7b09..a94bec02c2 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/dataset/repository/_m5.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/dataset/repository/_m5.py @@ -233,15 +233,13 @@ def get_sell_price(item_id, store_id): meta_file = dataset_path / "metadata.json" with open(meta_file, "w") as f: f.write( - json.dumps( - { - "freq": pandas_freq, - "prediction_length": prediction_length, - "feat_static_cat": feat_static_cat, - "feat_dynamic_real": feat_dynamic_real, - "cardinality": len(train_ds), - } - ) + json.dumps({ + "freq": pandas_freq, + "prediction_length": prediction_length, + "feat_static_cat": feat_static_cat, + "feat_dynamic_real": feat_dynamic_real, + "cardinality": len(train_ds), + }) ) # Build testing set diff --git a/src/gluonts/nursery/robust-mts-attack/pts/feature/holiday.py b/src/gluonts/nursery/robust-mts-attack/pts/feature/holiday.py index aef6823c36..253eb7db5b 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/feature/holiday.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/feature/holiday.py @@ -88,17 +88,13 @@ def __call__(self, dates): dates Pandas series with Datetimeindex timestamps. """ - return np.vstack( - [ - np.hstack( - [ - self.kernel_function((index - ref_date).days) - for index in dates - ] - ) - for ref_date in self.reference_dates - ] - ).sum(0, keepdims=True) + return np.vstack([ + np.hstack([ + self.kernel_function((index - ref_date).days) + for index in dates + ]) + for ref_date in self.reference_dates + ]).sum(0, keepdims=True) class CustomHolidayFeatureSet: @@ -170,16 +166,12 @@ def __call__(self, dates): dates Pandas series with Datetimeindex timestamps. """ - return np.vstack( - [ - np.hstack( - [ - self.kernel_function( - distance_to_holiday(custom_holiday)(index) - ) - for index in dates - ] + return np.vstack([ + np.hstack([ + self.kernel_function( + distance_to_holiday(custom_holiday)(index) ) - for custom_holiday in self.custom_holidays - ] - ) + for index in dates + ]) + for custom_holiday in self.custom_holidays + ]) diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py index 12fa525867..7abe78f816 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py @@ -241,21 +241,17 @@ def unroll_encoder( # from (batch_size, sub_seq_len, *target_shape, num_lags) # to (batch_size, sub_seq_len, prod(target_shape) * num_lags) - input_lags = lags_scaled.reshape( - ( - -1, - subsequences_length, - len(self.lags_seq) * prod(self.target_shape), - ) - ) - - input_control_lags = control_lags_scaled.reshape( - ( - -1, - subsequences_length, - len(self.lags_seq) * prod(self.target_shape), - ) - ) + input_lags = lags_scaled.reshape(( + -1, + subsequences_length, + len(self.lags_seq) * prod(self.target_shape), + )) + + input_control_lags = control_lags_scaled.reshape(( + -1, + subsequences_length, + len(self.lags_seq) * prod(self.target_shape), + )) # (batch_size, sub_seq_len, input_dim) inputs = torch.cat( @@ -562,13 +558,11 @@ def sampling_decoder( samples = torch.cat(future_samples, dim=1) # (batch_size, num_samples, prediction_length, *target_shape) - return samples.reshape( - ( - (-1, self.num_parallel_samples) - + (self.prediction_length,) - + self.target_shape - ) - ) + return samples.reshape(( + (-1, self.num_parallel_samples) + + (self.prediction_length,) + + self.target_shape + )) # noinspection PyMethodOverriding,PyPep8Naming def forward( diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py index 8d66c91ec0..2f0314856c 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py @@ -209,13 +209,11 @@ def unroll_encoder( # from (batch_size, sub_seq_len, *target_shape, num_lags) # to (batch_size, sub_seq_len, prod(target_shape) * num_lags) - input_lags = lags_scaled.reshape( - ( - -1, - subsequences_length, - len(self.lags_seq) * prod(self.target_shape), - ) - ) + input_lags = lags_scaled.reshape(( + -1, + subsequences_length, + len(self.lags_seq) * prod(self.target_shape), + )) # (batch_size, sub_seq_len, input_dim) inputs = torch.cat( @@ -440,13 +438,11 @@ def sampling_decoder( samples = torch.cat(future_samples, dim=1) # (batch_size, num_samples, prediction_length, *target_shape) - return samples.reshape( - ( - (-1, self.num_parallel_samples) - + (self.prediction_length,) - + self.target_shape - ) - ) + return samples.reshape(( + (-1, self.num_parallel_samples) + + (self.prediction_length,) + + self.target_shape + )) # noinspection PyMethodOverriding,PyPep8Naming def forward( diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py index 4bea3ddf23..5adb3c9364 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py @@ -260,12 +260,10 @@ def create_instance_splitter(self, mode: str): target_dim=self.target_dim, ) if self.use_marginal_transformation - else RenameFields( - { - f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", - f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", - } - ) + else RenameFields({ + f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", + f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", + }) ) def create_training_network( diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py index 64353b4adf..e035ccef03 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py @@ -596,14 +596,12 @@ def repeat(tensor, dim=0): samples = torch.cat(future_samples, dim=1) locs = torch.cat(loc, dim=1) # (batch_size, num_samples, prediction_length, target_dim) - return samples.reshape( - ( - -1, - self.num_parallel_samples, - self.prediction_length, - self.target_dim, - ) - ) # , locs.reshape( + return samples.reshape(( + -1, + self.num_parallel_samples, + self.prediction_length, + self.target_dim, + )) # , locs.reshape( # -1, # self.num_parallel_samples, # self.prediction_length, diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py b/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py index ff9234273d..ab21ca12aa 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py @@ -214,12 +214,10 @@ def __init__( self.freq = freq self.prediction_length = prediction_length - assert meta_loss_function is None or all( - [ - loss_function in VALID_LOSS_FUNCTIONS - for loss_function in meta_loss_function - ] - ), f"Each loss function has to be one of the following: {VALID_LOSS_FUNCTIONS}." + assert meta_loss_function is None or all([ + loss_function in VALID_LOSS_FUNCTIONS + for loss_function in meta_loss_function + ]), f"Each loss function has to be one of the following: {VALID_LOSS_FUNCTIONS}." assert meta_context_length is None or all( [context_length > 0 for context_length in meta_context_length] ), "The value of each `context_length` should be > 0" diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_estimator.py index c4ef409bc7..75d9b2adbd 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_estimator.py @@ -158,22 +158,20 @@ def _validate_nbeats_argument( # that can be digested by our model by only splitting the target in two, a # conditioning part and a to-predict part, for each training example. def create_transformation(self) -> Transformation: - return Chain( - [ - RemoveFields( - field_names=[ - FieldName.FEAT_STATIC_REAL, - FieldName.FEAT_DYNAMIC_REAL, - FieldName.FEAT_DYNAMIC_CAT, - ] - ), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - dtype=self.dtype, - ), - ] - ) + return Chain([ + RemoveFields( + field_names=[ + FieldName.FEAT_STATIC_REAL, + FieldName.FEAT_DYNAMIC_REAL, + FieldName.FEAT_DYNAMIC_CAT, + ] + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + dtype=self.dtype, + ), + ]) def create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_estimator.py index 48c230e04c..4049a36c8e 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_estimator.py @@ -135,43 +135,39 @@ def __init__( ) def create_transformation(self) -> Transformation: - return Chain( - [ - AsNumpyArray( - field=FieldName.TARGET, - expected_ndim=2, - ), - # maps the target to (1, T) - # if the target data is uni dimensional - ExpandDimArray( - field=FieldName.TARGET, - axis=None, - ), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME], - ), - SetFieldIfNotPresent( - field=FieldName.FEAT_STATIC_CAT, value=[0] - ), - TargetDimIndicator( - field_name="target_dimension_indicator", - target_field=FieldName.TARGET, - ), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ] - ) + return Chain([ + AsNumpyArray( + field=FieldName.TARGET, + expected_ndim=2, + ), + # maps the target to (1, T) + # if the target data is uni dimensional + ExpandDimArray( + field=FieldName.TARGET, + axis=None, + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME], + ), + SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]), + TargetDimIndicator( + field_name="target_dimension_indicator", + target_field=FieldName.TARGET, + ), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ]) def create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] @@ -195,12 +191,10 @@ def create_instance_splitter(self, mode: str): FieldName.OBSERVED_VALUES, ], ) + ( - RenameFields( - { - f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", - f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", - } - ) + RenameFields({ + f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", + f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", + }) ) def create_training_network( diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_network.py index 6fc6ae1fa2..b070f3f8f7 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_network.py @@ -550,14 +550,12 @@ def repeat(tensor, dim=0): samples = torch.cat(future_samples, dim=1) # (batch_size, num_samples, prediction_length, target_dim) - return samples.reshape( - ( - -1, - self.num_parallel_samples, - self.prediction_length, - self.target_dim, - ) - ) + return samples.reshape(( + -1, + self.num_parallel_samples, + self.prediction_length, + self.target_dim, + )) def forward( self, diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py index fdb47f7d87..d894ac7277 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py @@ -129,12 +129,10 @@ def __init__( def create_transformation(self) -> Transformation: transforms = ( [AsNumpyArray(field=FieldName.TARGET, expected_ndim=1)] - + ( - [ - AsNumpyArray(field=name, expected_ndim=1) - for name in self.static_cardinalities.keys() - ] - ) + + ([ + AsNumpyArray(field=name, expected_ndim=1) + for name in self.static_cardinalities.keys() + ]) + [ AsNumpyArray(field=name, expected_ndim=1) for name in chain( @@ -168,34 +166,30 @@ def create_transformation(self) -> Transformation: ) if self.static_cardinalities: - transforms.extend( - [ - VstackFeatures( - output_field=FieldName.FEAT_STATIC_CAT, - input_fields=list(self.static_cardinalities.keys()), - h_stack=True, - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_CAT, - expected_ndim=1, - dtype=np.long, - ), - ] - ) + transforms.extend([ + VstackFeatures( + output_field=FieldName.FEAT_STATIC_CAT, + input_fields=list(self.static_cardinalities.keys()), + h_stack=True, + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, + expected_ndim=1, + dtype=np.long, + ), + ]) else: - transforms.extend( - [ - SetField( - output_field=FieldName.FEAT_STATIC_CAT, - value=[0], - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_CAT, - expected_ndim=1, - dtype=np.long, - ), - ] - ) + transforms.extend([ + SetField( + output_field=FieldName.FEAT_STATIC_CAT, + value=[0], + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, + expected_ndim=1, + dtype=np.long, + ), + ]) if self.static_feature_dims: transforms.append( @@ -206,50 +200,44 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend( - [ - SetField( - output_field=FieldName.FEAT_STATIC_REAL, - value=[0.0], - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_REAL, expected_ndim=1 - ), - ] - ) + transforms.extend([ + SetField( + output_field=FieldName.FEAT_STATIC_REAL, + value=[0.0], + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_REAL, expected_ndim=1 + ), + ]) if self.dynamic_cardinalities: - transforms.extend( - [ - VstackFeatures( - output_field=FieldName.FEAT_DYNAMIC_CAT, - input_fields=list(self.dynamic_cardinalities.keys()), - ), - AsNumpyArray( - field=FieldName.FEAT_DYNAMIC_CAT, - expected_ndim=2, - dtype=np.long, - ), - ] - ) + transforms.extend([ + VstackFeatures( + output_field=FieldName.FEAT_DYNAMIC_CAT, + input_fields=list(self.dynamic_cardinalities.keys()), + ), + AsNumpyArray( + field=FieldName.FEAT_DYNAMIC_CAT, + expected_ndim=2, + dtype=np.long, + ), + ]) else: - transforms.extend( - [ - SetField( - output_field=FieldName.FEAT_DYNAMIC_CAT, - value=[[0]], - ), - AsNumpyArray( - field=FieldName.FEAT_DYNAMIC_CAT, - expected_ndim=2, - dtype=np.long, - ), - BroadcastTo( - field=FieldName.FEAT_DYNAMIC_CAT, - ext_length=self.prediction_length, - ), - ] - ) + transforms.extend([ + SetField( + output_field=FieldName.FEAT_DYNAMIC_CAT, + value=[[0]], + ), + AsNumpyArray( + field=FieldName.FEAT_DYNAMIC_CAT, + expected_ndim=2, + dtype=np.long, + ), + BroadcastTo( + field=FieldName.FEAT_DYNAMIC_CAT, + ext_length=self.prediction_length, + ), + ]) input_fields = [FieldName.FEAT_TIME, FieldName.FEAT_AGE] if self.dynamic_feature_dims: @@ -262,36 +250,30 @@ def create_transformation(self) -> Transformation: ) if self.past_dynamic_cardinalities: - transforms.extend( - [ - VstackFeatures( - output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat", - input_fields=list( - self.past_dynamic_cardinalities.keys() - ), - ), - AsNumpyArray( - field=FieldName.PAST_FEAT_DYNAMIC + "_cat", - expected_ndim=2, - dtype=np.long, - ), - ] - ) + transforms.extend([ + VstackFeatures( + output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat", + input_fields=list(self.past_dynamic_cardinalities.keys()), + ), + AsNumpyArray( + field=FieldName.PAST_FEAT_DYNAMIC + "_cat", + expected_ndim=2, + dtype=np.long, + ), + ]) else: - transforms.extend( - [ - SetField( - output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat", - value=[[0]], - ), - AsNumpyArray( - field=FieldName.PAST_FEAT_DYNAMIC + "_cat", - expected_ndim=2, - dtype=np.long, - ), - BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC + "_cat"), - ] - ) + transforms.extend([ + SetField( + output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat", + value=[[0]], + ), + AsNumpyArray( + field=FieldName.PAST_FEAT_DYNAMIC + "_cat", + expected_ndim=2, + dtype=np.long, + ), + BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC + "_cat"), + ]) if self.past_dynamic_feature_dims: transforms.append( @@ -301,18 +283,16 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend( - [ - SetField( - output_field=FieldName.PAST_FEAT_DYNAMIC_REAL, - value=[[0.0]], - ), - AsNumpyArray( - field=FieldName.PAST_FEAT_DYNAMIC_REAL, expected_ndim=2 - ), - BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC_REAL), - ] - ) + transforms.extend([ + SetField( + output_field=FieldName.PAST_FEAT_DYNAMIC_REAL, + value=[[0.0]], + ), + AsNumpyArray( + field=FieldName.PAST_FEAT_DYNAMIC_REAL, expected_ndim=2 + ), + BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC_REAL), + ]) return Chain(transforms) diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_modules.py b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_modules.py index ef6698f505..437f2c6ca3 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_modules.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_modules.py @@ -38,14 +38,12 @@ def __init__( self.feature_slices = feature_dims self.feature_dims = feature_dims - self._projector = nn.ModuleList( - [ - nn.Linear(in_features=in_feature, out_features=out_features) - for in_feature, out_features in zip( - self.feature_dims, embedding_dims - ) - ] - ) + self._projector = nn.ModuleList([ + nn.Linear(in_features=in_feature, out_features=out_features) + for in_feature, out_features in zip( + self.feature_dims, embedding_dims + ) + ]) def forward(self, features: torch.Tensor) -> List[torch.Tensor]: if self.__num_features > 1: @@ -160,12 +158,10 @@ def __init__( dropout=dropout, ) - self.variable_network = nn.ModuleList( - [ - GatedResidualNetwork(d_hidden=d_hidden, dropout=dropout) - for _ in range(n_vars) - ] - ) + self.variable_network = nn.ModuleList([ + GatedResidualNetwork(d_hidden=d_hidden, dropout=dropout) + for _ in range(n_vars) + ]) def forward( self, diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/epsilon_theta.py b/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/epsilon_theta.py index a66b5b003c..bc7462641f 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/epsilon_theta.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/epsilon_theta.py @@ -120,16 +120,14 @@ def __init__( self.cond_upsampler = CondUpsampler( target_dim=target_dim, cond_length=cond_length ) - self.residual_layers = nn.ModuleList( - [ - ResidualBlock( - residual_channels=residual_channels, - dilation=2 ** (i % dilation_cycle_length), - hidden_size=residual_hidden, - ) - for i in range(residual_layers) - ] - ) + self.residual_layers = nn.ModuleList([ + ResidualBlock( + residual_channels=residual_channels, + dilation=2 ** (i % dilation_cycle_length), + hidden_size=residual_hidden, + ) + for i in range(residual_layers) + ]) self.skip_projection = nn.Conv1d( residual_channels, residual_channels, 3 ) diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_estimator.py index 8ccb09d9c9..cb60cd4e39 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_estimator.py @@ -136,43 +136,39 @@ def __init__( ) def create_transformation(self) -> Transformation: - return Chain( - [ - AsNumpyArray( - field=FieldName.TARGET, - expected_ndim=2, - ), - # maps the target to (1, T) - # if the target data is uni dimensional - ExpandDimArray( - field=FieldName.TARGET, - axis=None, - ), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME], - ), - SetFieldIfNotPresent( - field=FieldName.FEAT_STATIC_CAT, value=[0] - ), - TargetDimIndicator( - field_name="target_dimension_indicator", - target_field=FieldName.TARGET, - ), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ] - ) + return Chain([ + AsNumpyArray( + field=FieldName.TARGET, + expected_ndim=2, + ), + # maps the target to (1, T) + # if the target data is uni dimensional + ExpandDimArray( + field=FieldName.TARGET, + axis=None, + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME], + ), + SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]), + TargetDimIndicator( + field_name="target_dimension_indicator", + target_field=FieldName.TARGET, + ), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ]) def create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] @@ -196,12 +192,10 @@ def create_instance_splitter(self, mode: str): FieldName.OBSERVED_VALUES, ], ) + ( - RenameFields( - { - f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", - f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", - } - ) + RenameFields({ + f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", + f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", + }) ) def create_training_network( diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_network.py index ecd657bf2f..50c6d68973 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_network.py @@ -565,14 +565,12 @@ def repeat(tensor, dim=0): samples = torch.cat(future_samples, dim=1) # (batch_size, num_samples, prediction_length, target_dim) - return samples.reshape( - ( - -1, - self.num_parallel_samples, - self.prediction_length, - self.target_dim, - ) - ) + return samples.reshape(( + -1, + self.num_parallel_samples, + self.prediction_length, + self.target_dim, + )) def forward( self, diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py index 3ac17fef6e..85326df0e2 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py @@ -224,13 +224,11 @@ def create_network_input( # from (batch_size, sub_seq_len, *target_shape, num_lags) # to (batch_size, sub_seq_len, prod(target_shape) * num_lags) - input_lags = lags_scaled.reshape( - ( - -1, - subsequences_length, - len(self.lags_seq) * prod(self.target_shape), - ) - ) + input_lags = lags_scaled.reshape(( + -1, + subsequences_length, + len(self.lags_seq) * prod(self.target_shape), + )) # (batch_size, sub_seq_len, input_dim) inputs = torch.cat( @@ -425,13 +423,11 @@ def sampling_decoder( samples = torch.cat(future_samples, dim=1) # (batch_size, num_samples, *target_shape, prediction_length) - return samples.reshape( - ( - (-1, self.num_parallel_samples) - + self.target_shape - + (self.prediction_length,) - ) - ) + return samples.reshape(( + (-1, self.num_parallel_samples) + + self.target_shape + + (self.prediction_length,) + )) def forward( self, diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_estimator.py index 26664034b2..b875df3c58 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_estimator.py @@ -146,49 +146,45 @@ def create_transformation(self) -> Transformation: if not self.use_feat_dynamic_real: remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) - return Chain( - [ - RemoveFields(field_names=remove_field_names), - AsNumpyArray( - field=FieldName.TARGET, - expected_ndim=2, + return Chain([ + RemoveFields(field_names=remove_field_names), + AsNumpyArray( + field=FieldName.TARGET, + expected_ndim=2, + ), + # maps the target to (1, T) + # if the target data is uni dimensional + ExpandDimArray( + field=FieldName.TARGET, + axis=None, + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME] + + ( + [FieldName.FEAT_DYNAMIC_REAL] + if self.use_feat_dynamic_real + else [] ), - # maps the target to (1, T) - # if the target data is uni dimensional - ExpandDimArray( - field=FieldName.TARGET, - axis=None, - ), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME] - + ( - [FieldName.FEAT_DYNAMIC_REAL] - if self.use_feat_dynamic_real - else [] - ), - ), - SetFieldIfNotPresent( - field=FieldName.FEAT_STATIC_CAT, value=[0] - ), - TargetDimIndicator( - field_name="target_dimension_indicator", - target_field=FieldName.TARGET, - ), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ] - ) + ), + SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]), + TargetDimIndicator( + field_name="target_dimension_indicator", + target_field=FieldName.TARGET, + ), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ]) def create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] @@ -212,12 +208,10 @@ def create_instance_splitter(self, mode: str): FieldName.OBSERVED_VALUES, ], ) + ( - RenameFields( - { - f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", - f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", - } - ) + RenameFields({ + f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", + f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", + }) ) def create_training_network( diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py index 55d60d1537..ffd78d9a77 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py @@ -552,14 +552,12 @@ def repeat(tensor, dim=0): samples = torch.cat(future_samples, dim=1) # (batch_size, num_samples, prediction_length, target_dim) - return samples.reshape( - ( - -1, - self.num_parallel_samples, - self.prediction_length, - self.target_dim, - ) - ) + return samples.reshape(( + -1, + self.num_parallel_samples, + self.prediction_length, + self.target_dim, + )) def forward( self, diff --git a/src/gluonts/nursery/robust-mts-attack/pts/modules/feature.py b/src/gluonts/nursery/robust-mts-attack/pts/modules/feature.py index e8f61a25cc..98c046dcc8 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/modules/feature.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/modules/feature.py @@ -31,12 +31,10 @@ def create_embedding(c: int, d: int) -> nn.Embedding: embedding = nn.Embedding(c, d) return embedding - self.__embedders = nn.ModuleList( - [ - create_embedding(c, d) - for c, d in zip(cardinalities, embedding_dims) - ] - ) + self.__embedders = nn.ModuleList([ + create_embedding(c, d) + for c, d in zip(cardinalities, embedding_dims) + ]) def forward(self, features: torch.Tensor) -> torch.Tensor: if self.__num_features > 1: diff --git a/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py b/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py index 73195e7957..5666592254 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/modules/gaussian_diffusion.py @@ -312,7 +312,7 @@ def log_prob(self, x, cond, *args, **kwargs): cond.reshape(B * T, 1, -1), time, *args, - **kwargs + **kwargs, ) return loss diff --git a/src/gluonts/nursery/robust-mts-attack/read_pickle.py b/src/gluonts/nursery/robust-mts-attack/read_pickle.py index e43284f8b3..038ddbd3a0 100644 --- a/src/gluonts/nursery/robust-mts-attack/read_pickle.py +++ b/src/gluonts/nursery/robust-mts-attack/read_pickle.py @@ -53,12 +53,10 @@ def create_table(path): "+-", np.asarray(result.mse[key]).std() * c, ) - mse.append( - ( - np.asarray(result.mse[key]).mean(), - np.asarray(result.mse[key]).std() * c, - ) - ) + mse.append(( + np.asarray(result.mse[key]).mean(), + np.asarray(result.mse[key]).std() * c, + )) print("mape loss:") for key in result.mape.keys(): @@ -68,12 +66,10 @@ def create_table(path): "+-", np.asarray(result.mape[key]).std() * c, ) - mape.append( - ( - np.asarray(result.mape[key]).mean(), - np.asarray(result.mape[key]).std() * c, - ) - ) + mape.append(( + np.asarray(result.mape[key]).mean(), + np.asarray(result.mape[key]).std() * c, + )) print("wQL:") for key in result.ql.keys(): @@ -83,12 +79,10 @@ def create_table(path): "+-", np.asarray(result.ql[key]).std() * c, ) - wql.append( - ( - np.asarray(result.ql[key]).mean(), - np.asarray(result.ql[key]).std() * c, - ) - ) + wql.append(( + np.asarray(result.ql[key]).mean(), + np.asarray(result.ql[key]).std() * c, + )) with open("table_" + types + ".txt", "w") as f: for i in range(len(mse)): diff --git a/src/gluonts/nursery/robust-mts-attack/utils.py b/src/gluonts/nursery/robust-mts-attack/utils.py index 75a36ec169..c40e2fd6a2 100644 --- a/src/gluonts/nursery/robust-mts-attack/utils.py +++ b/src/gluonts/nursery/robust-mts-attack/utils.py @@ -235,12 +235,10 @@ def calc_loss( target_items, quantiles=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], ): - testset_size = sum( - [ - attack_data[i].true_future_target.shape[0] - for i in range(len(attack_data)) - ] - ) + testset_size = sum([ + attack_data[i].true_future_target.shape[0] + for i in range(len(attack_data)) + ]) mse = { key: np.zeros((testset_size, len(attack_idx), len(target_items))) for key in forecasts.keys() diff --git a/src/gluonts/nursery/san/_estimator.py b/src/gluonts/nursery/san/_estimator.py index 3f50433d54..f72edabb97 100644 --- a/src/gluonts/nursery/san/_estimator.py +++ b/src/gluonts/nursery/san/_estimator.py @@ -129,21 +129,19 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend( - [ - SetField( - output_field=FieldName.FEAT_DYNAMIC_REAL, - value=[[]] - * (self.context_length + self.prediction_length), - ), - AsNumpyArray( - field=FieldName.FEAT_DYNAMIC_REAL, - expected_ndim=2, - ), - # SwapAxes(input_fields= - # [FieldName.FEAT_DYNAMIC_REAL], axes=(0,1)), - ] - ) + transforms.extend([ + SetField( + output_field=FieldName.FEAT_DYNAMIC_REAL, + value=[[]] + * (self.context_length + self.prediction_length), + ), + AsNumpyArray( + field=FieldName.FEAT_DYNAMIC_REAL, + expected_ndim=2, + ), + # SwapAxes(input_fields= + # [FieldName.FEAT_DYNAMIC_REAL], axes=(0,1)), + ]) if self.use_feat_dynamic_cat: transforms.append( AsNumpyArray( @@ -155,26 +153,24 @@ def create_transformation(self) -> Transformation: # Manually set dummy dynamic categorical features and split by time # Unknown issue in dataloader if leave splitting to # InstanceSplitter - transforms.extend( - [ - SetField( - output_field="past_" + FieldName.FEAT_DYNAMIC_CAT, - value=[[]] * self.context_length, - ), - AsNumpyArray( - field="past_" + FieldName.FEAT_DYNAMIC_CAT, - expected_ndim=2, - ), - SetField( - output_field="future_" + FieldName.FEAT_DYNAMIC_CAT, - value=[[]] * self.prediction_length, - ), - AsNumpyArray( - field="future_" + FieldName.FEAT_DYNAMIC_CAT, - expected_ndim=2, - ), - ] - ) + transforms.extend([ + SetField( + output_field="past_" + FieldName.FEAT_DYNAMIC_CAT, + value=[[]] * self.context_length, + ), + AsNumpyArray( + field="past_" + FieldName.FEAT_DYNAMIC_CAT, + expected_ndim=2, + ), + SetField( + output_field="future_" + FieldName.FEAT_DYNAMIC_CAT, + value=[[]] * self.prediction_length, + ), + AsNumpyArray( + field="future_" + FieldName.FEAT_DYNAMIC_CAT, + expected_ndim=2, + ), + ]) if self.use_feat_static_real: transforms.append( AsNumpyArray( @@ -183,18 +179,16 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend( - [ - SetField( - output_field=FieldName.FEAT_STATIC_REAL, - value=[], - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_REAL, - expected_ndim=1, - ), - ] - ) + transforms.extend([ + SetField( + output_field=FieldName.FEAT_STATIC_REAL, + value=[], + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_REAL, + expected_ndim=1, + ), + ]) if self.use_feat_static_cat: transforms.append( AsNumpyArray( @@ -203,37 +197,35 @@ def create_transformation(self) -> Transformation: ) ) - transforms.extend( - [ - AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - AddAgeFeature( - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_AGE, - pred_length=self.prediction_length, - log_scale=True, - ), - VstackFeatures( - output_field=FieldName.FEAT_DYNAMIC_REAL, - input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] - + ( - [FieldName.FEAT_DYNAMIC_REAL] - if self.use_feat_dynamic_real - else [] - ), + transforms.extend([ + AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + AddAgeFeature( + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_AGE, + pred_length=self.prediction_length, + log_scale=True, + ), + VstackFeatures( + output_field=FieldName.FEAT_DYNAMIC_REAL, + input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] + + ( + [FieldName.FEAT_DYNAMIC_REAL] + if self.use_feat_dynamic_real + else [] ), - ] - ) + ), + ]) return Chain(transforms) def _create_instance_splitter(self, mode: str): diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py index 8c8b297fa6..48769c0f71 100644 --- a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py +++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py @@ -110,12 +110,10 @@ def _get_time_features_agg_level( ) # shape: (T, num_features) - full_time_feat = np.array( - [ - feat_map(full_date_range) - for feat_map in time_features_from_frequency_str(freq) - ] - ).T + full_time_feat = np.array([ + feat_map(full_date_range) + for feat_map in time_features_from_frequency_str(freq) + ]).T age_feature = np.log10( 2.0 + np.arange(num_periods, dtype=agg_estimator.dtype) diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_network.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_network.py index 72b8329e25..8acbc37f7e 100644 --- a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_network.py +++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_network.py @@ -226,12 +226,10 @@ def get_target_related_feat_at_agg_level( mx.nd.zeros_like(future_observed_values_agg), ) - target_related_feat_agg.update( - { - "future_target": future_target_agg, - "future_observed_values": future_observed_values_agg, - } - ) + target_related_feat_agg.update({ + "future_target": future_target_agg, + "future_observed_values": future_observed_values_agg, + }) return target_related_feat_agg @@ -423,14 +421,12 @@ def hybrid_forward( ) // window_size embeddings_at_all_levels_ls.append( - rnn_outputs.reshape( - ( - rnn_outputs.shape[0], - num_windows, - -1, - rnn_outputs.shape[-1], - ) - ) + rnn_outputs.reshape(( + rnn_outputs.shape[0], + num_windows, + -1, + rnn_outputs.shape[-1], + )) ) target_at_all_levels_ls.append( @@ -835,13 +831,11 @@ def hybrid_forward( ) reconciled_samples_at_bottom_level = ( - reconciled_samples_at_bottom_level.reshape( - ( - reconciled_samples_at_bottom_level.shape[0], - reconciled_samples_at_bottom_level.shape[1], - -1, - ) - ) + reconciled_samples_at_bottom_level.reshape(( + reconciled_samples_at_bottom_level.shape[0], + reconciled_samples_at_bottom_level.shape[1], + -1, + )) ) return reconciled_samples_at_bottom_level diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/gnn.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/gnn.py index f081630c04..8aeefc42f8 100644 --- a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/gnn.py +++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/gnn.py @@ -23,7 +23,7 @@ def __init__( num_layers: int, adj_matrix: Tensor, use_mlp: bool = True, - **kwargs + **kwargs, ): super().__init__(**kwargs) diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/utils/utils.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/utils/utils.py index 289f590e30..4a8e02229b 100644 --- a/src/gluonts/nursery/temporal_hierarchical_forecasting/utils/utils.py +++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/utils/utils.py @@ -355,12 +355,10 @@ def mapping_matrix_at_level(level: int): M[:, start_ix:end_ix] = M[:, start_ix:end_ix] / row_sum[None, :] return M - mapping_matrices = np.array( - [ - mapping_matrix_at_level(level=level) - for level in range(len(cum_num_nodes_per_level)) - ] - ) + mapping_matrices = np.array([ + mapping_matrix_at_level(level=level) + for level in range(len(cum_num_nodes_per_level)) + ]) mean_mapping_matrix = np.mean(mapping_matrices, axis=0) reconciliation_mat = np.matmul(S, mean_mapping_matrix) diff --git a/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/ensemble_recommender.py b/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/ensemble_recommender.py index 811c9ee021..914e3bd217 100644 --- a/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/ensemble_recommender.py +++ b/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/ensemble_recommender.py @@ -100,7 +100,7 @@ def main( surrogate[surrogate["name"]] if surrogate["name"] in surrogate else {} - ) + ), ) # Then, we can create the recommender diff --git a/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/recommender.py b/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/recommender.py index f15e065af7..9698c231b2 100644 --- a/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/recommender.py +++ b/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/recommender.py @@ -123,7 +123,7 @@ def main( surrogate[surrogate["name"]] if surrogate["name"] in surrogate else {} - ) + ), ) elif recommender == "optimal": recommender_args["tracker"] = tracker diff --git a/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/surrogate.py b/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/surrogate.py index 8bcf74c9cd..e054e49e8a 100644 --- a/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/surrogate.py +++ b/src/gluonts/nursery/tsbench/src/cli/analysis/scripts/surrogate.py @@ -92,7 +92,7 @@ def main( input_flags=inputs, output_normalization=outputs["normalization"], impute_simulatable=outputs["imputation"], - **(_config[surrogate] if surrogate in _config else {}) + **(_config[surrogate] if surrogate in _config else {}), ) # And evaluate it diff --git a/src/gluonts/nursery/tsbench/src/cli/utils/config.py b/src/gluonts/nursery/tsbench/src/cli/utils/config.py index 41425e22e7..4b5e9b6400 100644 --- a/src/gluonts/nursery/tsbench/src/cli/utils/config.py +++ b/src/gluonts/nursery/tsbench/src/cli/utils/config.py @@ -74,12 +74,10 @@ def explode_key_values( """ all_combinations = { primary: ( - itertools.product( - *[ - [(option["key"], value) for value in option["values"]] - for option in choices - ] - ) + itertools.product(*[ + [(option["key"], value) for value in option["values"]] + for option in choices + ]) if choices else [] ) @@ -95,12 +93,9 @@ def explode_key_values( primary_config = {primary_key: primary} for key, value in item: if isinstance(key, (list, tuple)): - primary_config.update( - { - process_key(primary, k): v - for k, v in zip(key, value) - } - ) + primary_config.update({ + process_key(primary, k): v for k, v in zip(key, value) + }) else: primary_config[process_key(primary, key)] = value configs.append(primary_config) @@ -137,12 +132,10 @@ def process_key(model: str, key: str) -> str: for seed in seeds: for dataset in datasets: for model_config in configs: - all_configurations.append( - { - "seed": seed, - "dataset": dataset, - **model_config, - } - ) + all_configurations.append({ + "seed": seed, + "dataset": dataset, + **model_config, + }) return all_configurations diff --git a/src/gluonts/nursery/tsbench/src/tsbench/config/dataset/datasets.py b/src/gluonts/nursery/tsbench/src/tsbench/config/dataset/datasets.py index 9283e7c600..328cdb42c1 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/config/dataset/datasets.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/config/dataset/datasets.py @@ -1026,16 +1026,14 @@ def _extract_data( series = [] for i, store_data in data.groupby("Store"): sorted_data = store_data.sort_values("Date") - series.append( - { - "item_id": int(i) - 1, - "start": sorted_data.Date.min(), - "target": sorted_data.Sales.to_list(), - "feat_static_cat": [ - int(i) - 1, - ], - } - ) + series.append({ + "item_id": int(i) - 1, + "start": sorted_data.Date.min(), + "target": sorted_data.Sales.to_list(), + "feat_static_cat": [ + int(i) - 1, + ], + }) return metadata, series @@ -1094,20 +1092,18 @@ def _extract_data( sorted_data.unit_sales.to_numpy(), index=pd.DatetimeIndex(sorted_data.date), ) - series.append( - { - "item_id": i, - "start": sorted_data.date.min(), - "target": sales.resample("D") - .first() - .fillna(value=0) - .to_list(), - "feat_static_cat": [ - int(store_id) - 1, - int(item_id), - ], - } - ) + series.append({ + "item_id": i, + "start": sorted_data.date.min(), + "target": sales.resample("D") + .first() + .fillna(value=0) + .to_list(), + "feat_static_cat": [ + int(store_id) - 1, + int(item_id), + ], + }) return metadata, series @@ -1163,17 +1159,15 @@ def _extract_data( ): department_id = np.where(department_ids == department)[0][0] sorted_data = group_data.sort_values("Date") - series.append( - { - "item_id": i, - "start": sorted_data.Date.min(), - "target": sorted_data.Weekly_Sales.to_list(), - "feat_static_cat": [ - int(store_id) - 1, - int(department_id), - ], - } - ) + series.append({ + "item_id": i, + "start": sorted_data.Date.min(), + "target": sorted_data.Weekly_Sales.to_list(), + "feat_static_cat": [ + int(store_id) - 1, + int(department_id), + ], + }) return metadata, series @@ -1227,18 +1221,16 @@ def _extract_data( sorted_data.visitors.to_numpy(), index=pd.DatetimeIndex(sorted_data.visit_date), ) - series.append( - { - "item_id": i, - "start": sorted_data.visit_date.min(), - "target": visitors.resample("D") - .first() - .fillna(value=0) - .to_list(), - "feat_static_cat": [ - int(store_id), - ], - } - ) + series.append({ + "item_id": i, + "start": sorted_data.visit_date.min(), + "target": visitors.resample("D") + .first() + .fillna(value=0) + .to_list(), + "feat_static_cat": [ + int(store_id), + ], + }) return metadata, series diff --git a/src/gluonts/nursery/tsbench/src/tsbench/evaluations/aws/analytics.py b/src/gluonts/nursery/tsbench/src/tsbench/evaluations/aws/analytics.py index f65674ae55..2a5203a099 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/evaluations/aws/analytics.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/evaluations/aws/analytics.py @@ -373,13 +373,11 @@ def _fetch_training_jobs( "MaxResults": 100, "Resource": "TrainingJob", "SearchExpression": { - "Filters": [ - { - "Name": "Tags.Experiment", - "Operator": "Equals", - "Value": experiment, - } - ], + "Filters": [{ + "Name": "Tags.Experiment", + "Operator": "Equals", + "Value": experiment, + }], }, } diff --git a/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/_info.py b/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/_info.py index 662817ea4d..31099cb938 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/_info.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/_info.py @@ -142,34 +142,24 @@ def extract_job_infos( ] # And average the performance - averaged_performance = Performance( - **{ - metric: Metric( - np.mean( - [getattr(p, metric).mean for p in performances] - ), - np.std( - [getattr(p, metric).mean for p in performances] - ), - ) - for metric in Performance.metrics() - } - ) + averaged_performance = Performance(**{ + metric: Metric( + np.mean([getattr(p, metric).mean for p in performances]), + np.std([getattr(p, metric).mean for p in performances]), + ) + for metric in Performance.metrics() + }) # Get validation scores if available try: - val_ncrps = np.mean( - [ - job.metrics[c]["evaluation"]["val_ncrps"] - for (job, c) in zip(jobs, choices) - ] - ) - val_loss = np.mean( - [ - job.metrics[c]["evaluation"]["val_loss"] - for (job, c) in zip(jobs, choices) - ] - ).item() + val_ncrps = np.mean([ + job.metrics[c]["evaluation"]["val_ncrps"] + for (job, c) in zip(jobs, choices) + ]) + val_loss = np.mean([ + job.metrics[c]["evaluation"]["val_loss"] + for (job, c) in zip(jobs, choices) + ]).item() val_scores = ValidationScores(val_ncrps, val_loss) except KeyError: val_scores = None diff --git a/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/ensemble.py b/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/ensemble.py index d6dfe5c460..208d11a587 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/ensemble.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/ensemble.py @@ -39,12 +39,10 @@ def __init__(self, directory: Path): continue with Path(file).open("rb") as f: data = pickle.load(f) - configurations.extend( - [ - Config(frozenset(x["configurations"]), x["dataset"]) - for x in data - ] - ) + configurations.extend([ + Config(frozenset(x["configurations"]), x["dataset"]) + for x in data + ]) performances.extend([x["performance"] for x in data]) self.performance_map: Dict[Config[EnsembleConfig], Performance] = dict( diff --git a/src/gluonts/nursery/tsbench/src/tsbench/forecasts/evaluation.py b/src/gluonts/nursery/tsbench/src/tsbench/forecasts/evaluation.py index 7c24adc08e..cfb8b52150 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/forecasts/evaluation.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/forecasts/evaluation.py @@ -101,18 +101,14 @@ def performance(cls, evaluations: list[Evaluation]) -> Performance: Metric(0, 0) if m == "num_model_parameters" else Metric( - np.mean( - [ - metric[m] if m in metric else np.nan - for metric in metrics - ] - ), - np.std( - [ - metric[m] if m in metric else np.nan - for metric in metrics - ] - ), + np.mean([ + metric[m] if m in metric else np.nan + for metric in metrics + ]), + np.std([ + metric[m] if m in metric else np.nan + for metric in metrics + ]), ) ) for m in Performance.metrics() diff --git a/src/gluonts/nursery/tsbench/src/tsbench/recommender/greedy.py b/src/gluonts/nursery/tsbench/src/tsbench/recommender/greedy.py index effac277ca..30377c7990 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/recommender/greedy.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/recommender/greedy.py @@ -88,20 +88,16 @@ def fit( transformer = QuantileTransformer( n_quantiles=min(1000, self.metrics.shape[0]) ) - self.metrics = np.stack( - [ - transformer.fit_transform(dataset_metrics) - for dataset_metrics in self.metrics - ] - ) + self.metrics = np.stack([ + transformer.fit_transform(dataset_metrics) + for dataset_metrics in self.metrics + ]) else: transformer = StandardScaler() - self.metrics = np.stack( - [ - transformer.fit_transform(dataset_metrics) - for dataset_metrics in self.metrics - ] - ) + self.metrics = np.stack([ + transformer.fit_transform(dataset_metrics) + for dataset_metrics in self.metrics + ]) def recommend( self, @@ -123,12 +119,10 @@ def recommend( # true Pareto front. if not self.enforce_single_objective and len(self.objectives) > 1: reference = np.ones(len(self.objectives)) - hypervolumes = np.array( - [ - pygmo.hypervolume(dataset_metrics).compute(reference) # type: ignore - for dataset_metrics in self.metrics - ] - ) + hypervolumes = np.array([ + pygmo.hypervolume(dataset_metrics).compute(reference) # type: ignore + for dataset_metrics in self.metrics + ]) available_choices = list(range(len(model_configs))) result = [] @@ -152,14 +146,12 @@ def recommend( else: # Otherwise, we need to compute the hypervolumes for all datasets reference = np.ones(len(self.objectives)) - config_hypervolumes = np.array( - [ - pygmo.hypervolume( # type: ignore - dataset_performances, - ).compute(reference) - for dataset_performances in all_performances - ] - ) + config_hypervolumes = np.array([ + pygmo.hypervolume( # type: ignore + dataset_performances, + ).compute(reference) + for dataset_performances in all_performances + ]) # And then compute the cumulative hypervolume error error = (hypervolumes - config_hypervolumes).sum() # type: ignore @@ -177,10 +169,8 @@ def recommend( def _dummy_performance() -> Performance: - return Performance.from_dict( - { - mm: np.nan - for m in Performance.metrics() - for mm in [f"{m}_mean", f"{m}_std"] - } - ) + return Performance.from_dict({ + mm: np.nan + for m in Performance.metrics() + for mm in [f"{m}_mean", f"{m}_std"] + }) diff --git a/src/gluonts/nursery/tsbench/src/tsbench/surrogate/nonparametric.py b/src/gluonts/nursery/tsbench/src/tsbench/surrogate/nonparametric.py index 5172fe66d2..9ec0f19aeb 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/surrogate/nonparametric.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/surrogate/nonparametric.py @@ -67,13 +67,11 @@ def __init__( tracker, predict, output_normalization, impute_simulatable ) - self.use_dataset_features = any( - [ - use_simple_dataset_features, - use_seasonal_naive_performance, - use_catch22_features, - ] - ) + self.use_dataset_features = any([ + use_simple_dataset_features, + use_seasonal_naive_performance, + use_catch22_features, + ]) if self.use_dataset_features: self.config_transformer = ConfigTransformer( add_model_features=False, @@ -97,26 +95,22 @@ def _fit( # Then, we assign the model performances and dataset features self.model_performances_ = { - model: np.stack( - [ - p["performance"] - for p in sorted( - data, - key=lambda x: x["dataset"].name(), # type: ignore - ) - ] - ) + model: np.stack([ + p["performance"] + for p in sorted( + data, + key=lambda x: x["dataset"].name(), # type: ignore + ) + ]) for model, data in performances.items() } # We use the seasonal naive model config here since it is ignored anyway if self.use_dataset_features: - self.dataset_features_ = self.config_transformer.fit_transform( - [ - Config(SeasonalNaiveModelConfig(), d) - for d in sorted(datasets, key=lambda x: x.name()) # type: ignore - ] - ) + self.dataset_features_ = self.config_transformer.fit_transform([ + Config(SeasonalNaiveModelConfig(), d) + for d in sorted(datasets, key=lambda x: x.name()) # type: ignore + ]) def _predict( self, X: List[Config[ModelConfig]] diff --git a/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/config.py b/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/config.py index ce5aa18f02..c7e2bb9cf9 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/config.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/config.py @@ -61,14 +61,12 @@ def __init__( add_catch_22_features: Whether a dataset's catch22 features ought to be added. tracker: An optional tracker to obtain the performance of Seasonal Naïve. """ - assert any( - [ - add_model_features, - add_dataset_statistics, - add_seasonal_naive_performance, - add_catch22_features, - ] - ), "ConfigTransformer must be given at least some group of features." + assert any([ + add_model_features, + add_dataset_statistics, + add_seasonal_naive_performance, + add_catch22_features, + ]), "ConfigTransformer must be given at least some group of features." assert ( not add_seasonal_naive_performance or tracker is not None ), "Tracker must be set if seasonal naive performance is used." @@ -337,12 +335,10 @@ def transform( if self.transform_full_config: return cast( npt.NDArray[np.float32], - self.pipeline.transform( - [ - x.model.asdict() - for x in cast(List[Config[ModelConfig]], X) - ] - ), + self.pipeline.transform([ + x.model.asdict() + for x in cast(List[Config[ModelConfig]], X) + ]), ) return cast( npt.NDArray[np.float32], @@ -383,16 +379,14 @@ def feature_names_(self) -> list[str]: def fit( self, X: list[Config[ModelConfig]], _y: Any = None ) -> DatasetStatisticsEncoder: - self.pipeline.fit( - [ - { - **x.dataset.stats(), - "frequency": x.dataset.meta.freq, - "prediction_length": x.dataset.meta.prediction_length, - } - for x in X - ] - ) + self.pipeline.fit([ + { + **x.dataset.stats(), + "frequency": x.dataset.meta.freq, + "prediction_length": x.dataset.meta.prediction_length, + } + for x in X + ]) return self def transform( @@ -430,14 +424,12 @@ def transform( def _get_performance_array( self, X: list[Config[ModelConfig]] ) -> npt.NDArray[np.float32]: - return np.array( - [ - self.tracker.get_performance( - Config(SeasonalNaiveModelConfig(), x.dataset) - ).ncrps.mean - for x in X - ] - )[:, None] + return np.array([ + self.tracker.get_performance( + Config(SeasonalNaiveModelConfig(), x.dataset) + ).ncrps.mean + for x in X + ])[:, None] class DatasetCatch22Encoder(Encoder): diff --git a/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/performance.py b/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/performance.py index 28f32ce34e..9ec903f2d8 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/performance.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/performance.py @@ -112,13 +112,10 @@ def transform( def inverse_transform( self, X: npt.NDArray[np.float32], _y: Any = None ) -> list[Performance]: - df = pd.DataFrame(X, columns=self.feature_names_).assign( - **{ - col: np.nan - for col in set(self.all_feature_names_) - - set(self.feature_names_) - } - ) + df = pd.DataFrame(X, columns=self.feature_names_).assign(**{ + col: np.nan + for col in set(self.all_feature_names_) - set(self.feature_names_) + }) return [ Performance.from_dict(row.to_dict()) for _, row in df.iterrows() ] diff --git a/src/gluonts/shell/sagemaker/dyn.py b/src/gluonts/shell/sagemaker/dyn.py index bb1bc4d57b..fae3e3137f 100644 --- a/src/gluonts/shell/sagemaker/dyn.py +++ b/src/gluonts/shell/sagemaker/dyn.py @@ -49,33 +49,29 @@ def copy_install(self, path: Path): shutil.copytree(path, self.packages / path.name) def pip_install(self, path: Path): - subprocess.check_call( - [ - sys.executable, - "-m", - "pip", - "install", - "--upgrade", - "--target", - str(self.packages), - str(path), - ] - ) + subprocess.check_call([ + sys.executable, + "-m", + "pip", + "install", + "--upgrade", + "--target", + str(self.packages), + str(path), + ]) def install_requirement(self, path: Path): - subprocess.check_call( - [ - sys.executable, - "-m", - "pip", - "install", - "--upgrade", - "--target", - str(self.packages), - "--requirement", - str(path), - ] - ) + subprocess.check_call([ + sys.executable, + "-m", + "pip", + "install", + "--upgrade", + "--target", + str(self.packages), + "--requirement", + str(path), + ]) def install(self, path): if path.is_file(): diff --git a/src/gluonts/time_feature/holiday.py b/src/gluonts/time_feature/holiday.py index 9d395a0bb0..bb6422627c 100644 --- a/src/gluonts/time_feature/holiday.py +++ b/src/gluonts/time_feature/holiday.py @@ -215,16 +215,10 @@ def __call__(self, dates): dates Pandas series with Datetimeindex timestamps. """ - return np.vstack( - [ - np.hstack( - [ - self.kernel_function( - SPECIAL_DATE_FEATURES[feat_name](index) - ) - for index in dates - ] - ) - for feat_name in self.feature_names - ] - ) + return np.vstack([ + np.hstack([ + self.kernel_function(SPECIAL_DATE_FEATURES[feat_name](index)) + for index in dates + ]) + for feat_name in self.feature_names + ]) diff --git a/src/gluonts/torch/model/estimator.py b/src/gluonts/torch/model/estimator.py index b8a1147d44..41757efa70 100644 --- a/src/gluonts/torch/model/estimator.py +++ b/src/gluonts/torch/model/estimator.py @@ -198,13 +198,11 @@ def train_model( ) custom_callbacks = self.trainer_kwargs.pop("callbacks", []) - trainer = pl.Trainer( - **{ - "accelerator": "auto", - "callbacks": [checkpoint] + custom_callbacks, - **self.trainer_kwargs, - } - ) + trainer = pl.Trainer(**{ + "accelerator": "auto", + "callbacks": [checkpoint] + custom_callbacks, + **self.trainer_kwargs, + }) trainer.fit( model=training_network, diff --git a/src/gluonts/torch/model/i_transformer/estimator.py b/src/gluonts/torch/model/i_transformer/estimator.py index 541b16753c..71855820dd 100644 --- a/src/gluonts/torch/model/i_transformer/estimator.py +++ b/src/gluonts/torch/model/i_transformer/estimator.py @@ -232,7 +232,7 @@ def create_training_data_loader( data: Dataset, module: ITransformerLightningModule, shuffle_buffer_length: Optional[int] = None, - **kwargs + **kwargs, ) -> Iterable: data = Cyclic(data).stream() instances = self._create_instance_splitter(module, "training").apply( diff --git a/src/gluonts/torch/model/lag_tst/estimator.py b/src/gluonts/torch/model/lag_tst/estimator.py index 9dac576ab6..7e8ad1e759 100644 --- a/src/gluonts/torch/model/lag_tst/estimator.py +++ b/src/gluonts/torch/model/lag_tst/estimator.py @@ -225,7 +225,7 @@ def create_training_data_loader( data: Dataset, module: LagTSTLightningModule, shuffle_buffer_length: Optional[int] = None, - **kwargs + **kwargs, ) -> Iterable: data = Cyclic(data).stream() instances = self._create_instance_splitter(module, "training").apply( diff --git a/src/gluonts/torch/model/patch_tst/estimator.py b/src/gluonts/torch/model/patch_tst/estimator.py index c9e8c51e4f..94ae5b5444 100644 --- a/src/gluonts/torch/model/patch_tst/estimator.py +++ b/src/gluonts/torch/model/patch_tst/estimator.py @@ -229,7 +229,7 @@ def create_training_data_loader( data: Dataset, module: PatchTSTLightningModule, shuffle_buffer_length: Optional[int] = None, - **kwargs + **kwargs, ) -> Iterable: data = Cyclic(data).stream() instances = self._create_instance_splitter(module, "training").apply( diff --git a/src/gluonts/torch/model/patch_tst/module.py b/src/gluonts/torch/model/patch_tst/module.py index 4e829e2ea1..c4a7149dcb 100644 --- a/src/gluonts/torch/model/patch_tst/module.py +++ b/src/gluonts/torch/model/patch_tst/module.py @@ -38,12 +38,10 @@ def _init_weight(out: torch.Tensor) -> torch.Tensor: Features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] """ n_pos, dim = out.shape - position_enc = np.array( - [ - [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] - for pos in range(n_pos) - ] - ) + position_enc = np.array([ + [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] + for pos in range(n_pos) + ]) # set early to avoid an error in pytorch-1.8+ out.requires_grad = False diff --git a/src/gluonts/torch/model/tft/layers.py b/src/gluonts/torch/model/tft/layers.py index d19ed7f059..5a12260b70 100644 --- a/src/gluonts/torch/model/tft/layers.py +++ b/src/gluonts/torch/model/tft/layers.py @@ -56,12 +56,10 @@ def __init__( self.feature_dims = feature_dims self._num_features = len(feature_dims) - self._projectors = nn.ModuleList( - [ - nn.Linear(out_features=d, in_features=c) - for c, d in zip(feature_dims, embedding_dims) - ] - ) + self._projectors = nn.ModuleList([ + nn.Linear(out_features=d, in_features=c) + for c, d in zip(feature_dims, embedding_dims) + ]) def forward(self, features: torch.Tensor) -> List[torch.Tensor]: """ @@ -189,12 +187,10 @@ def __init__( d_static=self.d_hidden if add_static else None, dropout=dropout, ) - self.variable_networks = nn.ModuleList( - [ - GatedResidualNetwork(d_hidden=d_hidden, dropout=dropout) - for _ in range(num_vars) - ] - ) + self.variable_networks = nn.ModuleList([ + GatedResidualNetwork(d_hidden=d_hidden, dropout=dropout) + for _ in range(num_vars) + ]) def forward( self, diff --git a/src/gluonts/torch/model/wavenet/estimator.py b/src/gluonts/torch/model/wavenet/estimator.py index 234ecff237..6aacdc3a9a 100644 --- a/src/gluonts/torch/model/wavenet/estimator.py +++ b/src/gluonts/torch/model/wavenet/estimator.py @@ -264,60 +264,56 @@ def create_transformation(self) -> Transformation: remove_field_names.append(FieldName.FEAT_STATIC_REAL) if self.num_feat_dynamic_real == 0: remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) - return Chain( - [ - RemoveFields(field_names=remove_field_names), - ( - SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0]) - if self.num_feat_static_cat == 0 - else Identity() - ), - ( - SetField( - output_field=FieldName.FEAT_STATIC_REAL, value=[0.0] - ) - if self.num_feat_static_real == 0 - else Identity() - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=int - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_REAL, - expected_ndim=1, - dtype=np.float32, - ), - AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - AddAgeFeature( - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_AGE, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] - + ( - [FieldName.FEAT_DYNAMIC_REAL] - if self.num_feat_dynamic_real > 0 - else [] - ), - ), - AsNumpyArray( - FieldName.FEAT_TIME, expected_ndim=2, dtype=np.float32 + return Chain([ + RemoveFields(field_names=remove_field_names), + ( + SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0]) + if self.num_feat_static_cat == 0 + else Identity() + ), + ( + SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0]) + if self.num_feat_static_real == 0 + else Identity() + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=int + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_REAL, + expected_ndim=1, + dtype=np.float32, + ), + AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + AddAgeFeature( + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_AGE, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] + + ( + [FieldName.FEAT_DYNAMIC_REAL] + if self.num_feat_dynamic_real > 0 + else [] ), - ] - ) + ), + AsNumpyArray( + FieldName.FEAT_TIME, expected_ndim=2, dtype=np.float32 + ), + ]) def _create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] @@ -348,7 +344,7 @@ def create_training_data_loader( data: Dataset, module: WaveNetLightningModule, shuffle_buffer_length: Optional[int] = None, - **kwargs + **kwargs, ) -> Iterable: data = Cyclic(data).stream() instances = self._create_instance_splitter("training").apply( diff --git a/src/gluonts/transform/feature.py b/src/gluonts/transform/feature.py index f5f519d5a7..52648469c9 100644 --- a/src/gluonts/transform/feature.py +++ b/src/gluonts/transform/feature.py @@ -192,12 +192,10 @@ def __call__(self, values: np.ndarray) -> np.ndarray: last_value_imputation = LastValueImputation() value_no_nans = last_value_imputation(values) - adjusted_values_to_causality = np.concatenate( - ( - np.repeat(value_no_nans[0], self.window_size + 1), - value_no_nans[:-1], - ) - ) + adjusted_values_to_causality = np.concatenate(( + np.repeat(value_no_nans[0], self.window_size + 1), + value_no_nans[:-1], + )) cumsum = np.cumsum(adjusted_values_to_causality) @@ -519,25 +517,23 @@ def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry: # compute the aggregate lags for each time point of the time series agg_vals = np.concatenate( [ - np.zeros( - (max(self.valid_lags) * self.ratio + self.half_window + 1,) - ), + np.zeros(( + max(self.valid_lags) * self.ratio + self.half_window + 1, + )), t_agg.values, ], axis=0, ) - lags = np.vstack( - [ - agg_vals[ - -(l * self.ratio - self.half_window + len(t)) : ( - -(l * self.ratio - self.half_window) - if -(l * self.ratio - self.half_window) != 0 - else None - ) - ] - for l in self.valid_lags + lags = np.vstack([ + agg_vals[ + -(l * self.ratio - self.half_window + len(t)) : ( + -(l * self.ratio - self.half_window) + if -(l * self.ratio - self.half_window) != 0 + else None + ) ] - ) + for l in self.valid_lags + ]) # update the data entry data[self.feature_name] = np.nan_to_num(lags) diff --git a/src/gluonts/zebras/_period.py b/src/gluonts/zebras/_period.py index 4cda3178a9..ed9f52af00 100644 --- a/src/gluonts/zebras/_period.py +++ b/src/gluonts/zebras/_period.py @@ -114,12 +114,9 @@ def dayofyear(self) -> np.ndarray: def week(self) -> np.ndarray: # Note: In Python 3.9 `isocalendar()` returns a named tuple, but we # need to support 3.7 and 3.8, so we use index one for the week. - return np.array( - [ - cal.isocalendar()[1] - for cal in self.data.astype(datetime.datetime) - ] - ) + return np.array([ + cal.isocalendar()[1] for cal in self.data.astype(datetime.datetime) + ]) def __add__(self, other): if _is_number(other): diff --git a/src/gluonts/zebras/_time_frame.py b/src/gluonts/zebras/_time_frame.py index dec7e591e8..82246bcf12 100644 --- a/src/gluonts/zebras/_time_frame.py +++ b/src/gluonts/zebras/_time_frame.py @@ -204,23 +204,19 @@ def move_axis(data, name): head = self.head(5) tail = self.tail(5) - columns.update( - { - col: [ - *(move_axis(head[col], col)), - f"[ ... {len(self) - 10} ... ]", - *(move_axis(tail[col], col)), - ] - for col in self.columns - } - ) + columns.update({ + col: [ + *(move_axis(head[col], col)), + f"[ ... {len(self) - 10} ... ]", + *(move_axis(tail[col], col)), + ] + for col in self.columns + }) else: - columns.update( - { - name: move_axis(values, name) - for name, values in self.columns.items() - } - ) + columns.update({ + name: move_axis(values, name) + for name, values in self.columns.items() + }) return columns @@ -233,14 +229,10 @@ def _repr_html_(self): ] if self.static: - html.extend( - [ - "

Static Data

", - html_table( - {name: [val] for name, val in self.static.items()} - ), - ] - ) + html.extend([ + "

Static Data

", + html_table({name: [val] for name, val in self.static.items()}), + ]) return "\n".join(html) diff --git a/src/gluonts/zebras/schema.py b/src/gluonts/zebras/schema.py index 5bfa777fb2..4ab348eb29 100644 --- a/src/gluonts/zebras/schema.py +++ b/src/gluonts/zebras/schema.py @@ -247,13 +247,11 @@ def load_timeframe( columns = {self.time_series_ref: ref} - columns.update( - { - name: field.load_from(data, name, length=length) - for name, field in self.columns.items() - if name != self.time_series_ref - } - ) + columns.update({ + name: field.load_from(data, name, length=length) + for name, field in self.columns.items() + if name != self.time_series_ref + }) else: columns = {} diff --git a/test/conftest.py b/test/conftest.py index b624259a15..69222f462a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -103,13 +103,11 @@ def _sine7( ) train_dataset = ListDataset( - [ - { - "start": index[0], - "item_id": "all_items", - "target": Y[:, :-prediction_length], - } - ], + [{ + "start": index[0], + "item_id": "all_items", + "target": Y[:, :-prediction_length], + }], freq=index.freqstr, one_dim_target=False, ) diff --git a/test/dataset/test_data_loader.py b/test/dataset/test_data_loader.py index 8dd5684178..3f937628d6 100644 --- a/test/dataset/test_data_loader.py +++ b/test/dataset/test_data_loader.py @@ -231,12 +231,9 @@ def test_as_stacked_batches(): def test_as_stacked_batches_iter(): step = 10 - data = iter( - [ - {"x": np.arange(start, start + step)} - for start in range(0, 100, step) - ] - ) + data = iter([ + {"x": np.arange(start, start + step)} for start in range(0, 100, step) + ]) stream = as_stacked_batches(data, batch_size=2) @@ -255,12 +252,9 @@ def test_as_stacked_batches_iter(): def test_as_stacked_batches_iter_num_batches(): step = 10 - data = iter( - [ - {"x": np.arange(start, start + step)} - for start in range(0, 100, step) - ] - ) + data = iter([ + {"x": np.arange(start, start + step)} for start in range(0, 100, step) + ]) stream = as_stacked_batches(data, batch_size=2, num_batches_per_epoch=3) @@ -281,12 +275,10 @@ def test_as_stacked_batches_iter_num_batches(): def test_as_stacked_batches_num_batches_iter_cycle(): step = 10 data = iter( - Cyclic( - [ - {"x": np.arange(start, start + step)} - for start in range(0, 100, step) - ] - ) + Cyclic([ + {"x": np.arange(start, start + step)} + for start in range(0, 100, step) + ]) ) stream = as_stacked_batches(data, batch_size=2, num_batches_per_epoch=3) diff --git a/test/dataset/test_dataset_mutability.py b/test/dataset/test_dataset_mutability.py index b5c0bbd8fb..d3eba63901 100644 --- a/test/dataset/test_dataset_mutability.py +++ b/test/dataset/test_dataset_mutability.py @@ -25,21 +25,15 @@ AddObservedValuesIndicator, ) -ds1 = [ - { - "start": pd.Period("2020/01/01", freq="1D"), - "target": np.array( - [1, 2, 3, np.nan, 5, np.nan, 7, np.nan, np.nan, 10] - ), - } -] +ds1 = [{ + "start": pd.Period("2020/01/01", freq="1D"), + "target": np.array([1, 2, 3, np.nan, 5, np.nan, 7, np.nan, np.nan, 10]), +}] ds2 = ListDataset( - [ - { - "start": "2020/01/01", - "target": [1, 2, 3, np.nan, 5, np.nan, 7, np.nan, np.nan, 10], - } - ], + [{ + "start": "2020/01/01", + "target": [1, 2, 3, np.nan, 5, np.nan, 7, np.nan, np.nan, 10], + }], freq="1D", ) diff --git a/test/dataset/test_multivariate_grouper.py b/test/dataset/test_multivariate_grouper.py index 1f1ac34e75..cba2a3f7bc 100644 --- a/test/dataset/test_multivariate_grouper.py +++ b/test/dataset/test_multivariate_grouper.py @@ -43,22 +43,18 @@ MULTIVARIATE_TS = [ [{"start": "2014-09-07", "target": [[1, 2, 3, 4], [5, 6, 7, 8]]}], - [ - { - "start": "2014-09-07", - "target": [[1, 2, 3, 4, 2.5], [6.5, 5, 6, 7, 8]], - } - ], + [{ + "start": "2014-09-07", + "target": [[1, 2, 3, 4, 2.5], [6.5, 5, 6, 7, 8]], + }], [{"start": "2014-09-07", "target": [[1, 2, 3, 4], [0, 0, 0, 0]]}], - [ - { - "start": "2014-09-01", - "target": [ - [2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 1, 2, 3, 4], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], - } - ], + [{ + "start": "2014-09-01", + "target": [ + [2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 1, 2, 3, 4], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + }], [{"start": "2014-09-07", "target": [[1, 2, 3, 4, 0], [0, 5, 6, 7, 8]]}], ] diff --git a/test/dataset/test_pandas.py b/test/dataset/test_pandas.py index b3f9834532..569b9de143 100644 --- a/test/dataset/test_pandas.py +++ b/test/dataset/test_pandas.py @@ -241,42 +241,36 @@ def _testcase_dataframes_without_index( dtype=np.float32, ): dataframes = [ - pd.DataFrame.from_dict( - { - "timestamp": pd.period_range( - "2021-01-01 00:00:00", periods=10, freq=freq - ) - .map(str) - .to_list(), - "A": 1 + np.arange(10, dtype=dtype), - "B": 2 + np.arange(10, dtype=dtype), - "C": 3 + np.arange(10, dtype=dtype), - } - ), - pd.DataFrame.from_dict( - { - "timestamp": pd.period_range( - "2021-01-02 00:00:00", periods=20, freq=freq - ) - .map(str) - .to_list(), - "A": 1 + np.arange(20, dtype=dtype), - "B": 2 + np.arange(20, dtype=dtype), - "C": 3 + np.arange(20, dtype=dtype), - } - ), - pd.DataFrame.from_dict( - { - "timestamp": pd.period_range( - "2021-01-03 00:00:00", periods=30, freq=freq - ) - .map(str) - .to_list(), - "A": 1 + np.arange(30, dtype=dtype), - "B": 2 + np.arange(30, dtype=dtype), - "C": 3 + np.arange(30, dtype=dtype), - } - ), + pd.DataFrame.from_dict({ + "timestamp": pd.period_range( + "2021-01-01 00:00:00", periods=10, freq=freq + ) + .map(str) + .to_list(), + "A": 1 + np.arange(10, dtype=dtype), + "B": 2 + np.arange(10, dtype=dtype), + "C": 3 + np.arange(10, dtype=dtype), + }), + pd.DataFrame.from_dict({ + "timestamp": pd.period_range( + "2021-01-02 00:00:00", periods=20, freq=freq + ) + .map(str) + .to_list(), + "A": 1 + np.arange(20, dtype=dtype), + "B": 2 + np.arange(20, dtype=dtype), + "C": 3 + np.arange(20, dtype=dtype), + }), + pd.DataFrame.from_dict({ + "timestamp": pd.period_range( + "2021-01-03 00:00:00", periods=30, freq=freq + ) + .map(str) + .to_list(), + "A": 1 + np.arange(30, dtype=dtype), + "B": 2 + np.arange(30, dtype=dtype), + "C": 3 + np.arange(30, dtype=dtype), + }), ] dataset = pandas.PandasDataset( @@ -307,36 +301,30 @@ def _testcase_dataframes_with_index( dtype=np.float32, ): dataframes = [ - pd.DataFrame.from_dict( - { - "timestamp": index_type( - "2021-01-01 00:00:00", periods=10, freq=freq - ), - "A": 1 + np.arange(10, dtype=dtype), - "B": 2 + np.arange(10, dtype=dtype), - "C": 3 + np.arange(10, dtype=dtype), - } - ).set_index("timestamp"), - pd.DataFrame.from_dict( - { - "timestamp": index_type( - "2021-01-02 00:00:00", periods=20, freq=freq - ), - "A": 1 + np.arange(20, dtype=dtype), - "B": 2 + np.arange(20, dtype=dtype), - "C": 3 + np.arange(20, dtype=dtype), - } - ).set_index("timestamp"), - pd.DataFrame.from_dict( - { - "timestamp": index_type( - "2021-01-03 00:00:00", periods=30, freq=freq - ), - "A": 1 + np.arange(30, dtype=dtype), - "B": 2 + np.arange(30, dtype=dtype), - "C": 3 + np.arange(30, dtype=dtype), - } - ).set_index("timestamp"), + pd.DataFrame.from_dict({ + "timestamp": index_type( + "2021-01-01 00:00:00", periods=10, freq=freq + ), + "A": 1 + np.arange(10, dtype=dtype), + "B": 2 + np.arange(10, dtype=dtype), + "C": 3 + np.arange(10, dtype=dtype), + }).set_index("timestamp"), + pd.DataFrame.from_dict({ + "timestamp": index_type( + "2021-01-02 00:00:00", periods=20, freq=freq + ), + "A": 1 + np.arange(20, dtype=dtype), + "B": 2 + np.arange(20, dtype=dtype), + "C": 3 + np.arange(20, dtype=dtype), + }).set_index("timestamp"), + pd.DataFrame.from_dict({ + "timestamp": index_type( + "2021-01-03 00:00:00", periods=30, freq=freq + ), + "A": 1 + np.arange(30, dtype=dtype), + "B": 2 + np.arange(30, dtype=dtype), + "C": 3 + np.arange(30, dtype=dtype), + }).set_index("timestamp"), ] print(type(dataframes[0].index)) diff --git a/test/dataset/test_split.py b/test/dataset/test_split.py index 33d2cbe98e..606af31a81 100644 --- a/test/dataset/test_split.py +++ b/test/dataset/test_split.py @@ -414,12 +414,10 @@ def test_split_date( @pytest.mark.parametrize( "dataset", [ - [ - { - "start": pd.Period("2021-03-01", freq="D"), - "target": np.ones(shape=(28,)), - } - ], + [{ + "start": pd.Period("2021-03-01", freq="D"), + "target": np.ones(shape=(28,)), + }], ], ) @pytest.mark.parametrize( diff --git a/test/ev/test_aggregations.py b/test/ev/test_aggregations.py index 974f8ad199..48a2033fcd 100644 --- a/test/ev/test_aggregations.py +++ b/test/ev/test_aggregations.py @@ -33,24 +33,20 @@ np.zeros(9), ), ( - np.ma.masked_invalid( - [ - np.full((3, 5), np.nan), - np.full((3, 5), np.nan), - np.full((3, 5), np.nan), - ] - ), + np.ma.masked_invalid([ + np.full((3, 5), np.nan), + np.full((3, 5), np.nan), + np.full((3, 5), np.nan), + ]), 0, np.zeros(5), np.zeros(9), ), ( - np.ma.masked_invalid( - [ - np.array([[0, np.nan], [0, 0]]), - np.array([[0, 5], [-5, np.nan]]), - ] - ), + np.ma.masked_invalid([ + np.array([[0, np.nan], [0, 0]]), + np.array([[0, 5], [-5, np.nan]]), + ]), 0, np.array([-5, 5]), np.array([0, 0, 5, -5]), @@ -91,24 +87,20 @@ def test_Sum(value_stream, res_axis_none, res_axis_0, res_axis_1): np.zeros(9), ), ( - np.ma.masked_invalid( - [ - np.full((3, 5), np.nan), - np.full((3, 5), np.nan), - np.full((3, 5), np.nan), - ] - ), + np.ma.masked_invalid([ + np.full((3, 5), np.nan), + np.full((3, 5), np.nan), + np.full((3, 5), np.nan), + ]), np.nan, np.full(5, np.nan), np.full(9, np.nan), ), ( - np.ma.masked_invalid( - [ - np.array([[0, np.nan], [0, 0]]), - np.array([[0, 5], [-5, np.nan]]), - ] - ), + np.ma.masked_invalid([ + np.array([[0, np.nan], [0, 0]]), + np.array([[0, 5], [-5, np.nan]]), + ]), 0, np.array([-1.25, 2.5]), np.array([0, 0, 2.5, -5]), diff --git a/test/ev/test_metrics_compared_to_previous_approach.py b/test/ev/test_metrics_compared_to_previous_approach.py index 5adc2e5d2b..086504744a 100644 --- a/test/ev/test_metrics_compared_to_previous_approach.py +++ b/test/ev/test_metrics_compared_to_previous_approach.py @@ -127,15 +127,13 @@ def get_data_batches(predictor, test_data): "seasonal_error": np.array( [seasonal_error(input_["target"], seasonality=seasonality)] ), - "naive_2": np.array( - [ - naive_2( - input_["target"], - len(label["target"]), - season_length=seasonality, - ) - ] - ), + "naive_2": np.array([ + naive_2( + input_["target"], + len(label["target"]), + season_length=seasonality, + ) + ]), } yield ChainMap(other_data, forecast_batch) @@ -168,12 +166,9 @@ def get_new_metrics(test_data, predictor, quantile_levels): + MeanSumQuantileLoss([quantile.value for quantile in quantiles]) + MeanWeightedSumQuantileLoss( [quantile.value for quantile in quantiles] - ).add( - *[ - WeightedSumQuantileLoss(q=quantile.value) - for quantile in quantiles - ] - ) + ).add(*[ + WeightedSumQuantileLoss(q=quantile.value) for quantile in quantiles + ]) ) # mask invalid values diff --git a/test/evaluation/test_evaluator.py b/test/evaluation/test_evaluator.py index ca21c56aa8..da57ee389f 100644 --- a/test/evaluation/test_evaluator.py +++ b/test/evaluation/test_evaluator.py @@ -119,134 +119,130 @@ def calculate_metrics( TIMESERIES_M4 = [ - np.array( + np.array([ [ - [ - 2.943_013, - 2.822_251, - 4.196_222, - 1.328_664, - 4.947_390, - 3.333_131, - 1.479_800, - 2.265_094, - 3.413_493, - 3.497_607, - ], - [ - -0.126_781_2, - 3.057_412_2, - 1.901_594_4, - 2.772_549_5, - 3.312_853_1, - 4.411_818_0, - 3.709_025_2, - 4.322_028, - 2.565_359, - 3.074_308, - ], - [ - 2.542_998, - 2.336_757, - 1.417_916, - 1.335_139, - 2.523_035, - 3.645_589, - 3.382_819, - 2.075_960, - 2.643_869, - 2.772_456, - ], - [ - 0.315_685_6, - 1.892_312_1, - 2.476_861_2, - 3.511_628_6, - 4.384_346_5, - 2.960_685_6, - 4.897_572_5, - 3.280_125, - 4.768_556, - 4.958_616, - ], - [ - 2.205_877_3, - 0.782_759_4, - 2.401_420_8, - 2.385_643_4, - 4.845_818_2, - 3.102_322_9, - 3.567_723_7, - 4.878_143, - 3.735_245, - 2.218_113, - ], - ] - ), - np.array( + 2.943_013, + 2.822_251, + 4.196_222, + 1.328_664, + 4.947_390, + 3.333_131, + 1.479_800, + 2.265_094, + 3.413_493, + 3.497_607, + ], [ - [ - 13.11301, - 13.16225, - 14.70622, - 12.00866, - 15.79739, - 14.35313, - 12.66980, - 13.62509, - 14.94349, - 15.19761, - ], - [ - 10.04322, - 13.39741, - 12.41159, - 13.45255, - 14.16285, - 15.43182, - 14.89903, - 15.68203, - 14.09536, - 14.77431, - ], - [ - 12.71300, - 12.67676, - 11.92792, - 12.01514, - 13.37303, - 14.66559, - 14.57282, - 13.43596, - 14.17387, - 14.47246, - ], - [ - 10.48569, - 12.23231, - 12.98686, - 14.19163, - 15.23435, - 13.98069, - 16.08757, - 14.64012, - 16.29856, - 16.65862, - ], - [ - 12.37588, - 11.12276, - 12.91142, - 13.06564, - 15.69582, - 14.12232, - 14.75772, - 16.23814, - 15.26524, - 13.91811, - ], - ] - ), + -0.126_781_2, + 3.057_412_2, + 1.901_594_4, + 2.772_549_5, + 3.312_853_1, + 4.411_818_0, + 3.709_025_2, + 4.322_028, + 2.565_359, + 3.074_308, + ], + [ + 2.542_998, + 2.336_757, + 1.417_916, + 1.335_139, + 2.523_035, + 3.645_589, + 3.382_819, + 2.075_960, + 2.643_869, + 2.772_456, + ], + [ + 0.315_685_6, + 1.892_312_1, + 2.476_861_2, + 3.511_628_6, + 4.384_346_5, + 2.960_685_6, + 4.897_572_5, + 3.280_125, + 4.768_556, + 4.958_616, + ], + [ + 2.205_877_3, + 0.782_759_4, + 2.401_420_8, + 2.385_643_4, + 4.845_818_2, + 3.102_322_9, + 3.567_723_7, + 4.878_143, + 3.735_245, + 2.218_113, + ], + ]), + np.array([ + [ + 13.11301, + 13.16225, + 14.70622, + 12.00866, + 15.79739, + 14.35313, + 12.66980, + 13.62509, + 14.94349, + 15.19761, + ], + [ + 10.04322, + 13.39741, + 12.41159, + 13.45255, + 14.16285, + 15.43182, + 14.89903, + 15.68203, + 14.09536, + 14.77431, + ], + [ + 12.71300, + 12.67676, + 11.92792, + 12.01514, + 13.37303, + 14.66559, + 14.57282, + 13.43596, + 14.17387, + 14.47246, + ], + [ + 10.48569, + 12.23231, + 12.98686, + 14.19163, + 15.23435, + 13.98069, + 16.08757, + 14.64012, + 16.29856, + 16.65862, + ], + [ + 12.37588, + 11.12276, + 12.91142, + 13.06564, + 15.69582, + 14.12232, + 14.75772, + 16.23814, + 15.26524, + 13.91811, + ], + ]), ] RES_M4 = [ diff --git a/test/ext/prophet/test_prophet.py b/test/ext/prophet/test_prophet.py index 5d1335f3cc..865deb9113 100644 --- a/test/ext/prophet/test_prophet.py +++ b/test/ext/prophet/test_prophet.py @@ -32,18 +32,14 @@ def test_feat_dynamic_real_success(freq: str): params = dict(prediction_length=3, prophet_params=dict(n_changepoints=20)) dataset = ListDataset( - data_iter=[ - { - "start": "2017-01-01", - "target": np.array([1.0, 2.0, 3.0, 4.0]), - "feat_dynamic_real": np.array( - [ - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], - ] - ), - } - ], + data_iter=[{ + "start": "2017-01-01", + "target": np.array([1.0, 2.0, 3.0, 4.0]), + "feat_dynamic_real": np.array([ + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + ]), + }], freq=freq, ) @@ -61,18 +57,14 @@ def test_feat_dynamic_real_bad_size(): params = dict(prediction_length=3, prophet_params={}) dataset = ListDataset( - data_iter=[ - { - "start": "2017-01-01", - "target": np.array([1.0, 2.0, 3.0, 4.0]), - "feat_dynamic_real": np.array( - [ - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], - ] - ), - } - ], + data_iter=[{ + "start": "2017-01-01", + "target": np.array([1.0, 2.0, 3.0, 4.0]), + "feat_dynamic_real": np.array([ + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + ]), + }], freq="1D", ) diff --git a/test/ext/r_forecast/test_r_multi_seasonality.py b/test/ext/r_forecast/test_r_multi_seasonality.py index c130f03f74..ca75283f9d 100644 --- a/test/ext/r_forecast/test_r_multi_seasonality.py +++ b/test/ext/r_forecast/test_r_multi_seasonality.py @@ -34,32 +34,24 @@ period = 24 ## two weeks of data -dataset = [ - { - "start": pd.Period("1990-01-01 00", freq=freq), - "target": np.array( - [ - item - for i in range(70) - for item in np.sin( - 2 * np.pi / period * np.arange(1, period + 1, 1) - ) - ] - ) - + np.random.normal(0, 0.5, period * 70) - + np.array( - [ - item - for i in range(10) - for item in [0 for i in range(5 * 24)] - + [8 for i in range(4)] - + [0 for i in range(20)] - + [8 for i in range(4)] - + [0 for i in range(20)] - ] - ), - } -] +dataset = [{ + "start": pd.Period("1990-01-01 00", freq=freq), + "target": np.array([ + item + for i in range(70) + for item in np.sin(2 * np.pi / period * np.arange(1, period + 1, 1)) + ]) + + np.random.normal(0, 0.5, period * 70) + + np.array([ + item + for i in range(10) + for item in [0 for i in range(5 * 24)] + + [8 for i in range(4)] + + [0 for i in range(20)] + + [8 for i in range(4)] + + [0 for i in range(20)] + ]), +}] def no_quantile_crossing( @@ -149,38 +141,25 @@ def test_compare_arimas(): ## Below shows improvement in metric when proper x_regressors are included # -dataset_xreg = [ - { - "start": pd.Period("1990-01-01 00", freq=freq), - "target": np.array( - [ - item - for i in range(21) - for item in np.sin( - 2 * np.pi / period * np.arange(1, period + 1, 1) - ) - ] - ) - + np.random.normal(0, 0.5, period * 21) - + np.array( - [ - item - for i in range(3) - for item in [0 for i in range(167)] + [8 for i in range(0, 1)] - ] - ), - "feat_dynamic_real": np.array( - [ - [ - item - for i in range(3) - for item in [0 for i in range(167)] - + [1 for i in range(0, 1)] - ] - ] - ), - } -] +dataset_xreg = [{ + "start": pd.Period("1990-01-01 00", freq=freq), + "target": np.array([ + item + for i in range(21) + for item in np.sin(2 * np.pi / period * np.arange(1, period + 1, 1)) + ]) + + np.random.normal(0, 0.5, period * 21) + + np.array([ + item + for i in range(3) + for item in [0 for i in range(167)] + [8 for i in range(0, 1)] + ]), + "feat_dynamic_real": np.array([[ + item + for i in range(3) + for item in [0 for i in range(167)] + [1 for i in range(0, 1)] + ]]), +}] def test_compare_arimas_xreg(): diff --git a/test/ext/rotbaum/test_rotbaum_smoke.py b/test/ext/rotbaum/test_rotbaum_smoke.py index 93d1e96dd5..04c07e2398 100644 --- a/test/ext/rotbaum/test_rotbaum_smoke.py +++ b/test/ext/rotbaum/test_rotbaum_smoke.py @@ -72,31 +72,27 @@ def test_short_history_item_pred(): { "start": "2017-10-11", "item_id": "item_1", - "target": np.array( - [ - 1.0, - 9.0, - 2.0, - 0.0, - 0.0, - 1.0, - 5.0, - 3.0, - 4.0, - 2.0, - 0.0, - 0.0, - 1.0, - 6.0, - ] - ), + "target": np.array([ + 1.0, + 9.0, + 2.0, + 0.0, + 0.0, + 1.0, + 5.0, + 3.0, + 4.0, + 2.0, + 0.0, + 0.0, + 1.0, + 6.0, + ]), "feat_static_cat": np.array([0.0, 0.0], dtype=float), - "past_feat_dynamic_real": np.array( - [ - [1.0222e06 for i in range(14)], - [750.0 for i in range(14)], - ] - ), + "past_feat_dynamic_real": np.array([ + [1.0222e06 for i in range(14)], + [750.0 for i in range(14)], + ]), }, { "start": "2017-10-11", diff --git a/test/ext/statsforecast/test_statsforecast.py b/test/ext/statsforecast/test_statsforecast.py index 3c09327b0a..2bb90db539 100644 --- a/test/ext/statsforecast/test_statsforecast.py +++ b/test/ext/statsforecast/test_statsforecast.py @@ -123,14 +123,12 @@ def test_model_config( ) @pytest.mark.parametrize( "dataset", - [ - [ - dict( - start=pd.Period("2021-02-03 00", freq="H"), - target=np.random.normal(loc=10, scale=0.5, size=(100,)), - ) - ] - ], + [[ + dict( + start=pd.Period("2021-02-03 00", freq="H"), + target=np.random.normal(loc=10, scale=0.5, size=(100,)), + ) + ]], ) def test_predictor_working( predictor: StatsForecastPredictor, dataset: Dataset diff --git a/test/model/npts/test_npts.py b/test/model/npts/test_npts.py index a9b7183b58..97faf4e0b7 100644 --- a/test/model/npts/test_npts.py +++ b/test/model/npts/test_npts.py @@ -102,12 +102,10 @@ def test_climatological_forecaster( kernel_type=KernelType.uniform, ) - dataset = [ - { - "start": pd.Period(train_ts.index[0], freq=freq), - "target": train_ts.values, - } - ] + dataset = [{ + "start": pd.Period(train_ts.index[0], freq=freq), + "target": train_ts.values, + }] # validate that the predictor works with targets with NaNs _test_nans_in_target(predictor, dataset) @@ -267,12 +265,10 @@ def test_npts_forecaster( use_seasonal_model=use_seasonal_model, ) - dataset = [ - { - "start": pd.Period(train_ts.index[0], freq=freq), - "target": train_ts.values, - } - ] + dataset = [{ + "start": pd.Period(train_ts.index[0], freq=freq), + "target": train_ts.values, + }] # validate that the predictor works with targets with NaNs _test_nans_in_target(predictor, dataset) @@ -417,12 +413,9 @@ def test_npts_custom_features( freq=train_ts.index.freq, ) # Dummy feature defining 52 seasons - feat_dynamic_real = [ - [ - (ix % 52) / 51.0 - 0.5 - for ix, timestamp in enumerate(full_time_index) - ] - ] + feat_dynamic_real = [[ + (ix % 52) / 51.0 - 0.5 for ix, timestamp in enumerate(full_time_index) + ]] predictor = NPTSPredictor( prediction_length=pred_length, @@ -433,13 +426,11 @@ def test_npts_custom_features( use_default_time_features=False, # disable default time features ) - dataset = [ - { - "start": pd.Period(train_ts.index[0], freq=freq), - "target": train_ts.values, - "feat_dynamic_real": np.array(feat_dynamic_real), - } - ] + dataset = [{ + "start": pd.Period(train_ts.index[0], freq=freq), + "target": train_ts.values, + "feat_dynamic_real": np.array(feat_dynamic_real), + }] # validate that the predictor works with targets with NaNs _test_nans_in_target(predictor, dataset) diff --git a/test/mx/block/test_scaler.py b/test/mx/block/test_scaler.py index 4edec3483d..a949bdab94 100644 --- a/test/mx/block/test_scaler.py +++ b/test/mx/block/test_scaler.py @@ -20,144 +20,118 @@ test_cases = [ ( scaler.MeanScaler(), - mx.nd.array( - [ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ] - ), - mx.nd.array( - [ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ] - ), + mx.nd.array([ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ]), + mx.nd.array([ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ]), mx.nd.array([1.0, 3.0, 1.5, 1.00396824, 1.00396824]), ), ( scaler.MeanScaler(default_scale=0.5), - mx.nd.array( - [ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ] - ), - mx.nd.array( - [ - [0.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ] - ), + mx.nd.array([ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ]), + mx.nd.array([ + [0.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ]), mx.nd.array([0.5, 3.0, 1.5, 0.5, 0.5]), ), ( scaler.MeanScaler(keepdims=True), - mx.nd.array( - [ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ] - ), - mx.nd.array( - [ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ] - ), + mx.nd.array([ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ]), + mx.nd.array([ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ]), mx.nd.array([1.0, 3.0, 1.5, 1.00396824, 1.00396824]).expand_dims( axis=1 ), ), ( scaler.MeanScaler(), - mx.nd.array( - [ - [[1.0]] * 50, - [[0.0]] * 25 + [[3.0]] * 25, - [[2.0]] * 49 + [[1.5]] * 1, - [[0.0]] * 50, - [[1.0]] * 50, - ] - ), - mx.nd.array( - [ - [[1.0]] * 50, - [[0.0]] * 25 + [[1.0]] * 25, - [[0.0]] * 49 + [[1.0]] * 1, - [[1.0]] * 50, - [[0.0]] * 50, - ] - ), + mx.nd.array([ + [[1.0]] * 50, + [[0.0]] * 25 + [[3.0]] * 25, + [[2.0]] * 49 + [[1.5]] * 1, + [[0.0]] * 50, + [[1.0]] * 50, + ]), + mx.nd.array([ + [[1.0]] * 50, + [[0.0]] * 25 + [[1.0]] * 25, + [[0.0]] * 49 + [[1.0]] * 1, + [[1.0]] * 50, + [[0.0]] * 50, + ]), mx.nd.array([1.0, 3.0, 1.5, 1.00396824, 1.00396824]).expand_dims( axis=1 ), ), ( scaler.MeanScaler(minimum_scale=1e-8), - mx.nd.array( - [ - [[1.0, 2.0]] * 50, - [[0.0, 0.0]] * 25 + [[3.0, 6.0]] * 25, - [[2.0, 4.0]] * 49 + [[1.5, 3.0]] * 1, - [[0.0, 0.0]] * 50, - [[1.0, 2.0]] * 50, - ] - ), - mx.nd.array( - [ - [[1.0, 1.0]] * 50, - [[0.0, 1.0]] * 25 + [[1.0, 0.0]] * 25, - [[1.0, 0.0]] * 49 + [[0.0, 1.0]] * 1, - [[1.0, 0.0]] * 50, - [[0.0, 1.0]] * 50, - ] - ), - mx.nd.array( - [ - [1.0, 2.0], - [3.0, 1.61111116], - [2.0, 3.0], - [1.28160918, 1.61111116], - [1.28160918, 2.0], - ] - ), + mx.nd.array([ + [[1.0, 2.0]] * 50, + [[0.0, 0.0]] * 25 + [[3.0, 6.0]] * 25, + [[2.0, 4.0]] * 49 + [[1.5, 3.0]] * 1, + [[0.0, 0.0]] * 50, + [[1.0, 2.0]] * 50, + ]), + mx.nd.array([ + [[1.0, 1.0]] * 50, + [[0.0, 1.0]] * 25 + [[1.0, 0.0]] * 25, + [[1.0, 0.0]] * 49 + [[0.0, 1.0]] * 1, + [[1.0, 0.0]] * 50, + [[0.0, 1.0]] * 50, + ]), + mx.nd.array([ + [1.0, 2.0], + [3.0, 1.61111116], + [2.0, 3.0], + [1.28160918, 1.61111116], + [1.28160918, 2.0], + ]), ), ( scaler.MeanScaler(), - mx.nd.array( - [ - [120.0] * 25 + [150.0] * 25, - [0.0] * 10 + [3.0] * 20 + [61.0] * 20, - [0.0] * 50, - [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, - ] - ), - mx.nd.array( - [ - [1.0] * 25 + [1.0] * 25, - [0.0] * 10 + [1.0] * 20 + [1.0] * 20, - [0.0] * 50, - [1.0] * 10 + [0.0] * 30 + [1.0] * 10, - ] - ), + mx.nd.array([ + [120.0] * 25 + [150.0] * 25, + [0.0] * 10 + [3.0] * 20 + [61.0] * 20, + [0.0] * 50, + [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, + ]), + mx.nd.array([ + [1.0] * 25 + [1.0] * 25, + [0.0] * 10 + [1.0] * 20 + [1.0] * 20, + [0.0] * 50, + [1.0] * 10 + [0.0] * 30 + [1.0] * 10, + ]), mx.nd.array([135.0, 32.0, 73.00454712, 2.5e-2]), ), ( @@ -207,129 +181,93 @@ test_minmax = [ ( scaler.MinMax(), - mx.nd.array( - [ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - ] - ), - mx.nd.array( - [ - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - ] - ), - mx.nd.array( - [ - [0.0, 0.5, 1.0], - [0.0, 0.5, 1.0], - ] - ), + mx.nd.array([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ]), + mx.nd.array([ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ]), + mx.nd.array([ + [0.0, 0.5, 1.0], + [0.0, 0.5, 1.0], + ]), ), ( scaler.MinMax(), - mx.nd.array( - [ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - ] - ), - mx.nd.array( - [ - [0.0, 1.0, 1.0], - [1.0, 1.0, 0.0], - ] - ), - mx.nd.array( - [ - [0.0, 0, 1.0], - [0.0, 1.0, 0.0], - ] - ), + mx.nd.array([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ]), + mx.nd.array([ + [0.0, 1.0, 1.0], + [1.0, 1.0, 0.0], + ]), + mx.nd.array([ + [0.0, 0, 1.0], + [0.0, 1.0, 0.0], + ]), ), ( scaler.MinMax(), - mx.nd.array( - [ - [9.0, 9.0, 9.0], - [4.0, 5.0, 6.0], - ] - ), - mx.nd.array( - [ - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - ] - ), - mx.nd.array( - [ - [1.0, 1.0, 1.0], - [0.0, 0.5, 1.0], - ] - ), + mx.nd.array([ + [9.0, 9.0, 9.0], + [4.0, 5.0, 6.0], + ]), + mx.nd.array([ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ]), + mx.nd.array([ + [1.0, 1.0, 1.0], + [0.0, 0.5, 1.0], + ]), ), ( scaler.MinMax(), - mx.nd.array( - [ - [9.0, 9.0, 9.0], - [4.0, 5.0, 6.0], - ] - ), - mx.nd.array( - [ - [0.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - ] - ), - mx.nd.array( - [ - [0.0, 1.0, 1.0], - [0.0, 0.5, 1.0], - ] - ), + mx.nd.array([ + [9.0, 9.0, 9.0], + [4.0, 5.0, 6.0], + ]), + mx.nd.array([ + [0.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ]), + mx.nd.array([ + [0.0, 1.0, 1.0], + [0.0, 0.5, 1.0], + ]), ), ( scaler.MinMax(), - mx.nd.array( - [ - [0.0, 0.0, 0.0], - [4.0, 5.0, 6.0], - ] - ), - mx.nd.array( - [ - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - ] - ), - mx.nd.array( - [ - [0.0, 0.0, 0.0], - [0.0, 0.5, 1.0], - ] - ), + mx.nd.array([ + [0.0, 0.0, 0.0], + [4.0, 5.0, 6.0], + ]), + mx.nd.array([ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ]), + mx.nd.array([ + [0.0, 0.0, 0.0], + [0.0, 0.5, 1.0], + ]), ), ( scaler.MinMax(axis=0), - mx.nd.array( - [ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - ] - ), - mx.nd.array( - [ - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - ] - ), - mx.nd.array( - [ - [0.0, 0.0, 0.0], - [1.0, 1.0, 1.0], - ] - ), + mx.nd.array([ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ]), + mx.nd.array([ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ]), + mx.nd.array([ + [0.0, 0.0, 0.0], + [1.0, 1.0, 1.0], + ]), ), ] diff --git a/test/mx/distribution/test_distribution_methods.py b/test/mx/distribution/test_distribution_methods.py index 9266fd0428..9ada4f7bd6 100644 --- a/test/mx/distribution/test_distribution_methods.py +++ b/test/mx/distribution/test_distribution_methods.py @@ -147,12 +147,10 @@ ( EmpiricalDistribution, { - "samples": mx.nd.stack( - *[ - mx.nd.arange(start=0, stop=20, step=2), - mx.nd.arange(start=100, stop=0, step=-10), - ] - ).transpose(), + "samples": mx.nd.stack(*[ + mx.nd.arange(start=0, stop=20, step=2), + mx.nd.arange(start=100, stop=0, step=-10), + ]).transpose(), "event_dim": 1, }, ), @@ -254,12 +252,10 @@ ( EmpiricalDistribution, { - "samples": mx.nd.stack( - *[ - mx.nd.arange(start=0, stop=20, step=2), - mx.nd.arange(start=100, stop=0, step=-10), - ] - ).transpose(), + "samples": mx.nd.stack(*[ + mx.nd.arange(start=0, stop=20, step=2), + mx.nd.arange(start=100, stop=0, step=-10), + ]).transpose(), "event_dim": 1, }, ), diff --git a/test/mx/distribution/test_distribution_output_shapes.py b/test/mx/distribution/test_distribution_output_shapes.py index 7f8ffb857f..c55783397c 100644 --- a/test/mx/distribution/test_distribution_output_shapes.py +++ b/test/mx/distribution/test_distribution_output_shapes.py @@ -157,12 +157,10 @@ TEST_CASES_WITHOUT_VARIANCE = [ ( - MixtureDistributionOutput( - [ - MultivariateGaussianOutput(dim=5), - MultivariateGaussianOutput(dim=5), - ] - ), + MixtureDistributionOutput([ + MultivariateGaussianOutput(dim=5), + MultivariateGaussianOutput(dim=5), + ]), mx.nd.random.normal(shape=(3, 4, 10)), [None, mx.nd.ones(shape=(3, 4, 5))], [None], diff --git a/test/mx/distribution/test_distribution_sampling.py b/test/mx/distribution/test_distribution_sampling.py index 80b61882ce..e827221f7f 100644 --- a/test/mx/distribution/test_distribution_sampling.py +++ b/test/mx/distribution/test_distribution_sampling.py @@ -233,20 +233,18 @@ def test_multivariate_sampling(distr, params, dim, serialize_fn) -> None: ) -test_cases_pwl_sqf = [ - ( - PiecewiseLinear, - { - "gamma": mx.nd.array([2]).repeat(axis=0, repeats=2), - "slopes": mx.nd.array([[3, 1, 3, 0.2, 5, 4]]).repeat( - axis=0, repeats=2 - ), - "knot_spacings": mx.nd.array( - [[0.3, 0.2, 0.2, 0.15, 0.1, 0.05]] - ).repeat(axis=0, repeats=2), - }, - ) -] +test_cases_pwl_sqf = [( + PiecewiseLinear, + { + "gamma": mx.nd.array([2]).repeat(axis=0, repeats=2), + "slopes": mx.nd.array([[3, 1, 3, 0.2, 5, 4]]).repeat( + axis=0, repeats=2 + ), + "knot_spacings": mx.nd.array( + [[0.3, 0.2, 0.2, 0.15, 0.1, 0.05]] + ).repeat(axis=0, repeats=2), + }, +)] @pytest.mark.parametrize("distr, params", test_cases_pwl_sqf) diff --git a/test/mx/distribution/test_nan_mixture.py b/test/mx/distribution/test_nan_mixture.py index e806d82b7b..b3bdec423c 100644 --- a/test/mx/distribution/test_nan_mixture.py +++ b/test/mx/distribution/test_nan_mixture.py @@ -63,12 +63,10 @@ def diff(x: np.ndarray, y: np.ndarray) -> np.ndarray: sigma_grad_true[p == 1] = 0 params_gauss_grad = {"mu": mu_grad_true, "sigma": sigma_grad_true} -p_cat = np.array( - [ - [[[0.1, 0.9], [0.9, 0.1], [0.5, 0.5]]], - [[[0.9, 0.1], [0.05, 0.95], [0.45, 0.55]]], - ] -) +p_cat = np.array([ + [[[0.1, 0.9], [0.9, 0.1], [0.5, 0.5]]], + [[[0.9, 0.1], [0.05, 0.95], [0.45, 0.55]]], +]) params_cat = {"log_probs": mx.nd.array(np.log(p_cat))} x_cat = np.array([[[np.nan, 0, 1]], [[np.nan, 0, np.nan]]]) diff --git a/test/mx/distribution/test_piecewise_linear.py b/test/mx/distribution/test_piecewise_linear.py index ea2a44c1ca..406ffb5a77 100644 --- a/test/mx/distribution/test_piecewise_linear.py +++ b/test/mx/distribution/test_piecewise_linear.py @@ -67,12 +67,12 @@ def test_values( ): distr = serialize_fn(distr) target = mx.nd.array(target).reshape(shape=(len(target),)) - expected_target_cdf = np.array(expected_target_cdf).reshape( - (len(expected_target_cdf),) - ) - expected_target_crps = np.array(expected_target_crps).reshape( - (len(expected_target_crps),) - ) + expected_target_cdf = np.array(expected_target_cdf).reshape(( + len(expected_target_cdf), + )) + expected_target_crps = np.array(expected_target_crps).reshape(( + len(expected_target_crps), + )) assert all(np.isclose(distr.cdf(target).asnumpy(), expected_target_cdf)) assert all(np.isclose(distr.crps(target).asnumpy(), expected_target_crps)) diff --git a/test/mx/kernels/test_periodic_kernel.py b/test/mx/kernels/test_periodic_kernel.py index 7424f927ed..329eb8230e 100644 --- a/test/mx/kernels/test_periodic_kernel.py +++ b/test/mx/kernels/test_periodic_kernel.py @@ -56,13 +56,11 @@ nd.array([[0, 1, 3], [2, -1, 1], [1, 0, -1], [-1, -2, 3]]), nd.array([3, 2.1, 4.2]), nd.array([1.3, 2.5, 3.2]), - nd.array( - [ - [[14, 2, 2, 14], [57, 41, 19, 83]], - [[40, 56, 24, 72], [84, 116, 172, 12]], - [[22, 42, 26, 38], [217, 249, 299, 155]], - ] - ), + nd.array([ + [[14, 2, 2, 14], [57, 41, 19, 83]], + [[40, 56, 24, 72], [84, 116, 172, 12]], + [[22, 42, 26, 38], [217, 249, 299, 155]], + ]), ), ] diff --git a/test/mx/kernels/test_rbf_kernel.py b/test/mx/kernels/test_rbf_kernel.py index e60a905e20..3e13d79408 100644 --- a/test/mx/kernels/test_rbf_kernel.py +++ b/test/mx/kernels/test_rbf_kernel.py @@ -55,13 +55,11 @@ nd.array([[0, 1, 3], [2, -1, 1], [1, 0, -1], [-1, -2, 3]]), nd.array([3, 2.1, 4.2]), nd.array([1.3, 2.5, 3.2]), - nd.array( - [ - [[14, 2, 2, 14], [57, 41, 19, 83]], - [[40, 56, 24, 72], [84, 116, 172, 12]], - [[22, 42, 26, 38], [217, 249, 299, 155]], - ] - ), + nd.array([ + [[14, 2, 2, 14], [57, 41, 19, 83]], + [[40, 56, 24, 72], [84, 116, 172, 12]], + [[22, 42, 26, 38], [217, 249, 299, 155]], + ]), ), ] diff --git a/test/mx/model/deepvar_hierarchical/generate_hierarchical_dataset.py b/test/mx/model/deepvar_hierarchical/generate_hierarchical_dataset.py index 8e1737abb2..6e66978d60 100644 --- a/test/mx/model/deepvar_hierarchical/generate_hierarchical_dataset.py +++ b/test/mx/model/deepvar_hierarchical/generate_hierarchical_dataset.py @@ -73,13 +73,11 @@ def sine7(seq_length: int, prediction_length: int): ) train_dataset = ListDataset( - [ - { - "start": index[0], - "item_id": "all_items", - "target": Y[:, :-prediction_length], - } - ], + [{ + "start": index[0], + "item_id": "all_items", + "target": Y[:, :-prediction_length], + }], freq=index.freqstr, one_dim_target=False, ) diff --git a/test/mx/model/deepvar_hierarchical/test_train_prediction_with_hts.py b/test/mx/model/deepvar_hierarchical/test_train_prediction_with_hts.py index 22fc735217..ddec3e01ea 100644 --- a/test/mx/model/deepvar_hierarchical/test_train_prediction_with_hts.py +++ b/test/mx/model/deepvar_hierarchical/test_train_prediction_with_hts.py @@ -89,9 +89,7 @@ def test_train_prediction(features_df: Optional[pd.DataFrame]): forecasts = list(predictor.predict(predictor_input)) assert len(forecasts) == len(dataset) - assert all( - [ - forecast.samples.shape == (100, PREDICTION_LENGTH, hts.num_ts) - for forecast in forecasts - ] - ) + assert all([ + forecast.samples.shape == (100, PREDICTION_LENGTH, hts.num_ts) + for forecast in forecasts + ]) diff --git a/test/mx/model/gp_forecaster/data.py b/test/mx/model/gp_forecaster/data.py index 3b78ce1ae7..3f1450fe8b 100644 --- a/test/mx/model/gp_forecaster/data.py +++ b/test/mx/model/gp_forecaster/data.py @@ -15,5481 +15,5471 @@ def load_gp_params(): - return nd.array( - [ - [412263.3050, 8.0703, 57.3620], - [387.5274, 5.0673, 4.2793], - [41625.4972, 6.7450, 9.5796], - [2639.2794, 3.6458, 5.1566], - [4423.1468, 6.0896, 5.1452], - [19065.0601, 4.2969, 13.6918], - [28449.8824, 4.7820, 10.6377], - [300837.0179, 6.8132, 41.2674], - [1374.7983, 4.3378, 4.0292], - [5304.0383, 6.1224, 1.1865], - ] - ).expand_dims(axis=2) + return nd.array([ + [412263.3050, 8.0703, 57.3620], + [387.5274, 5.0673, 4.2793], + [41625.4972, 6.7450, 9.5796], + [2639.2794, 3.6458, 5.1566], + [4423.1468, 6.0896, 5.1452], + [19065.0601, 4.2969, 13.6918], + [28449.8824, 4.7820, 10.6377], + [300837.0179, 6.8132, 41.2674], + [1374.7983, 4.3378, 4.0292], + [5304.0383, 6.1224, 1.1865], + ]).expand_dims(axis=2) def load_exact_mean(): - return nd.array( + return nd.array([ + [ + 329.91, + 318.8, + 326.58, + 352.84, + 395.77, + 452.58, + 519.87, + 594.14, + 672.03, + 750.42, + 826.4, + 897.05, + 959.22, + 1009.3, + 1043.4, + 1057.1, + 1046.3, + 1007.4, + 938.3, + 838.59, + 710.16, + 557.29, + 386.44, + 370.18, + 341.38, + 330.33, + 338.16, + 364.32, + 406.87, + 462.85, + 528.8, + 601.17, + 676.64, + 752.21, + 825.11, + 892.63, + 951.82, + 999.28, + 1031.1, + 1043.2, + 1031.2, + 991.72, + 922.39, + 822.75, + 694.55, + 541.99, + 371.45, + 376.02, + 347.33, + 336.44, + 344.41, + 370.61, + 412.98, + 468.49, + 533.61, + 604.78, + 678.67, + 752.35, + 823.15, + 888.47, + 945.49, + 990.94, + 1021.1, + 1031.7, + 1018.8, + 978.69, + 909.16, + 809.66, + 681.91, + 529.99, + 360.27, + 377.02, + 348.76, + 338.32, + 346.72, + 373.3, + 415.97, + 471.67, + 536.82, + 607.84, + 681.37, + 754.49, + 824.52, + 888.89, + 944.83, + 989.13, + 1018.1, + 1027.7, + 1013.8, + 973.03, + 903.11, + 803.53, + 676.04, + 524.73, + 355.89, + 374.28, + 346.91, + 337.34, + 346.63, + 374.14, + 417.81, + 474.57, + 540.85, + 613.01, + 687.65, + 761.74, + 832.55, + 897.41, + 953.51, + 997.61, + 1026, + 1034.8, + 1020.1, + 978.36, + 907.64, + 807.53, + 679.87, + 528.79, + 360.61, + 368.96, + 343.04, + 334.91, + 345.67, + 374.79, + 420.26, + 479.07, + 547.66, + 622.33, + 699.59, + 776.26, + 849.41, + 916.22, + 973.7, + 1018.5, + 1047.1, + 1055.3, + 1039.6, + 996.64, + 924.63, + 823.41, + 694.98, + 543.61, + 375.65, + 362.13, + 338.31, + 332.23, + 345.12, + 376.54, + 424.62, + 486.44, + 558.45, + 636.92, + 718.19, + 798.86, + 875.77, + 945.81, + 1005.8, + 1052.2, + 1081.3, + 1089.2, + 1072.4, + 1027.9, + 954.1, + 851.17, + 721.38, + 569.17, + 400.96, + 358.65, + ], + [ + 26.239, + 26.612, + 26.31, + 25.441, + 24.318, + 23.394, + 23.112, + 23.72, + 25.148, + 27.001, + 28.699, + 29.716, + 29.81, + 29.155, + 28.283, + 27.875, + 28.464, + 30.184, + 32.67, + 35.15, + 36.697, + 36.53, + 34.26, + 26.079, + 26.356, + 26.123, + 25.377, + 24.269, + 23.133, + 22.412, + 22.513, + 23.625, + 25.603, + 27.972, + 30.08, + 31.355, + 31.549, + 30.869, + 29.916, + 29.444, + 30.048, + 31.882, + 34.561, + 37.256, + 38.964, + 38.84, + 36.463, + 27.34, + 27.022, + 26.272, + 25.145, + 23.828, + 22.661, + 22.071, + 22.427, + 23.876, + 26.22, + 28.935, + 31.326, + 32.793, + 33.079, + 32.405, + 31.404, + 30.873, + 31.438, + 33.276, + 35.992, + 38.735, + 40.468, + 40.314, + 37.834, + 28.936, + 28.179, + 27.044, + 25.629, + 24.14, + 22.91, + 22.342, + 22.777, + 24.33, + 26.785, + 29.604, + 32.084, + 33.617, + 33.942, + 33.274, + 32.242, + 31.649, + 32.13, + 33.871, + 36.492, + 39.148, + 40.809, + 40.599, + 38.075, + 30.565, + 29.545, + 28.174, + 26.574, + 24.953, + 23.624, + 22.961, + 23.278, + 24.681, + 26.963, + 29.608, + 31.943, + 33.377, + 33.648, + 32.951, + 31.886, + 31.224, + 31.584, + 33.16, + 35.596, + 38.088, + 39.641, + 39.41, + 36.961, + 31.817, + 30.693, + 29.225, + 27.539, + 25.826, + 24.371, + 23.51, + 23.538, + 24.562, + 26.408, + 28.617, + 30.572, + 31.727, + 31.828, + 31.044, + 29.925, + 29.185, + 29.41, + 30.789, + 33.006, + 35.314, + 36.788, + 36.628, + 34.415, + 32.283, + 31.18, + 29.725, + 28.034, + 26.271, + 24.68, + 23.561, + 23.184, + 23.667, + 24.883, + 26.449, + 27.832, + 28.549, + 28.369, + 27.435, + 26.235, + 25.413, + 25.505, + 26.693, + 28.702, + 30.862, + 32.327, + 32.358, + 30.553, + 25.266, + ], + [ + 128.95, + 122.39, + 120.64, + 123.55, + 130.36, + 139.91, + 150.88, + 162.03, + 172.42, + 181.48, + 189.02, + 195.16, + 200.19, + 204.44, + 208.13, + 211.28, + 213.65, + 214.75, + 213.92, + 210.41, + 203.54, + 192.86, + 178.23, + 163.31, + 155.31, + 151.08, + 150.99, + 154.77, + 161.58, + 170.28, + 179.68, + 188.74, + 196.76, + 203.4, + 208.68, + 212.85, + 216.27, + 219.23, + 221.88, + 224.12, + 225.6, + 225.74, + 223.82, + 219.07, + 210.87, + 198.8, + 182.79, + 174.14, + 167.52, + 164.42, + 165.1, + 169.16, + 175.75, + 183.73, + 191.99, + 199.65, + 206.14, + 211.3, + 215.28, + 218.42, + 221.08, + 223.54, + 225.85, + 227.84, + 229.03, + 228.78, + 226.33, + 220.92, + 211.96, + 199.07, + 182.29, + 173.98, + 167.47, + 164.48, + 165.21, + 169.22, + 175.62, + 183.29, + 191.13, + 198.31, + 204.34, + 209.12, + 212.84, + 215.87, + 218.56, + 221.15, + 223.64, + 225.77, + 227.02, + 226.7, + 224.03, + 218.3, + 208.92, + 195.62, + 178.46, + 166, + 158.67, + 155.08, + 155.45, + 159.35, + 165.87, + 173.86, + 182.19, + 189.98, + 196.7, + 202.2, + 206.66, + 210.38, + 213.7, + 216.82, + 219.71, + 222.07, + 223.4, + 223, + 220.13, + 214.1, + 204.4, + 190.8, + 173.44, + 154.11, + 145.46, + 140.96, + 140.92, + 144.95, + 152.12, + 161.23, + 171.03, + 180.51, + 189, + 196.23, + 202.23, + 207.24, + 211.54, + 215.29, + 218.49, + 220.88, + 221.98, + 221.18, + 217.81, + 211.23, + 201.01, + 186.99, + 169.34, + 142.18, + 132.2, + 126.9, + 126.73, + 131.35, + 139.83, + 150.86, + 163.06, + 175.18, + 186.35, + 196.06, + 204.16, + 210.78, + 216.11, + 220.35, + 223.51, + 225.44, + 225.78, + 224.02, + 219.59, + 211.98, + 200.83, + 186.04, + 167.83, + 139.73, + ], + [ + 70.415, + 63.185, + 60.465, + 60.591, + 60.143, + 57.038, + 52.504, + 50.145, + 52.795, + 59.874, + 67.577, + 71.705, + 70.705, + 66.507, + 62.829, + 62.672, + 66.912, + 74.551, + 83.744, + 92.514, + 98.891, + 100.95, + 97.216, + 79.011, + 69.417, + 63.16, + 61.764, + 63.012, + 62.935, + 59.324, + 53.838, + 50.763, + 53.311, + 60.684, + 68.47, + 72.03, + 69.915, + 64.583, + 60.254, + 60.045, + 64.612, + 72.707, + 82.445, + 91.957, + 99.291, + 102.29, + 99.127, + 78.038, + 68.621, + 62.845, + 62.133, + 63.938, + 63.956, + 59.934, + 53.875, + 50.547, + 53.377, + 61.309, + 69.423, + 72.78, + 70.092, + 64.277, + 59.88, + 59.941, + 64.754, + 72.791, + 82.217, + 91.458, + 98.799, + 102.05, + 99.099, + 79.084, + 69.517, + 63.545, + 62.667, + 64.294, + 64.063, + 59.786, + 53.71, + 50.862, + 54.69, + 63.816, + 72.863, + 76.62, + 73.907, + 67.976, + 63.596, + 63.71, + 68.311, + 75.671, + 84.059, + 92.247, + 98.83, + 101.68, + 98.532, + 82.217, + 72.301, + 65.592, + 63.804, + 64.591, + 63.807, + 59.46, + 53.943, + 52.302, + 57.788, + 68.623, + 78.974, + 83.346, + 80.57, + 74.14, + 69.092, + 68.491, + 72.306, + 78.769, + 86.214, + 93.562, + 99.537, + 102.01, + 98.59, + 85.922, + 75.755, + 68.142, + 65.06, + 64.61, + 63.112, + 58.883, + 54.354, + 54.334, + 61.656, + 74.067, + 85.316, + 89.653, + 85.894, + 77.776, + 70.782, + 68.475, + 71.193, + 77.305, + 85.039, + 93.058, + 99.781, + 102.83, + 99.672, + 87.851, + 77.836, + 69.584, + 65.269, + 63.539, + 61.355, + 57.414, + 54.058, + 55.603, + 64.259, + 77.282, + 88.15, + 91.02, + 84.779, + 73.49, + 63.277, + 58.54, + 60.333, + 67.279, + 77.276, + 88.242, + 97.786, + 102.97, + 101.07, + 80.395, + ], + [ + 41.641, + 40.194, + 40.186, + 42.23, + 46.594, + 53.121, + 61.229, + 70.012, + 78.439, + 85.592, + 90.892, + 94.251, + 96.067, + 97.076, + 98.062, + 99.52, + 101.37, + 102.81, + 102.42, + 98.456, + 89.282, + 73.864, + 52.127, + 46.345, + 44.621, + 43.544, + 43.912, + 46.325, + 51.036, + 57.864, + 66.199, + 75.106, + 83.524, + 90.514, + 95.492, + 98.379, + 99.607, + 99.965, + 100.3, + 101.19, + 102.58, + 103.72, + 103.16, + 99.119, + 89.886, + 74.352, + 52.377, + 47.029, + 45.418, + 44.464, + 44.971, + 47.544, + 52.44, + 59.48, + 68.047, + 77.196, + 85.846, + 93.031, + 98.137, + 101.06, + 102.23, + 102.43, + 102.55, + 103.18, + 104.36, + 105.36, + 104.77, + 100.79, + 91.703, + 76.343, + 54.515, + 46.417, + 44.725, + 43.69, + 44.133, + 46.678, + 51.605, + 58.757, + 67.538, + 77.004, + 86.064, + 93.719, + 99.307, + 102.67, + 104.17, + 104.59, + 104.77, + 105.37, + 106.44, + 107.34, + 106.73, + 102.87, + 94.051, + 79.104, + 57.779, + 45.332, + 43.46, + 42.227, + 42.471, + 44.847, + 49.673, + 56.832, + 65.762, + 75.54, + 85.07, + 93.32, + 99.567, + 103.57, + 105.62, + 106.39, + 106.74, + 107.3, + 108.21, + 108.9, + 108.17, + 104.36, + 95.857, + 81.497, + 60.997, + 44.582, + 42.54, + 41.092, + 41.092, + 43.223, + 47.848, + 54.899, + 63.862, + 73.848, + 83.763, + 92.548, + 99.412, + 104.02, + 106.56, + 107.63, + 108.02, + 108.36, + 108.88, + 109.12, + 108.03, + 104.13, + 95.9, + 82.242, + 62.844, + 44.76, + 42.661, + 41.081, + 40.875, + 42.75, + 47.106, + 53.93, + 62.761, + 72.751, + 82.824, + 91.901, + 99.134, + 104.1, + 106.86, + 107.94, + 108.05, + 107.85, + 107.62, + 107.08, + 105.35, + 101.14, + 93.104, + 80.215, + 62.151, + 43.724, + ], + [ + 118.98, + 118.77, + 114.81, + 110.94, + 111.76, + 120.81, + 139.03, + 164.13, + 191.14, + 214.26, + 229.24, + 235.31, + 235.48, + 234.91, + 237.59, + 243.43, + 247.56, + 242.8, + 224.16, + 192.83, + 156.88, + 128.03, + 116.11, + 114.56, + 120.47, + 121.48, + 119.8, + 119.4, + 124.67, + 138.61, + 161.35, + 189.66, + 217.79, + 239.51, + 250.63, + 250.97, + 244.63, + 238.01, + 236.21, + 239.67, + 243.29, + 239.03, + 220.99, + 189.83, + 153.87, + 125.62, + 115.76, + 113.73, + 119.34, + 120.43, + 119.35, + 120.07, + 126.81, + 142.28, + 166.27, + 195.25, + 223.31, + 244.16, + 253.68, + 251.93, + 243.39, + 234.99, + 232.27, + 235.87, + 240.45, + 237.39, + 220.08, + 188.86, + 152.21, + 123.36, + 113.8, + 112.5, + 117.39, + 117.83, + 116.14, + 116.19, + 122.09, + 136.48, + 159.24, + 187.11, + 214.47, + 235.32, + 245.61, + 245.23, + 238.46, + 231.97, + 231.22, + 236.76, + 243.08, + 241.27, + 224.36, + 192.43, + 154.12, + 123.17, + 111.77, + 111.84, + 116.07, + 115.81, + 113.28, + 112.16, + 116.42, + 128.77, + 149.41, + 175.62, + 202.44, + 224.24, + 237.01, + 240.27, + 237.61, + 235.01, + 237.45, + 245.27, + 252.98, + 251.67, + 234.34, + 200.95, + 160.12, + 125.78, + 110.61, + 111.8, + 115.85, + 115.53, + 112.91, + 111.46, + 114.95, + 126.08, + 145.37, + 170.67, + 197.62, + 221.06, + 236.98, + 244.41, + 246.13, + 247.26, + 252.18, + 261.08, + 268.7, + 266.51, + 247.75, + 212.36, + 168.66, + 130.37, + 110.27, + 111.49, + 115.98, + 116.58, + 115.25, + 115.16, + 119.79, + 131.68, + 151.41, + 177.19, + 205.1, + 230.34, + 248.92, + 259.5, + 264.14, + 267.18, + 272.48, + 280.25, + 285.71, + 281.08, + 260.09, + 222.65, + 176.54, + 134.76, + 109.72, + 113.52, + ], + [ + 193.76, + 175.38, + 162.27, + 160.36, + 172.25, + 195.97, + 225.74, + 254.47, + 276.83, + 291.1, + 299.03, + 303.88, + 308.06, + 311.91, + 314.25, + 314.13, + 312.28, + 310.93, + 311.97, + 314.65, + 314.29, + 303.4, + 274.7, + 224.26, + 207.31, + 188.71, + 175.68, + 174.45, + 187.74, + 213.48, + 245.57, + 276.56, + 300.69, + 315.98, + 324.06, + 328.17, + 330.76, + 332.29, + 331.78, + 328.62, + 323.95, + 320.4, + 320.06, + 321.98, + 320.97, + 308.8, + 277.58, + 228.89, + 211.58, + 193.09, + 180.66, + 180.38, + 194.76, + 221.52, + 254.41, + 285.92, + 310.33, + 325.71, + 333.69, + 337.44, + 339.35, + 339.79, + 337.84, + 333.07, + 326.89, + 322.18, + 321.12, + 322.62, + 321.14, + 308.02, + 275.04, + 226.62, + 208.9, + 190.91, + 179.59, + 180.54, + 195.75, + 222.62, + 254.92, + 285.45, + 308.91, + 323.71, + 331.58, + 335.58, + 337.83, + 338.47, + 336.46, + 331.43, + 324.91, + 319.91, + 318.63, + 319.91, + 318.05, + 304.27, + 270.31, + 220.88, + 202.94, + 185.9, + 176.17, + 178.56, + 194.36, + 220.62, + 251.31, + 279.83, + 301.61, + 315.58, + 323.64, + 328.59, + 332.13, + 334.02, + 333.01, + 328.67, + 322.59, + 317.79, + 316.49, + 317.49, + 315.17, + 300.91, + 266.66, + 215.01, + 197.33, + 181.72, + 173.93, + 177.86, + 194.01, + 219.16, + 247.65, + 273.68, + 293.62, + 306.99, + 315.73, + 322.31, + 327.86, + 331.65, + 332.18, + 328.96, + 323.58, + 319.06, + 317.57, + 318.01, + 314.99, + 300.29, + 266.31, + 211.08, + 194.33, + 180.62, + 174.97, + 180.35, + 196.57, + 220.29, + 246.36, + 269.94, + 288.33, + 301.58, + 311.6, + 320.31, + 328.22, + 334.09, + 336.21, + 334.05, + 329.23, + 324.7, + 322.64, + 322.07, + 318.01, + 302.85, + 269.56, + 210.47, + ], [ - [ - 329.91, - 318.8, - 326.58, - 352.84, - 395.77, - 452.58, - 519.87, - 594.14, - 672.03, - 750.42, - 826.4, - 897.05, - 959.22, - 1009.3, - 1043.4, - 1057.1, - 1046.3, - 1007.4, - 938.3, - 838.59, - 710.16, - 557.29, - 386.44, - 370.18, - 341.38, - 330.33, - 338.16, - 364.32, - 406.87, - 462.85, - 528.8, - 601.17, - 676.64, - 752.21, - 825.11, - 892.63, - 951.82, - 999.28, - 1031.1, - 1043.2, - 1031.2, - 991.72, - 922.39, - 822.75, - 694.55, - 541.99, - 371.45, - 376.02, - 347.33, - 336.44, - 344.41, - 370.61, - 412.98, - 468.49, - 533.61, - 604.78, - 678.67, - 752.35, - 823.15, - 888.47, - 945.49, - 990.94, - 1021.1, - 1031.7, - 1018.8, - 978.69, - 909.16, - 809.66, - 681.91, - 529.99, - 360.27, - 377.02, - 348.76, - 338.32, - 346.72, - 373.3, - 415.97, - 471.67, - 536.82, - 607.84, - 681.37, - 754.49, - 824.52, - 888.89, - 944.83, - 989.13, - 1018.1, - 1027.7, - 1013.8, - 973.03, - 903.11, - 803.53, - 676.04, - 524.73, - 355.89, - 374.28, - 346.91, - 337.34, - 346.63, - 374.14, - 417.81, - 474.57, - 540.85, - 613.01, - 687.65, - 761.74, - 832.55, - 897.41, - 953.51, - 997.61, - 1026, - 1034.8, - 1020.1, - 978.36, - 907.64, - 807.53, - 679.87, - 528.79, - 360.61, - 368.96, - 343.04, - 334.91, - 345.67, - 374.79, - 420.26, - 479.07, - 547.66, - 622.33, - 699.59, - 776.26, - 849.41, - 916.22, - 973.7, - 1018.5, - 1047.1, - 1055.3, - 1039.6, - 996.64, - 924.63, - 823.41, - 694.98, - 543.61, - 375.65, - 362.13, - 338.31, - 332.23, - 345.12, - 376.54, - 424.62, - 486.44, - 558.45, - 636.92, - 718.19, - 798.86, - 875.77, - 945.81, - 1005.8, - 1052.2, - 1081.3, - 1089.2, - 1072.4, - 1027.9, - 954.1, - 851.17, - 721.38, - 569.17, - 400.96, - 358.65, - ], - [ - 26.239, - 26.612, - 26.31, - 25.441, - 24.318, - 23.394, - 23.112, - 23.72, - 25.148, - 27.001, - 28.699, - 29.716, - 29.81, - 29.155, - 28.283, - 27.875, - 28.464, - 30.184, - 32.67, - 35.15, - 36.697, - 36.53, - 34.26, - 26.079, - 26.356, - 26.123, - 25.377, - 24.269, - 23.133, - 22.412, - 22.513, - 23.625, - 25.603, - 27.972, - 30.08, - 31.355, - 31.549, - 30.869, - 29.916, - 29.444, - 30.048, - 31.882, - 34.561, - 37.256, - 38.964, - 38.84, - 36.463, - 27.34, - 27.022, - 26.272, - 25.145, - 23.828, - 22.661, - 22.071, - 22.427, - 23.876, - 26.22, - 28.935, - 31.326, - 32.793, - 33.079, - 32.405, - 31.404, - 30.873, - 31.438, - 33.276, - 35.992, - 38.735, - 40.468, - 40.314, - 37.834, - 28.936, - 28.179, - 27.044, - 25.629, - 24.14, - 22.91, - 22.342, - 22.777, - 24.33, - 26.785, - 29.604, - 32.084, - 33.617, - 33.942, - 33.274, - 32.242, - 31.649, - 32.13, - 33.871, - 36.492, - 39.148, - 40.809, - 40.599, - 38.075, - 30.565, - 29.545, - 28.174, - 26.574, - 24.953, - 23.624, - 22.961, - 23.278, - 24.681, - 26.963, - 29.608, - 31.943, - 33.377, - 33.648, - 32.951, - 31.886, - 31.224, - 31.584, - 33.16, - 35.596, - 38.088, - 39.641, - 39.41, - 36.961, - 31.817, - 30.693, - 29.225, - 27.539, - 25.826, - 24.371, - 23.51, - 23.538, - 24.562, - 26.408, - 28.617, - 30.572, - 31.727, - 31.828, - 31.044, - 29.925, - 29.185, - 29.41, - 30.789, - 33.006, - 35.314, - 36.788, - 36.628, - 34.415, - 32.283, - 31.18, - 29.725, - 28.034, - 26.271, - 24.68, - 23.561, - 23.184, - 23.667, - 24.883, - 26.449, - 27.832, - 28.549, - 28.369, - 27.435, - 26.235, - 25.413, - 25.505, - 26.693, - 28.702, - 30.862, - 32.327, - 32.358, - 30.553, - 25.266, - ], - [ - 128.95, - 122.39, - 120.64, - 123.55, - 130.36, - 139.91, - 150.88, - 162.03, - 172.42, - 181.48, - 189.02, - 195.16, - 200.19, - 204.44, - 208.13, - 211.28, - 213.65, - 214.75, - 213.92, - 210.41, - 203.54, - 192.86, - 178.23, - 163.31, - 155.31, - 151.08, - 150.99, - 154.77, - 161.58, - 170.28, - 179.68, - 188.74, - 196.76, - 203.4, - 208.68, - 212.85, - 216.27, - 219.23, - 221.88, - 224.12, - 225.6, - 225.74, - 223.82, - 219.07, - 210.87, - 198.8, - 182.79, - 174.14, - 167.52, - 164.42, - 165.1, - 169.16, - 175.75, - 183.73, - 191.99, - 199.65, - 206.14, - 211.3, - 215.28, - 218.42, - 221.08, - 223.54, - 225.85, - 227.84, - 229.03, - 228.78, - 226.33, - 220.92, - 211.96, - 199.07, - 182.29, - 173.98, - 167.47, - 164.48, - 165.21, - 169.22, - 175.62, - 183.29, - 191.13, - 198.31, - 204.34, - 209.12, - 212.84, - 215.87, - 218.56, - 221.15, - 223.64, - 225.77, - 227.02, - 226.7, - 224.03, - 218.3, - 208.92, - 195.62, - 178.46, - 166, - 158.67, - 155.08, - 155.45, - 159.35, - 165.87, - 173.86, - 182.19, - 189.98, - 196.7, - 202.2, - 206.66, - 210.38, - 213.7, - 216.82, - 219.71, - 222.07, - 223.4, - 223, - 220.13, - 214.1, - 204.4, - 190.8, - 173.44, - 154.11, - 145.46, - 140.96, - 140.92, - 144.95, - 152.12, - 161.23, - 171.03, - 180.51, - 189, - 196.23, - 202.23, - 207.24, - 211.54, - 215.29, - 218.49, - 220.88, - 221.98, - 221.18, - 217.81, - 211.23, - 201.01, - 186.99, - 169.34, - 142.18, - 132.2, - 126.9, - 126.73, - 131.35, - 139.83, - 150.86, - 163.06, - 175.18, - 186.35, - 196.06, - 204.16, - 210.78, - 216.11, - 220.35, - 223.51, - 225.44, - 225.78, - 224.02, - 219.59, - 211.98, - 200.83, - 186.04, - 167.83, - 139.73, - ], - [ - 70.415, - 63.185, - 60.465, - 60.591, - 60.143, - 57.038, - 52.504, - 50.145, - 52.795, - 59.874, - 67.577, - 71.705, - 70.705, - 66.507, - 62.829, - 62.672, - 66.912, - 74.551, - 83.744, - 92.514, - 98.891, - 100.95, - 97.216, - 79.011, - 69.417, - 63.16, - 61.764, - 63.012, - 62.935, - 59.324, - 53.838, - 50.763, - 53.311, - 60.684, - 68.47, - 72.03, - 69.915, - 64.583, - 60.254, - 60.045, - 64.612, - 72.707, - 82.445, - 91.957, - 99.291, - 102.29, - 99.127, - 78.038, - 68.621, - 62.845, - 62.133, - 63.938, - 63.956, - 59.934, - 53.875, - 50.547, - 53.377, - 61.309, - 69.423, - 72.78, - 70.092, - 64.277, - 59.88, - 59.941, - 64.754, - 72.791, - 82.217, - 91.458, - 98.799, - 102.05, - 99.099, - 79.084, - 69.517, - 63.545, - 62.667, - 64.294, - 64.063, - 59.786, - 53.71, - 50.862, - 54.69, - 63.816, - 72.863, - 76.62, - 73.907, - 67.976, - 63.596, - 63.71, - 68.311, - 75.671, - 84.059, - 92.247, - 98.83, - 101.68, - 98.532, - 82.217, - 72.301, - 65.592, - 63.804, - 64.591, - 63.807, - 59.46, - 53.943, - 52.302, - 57.788, - 68.623, - 78.974, - 83.346, - 80.57, - 74.14, - 69.092, - 68.491, - 72.306, - 78.769, - 86.214, - 93.562, - 99.537, - 102.01, - 98.59, - 85.922, - 75.755, - 68.142, - 65.06, - 64.61, - 63.112, - 58.883, - 54.354, - 54.334, - 61.656, - 74.067, - 85.316, - 89.653, - 85.894, - 77.776, - 70.782, - 68.475, - 71.193, - 77.305, - 85.039, - 93.058, - 99.781, - 102.83, - 99.672, - 87.851, - 77.836, - 69.584, - 65.269, - 63.539, - 61.355, - 57.414, - 54.058, - 55.603, - 64.259, - 77.282, - 88.15, - 91.02, - 84.779, - 73.49, - 63.277, - 58.54, - 60.333, - 67.279, - 77.276, - 88.242, - 97.786, - 102.97, - 101.07, - 80.395, - ], - [ - 41.641, - 40.194, - 40.186, - 42.23, - 46.594, - 53.121, - 61.229, - 70.012, - 78.439, - 85.592, - 90.892, - 94.251, - 96.067, - 97.076, - 98.062, - 99.52, - 101.37, - 102.81, - 102.42, - 98.456, - 89.282, - 73.864, - 52.127, - 46.345, - 44.621, - 43.544, - 43.912, - 46.325, - 51.036, - 57.864, - 66.199, - 75.106, - 83.524, - 90.514, - 95.492, - 98.379, - 99.607, - 99.965, - 100.3, - 101.19, - 102.58, - 103.72, - 103.16, - 99.119, - 89.886, - 74.352, - 52.377, - 47.029, - 45.418, - 44.464, - 44.971, - 47.544, - 52.44, - 59.48, - 68.047, - 77.196, - 85.846, - 93.031, - 98.137, - 101.06, - 102.23, - 102.43, - 102.55, - 103.18, - 104.36, - 105.36, - 104.77, - 100.79, - 91.703, - 76.343, - 54.515, - 46.417, - 44.725, - 43.69, - 44.133, - 46.678, - 51.605, - 58.757, - 67.538, - 77.004, - 86.064, - 93.719, - 99.307, - 102.67, - 104.17, - 104.59, - 104.77, - 105.37, - 106.44, - 107.34, - 106.73, - 102.87, - 94.051, - 79.104, - 57.779, - 45.332, - 43.46, - 42.227, - 42.471, - 44.847, - 49.673, - 56.832, - 65.762, - 75.54, - 85.07, - 93.32, - 99.567, - 103.57, - 105.62, - 106.39, - 106.74, - 107.3, - 108.21, - 108.9, - 108.17, - 104.36, - 95.857, - 81.497, - 60.997, - 44.582, - 42.54, - 41.092, - 41.092, - 43.223, - 47.848, - 54.899, - 63.862, - 73.848, - 83.763, - 92.548, - 99.412, - 104.02, - 106.56, - 107.63, - 108.02, - 108.36, - 108.88, - 109.12, - 108.03, - 104.13, - 95.9, - 82.242, - 62.844, - 44.76, - 42.661, - 41.081, - 40.875, - 42.75, - 47.106, - 53.93, - 62.761, - 72.751, - 82.824, - 91.901, - 99.134, - 104.1, - 106.86, - 107.94, - 108.05, - 107.85, - 107.62, - 107.08, - 105.35, - 101.14, - 93.104, - 80.215, - 62.151, - 43.724, - ], - [ - 118.98, - 118.77, - 114.81, - 110.94, - 111.76, - 120.81, - 139.03, - 164.13, - 191.14, - 214.26, - 229.24, - 235.31, - 235.48, - 234.91, - 237.59, - 243.43, - 247.56, - 242.8, - 224.16, - 192.83, - 156.88, - 128.03, - 116.11, - 114.56, - 120.47, - 121.48, - 119.8, - 119.4, - 124.67, - 138.61, - 161.35, - 189.66, - 217.79, - 239.51, - 250.63, - 250.97, - 244.63, - 238.01, - 236.21, - 239.67, - 243.29, - 239.03, - 220.99, - 189.83, - 153.87, - 125.62, - 115.76, - 113.73, - 119.34, - 120.43, - 119.35, - 120.07, - 126.81, - 142.28, - 166.27, - 195.25, - 223.31, - 244.16, - 253.68, - 251.93, - 243.39, - 234.99, - 232.27, - 235.87, - 240.45, - 237.39, - 220.08, - 188.86, - 152.21, - 123.36, - 113.8, - 112.5, - 117.39, - 117.83, - 116.14, - 116.19, - 122.09, - 136.48, - 159.24, - 187.11, - 214.47, - 235.32, - 245.61, - 245.23, - 238.46, - 231.97, - 231.22, - 236.76, - 243.08, - 241.27, - 224.36, - 192.43, - 154.12, - 123.17, - 111.77, - 111.84, - 116.07, - 115.81, - 113.28, - 112.16, - 116.42, - 128.77, - 149.41, - 175.62, - 202.44, - 224.24, - 237.01, - 240.27, - 237.61, - 235.01, - 237.45, - 245.27, - 252.98, - 251.67, - 234.34, - 200.95, - 160.12, - 125.78, - 110.61, - 111.8, - 115.85, - 115.53, - 112.91, - 111.46, - 114.95, - 126.08, - 145.37, - 170.67, - 197.62, - 221.06, - 236.98, - 244.41, - 246.13, - 247.26, - 252.18, - 261.08, - 268.7, - 266.51, - 247.75, - 212.36, - 168.66, - 130.37, - 110.27, - 111.49, - 115.98, - 116.58, - 115.25, - 115.16, - 119.79, - 131.68, - 151.41, - 177.19, - 205.1, - 230.34, - 248.92, - 259.5, - 264.14, - 267.18, - 272.48, - 280.25, - 285.71, - 281.08, - 260.09, - 222.65, - 176.54, - 134.76, - 109.72, - 113.52, - ], - [ - 193.76, - 175.38, - 162.27, - 160.36, - 172.25, - 195.97, - 225.74, - 254.47, - 276.83, - 291.1, - 299.03, - 303.88, - 308.06, - 311.91, - 314.25, - 314.13, - 312.28, - 310.93, - 311.97, - 314.65, - 314.29, - 303.4, - 274.7, - 224.26, - 207.31, - 188.71, - 175.68, - 174.45, - 187.74, - 213.48, - 245.57, - 276.56, - 300.69, - 315.98, - 324.06, - 328.17, - 330.76, - 332.29, - 331.78, - 328.62, - 323.95, - 320.4, - 320.06, - 321.98, - 320.97, - 308.8, - 277.58, - 228.89, - 211.58, - 193.09, - 180.66, - 180.38, - 194.76, - 221.52, - 254.41, - 285.92, - 310.33, - 325.71, - 333.69, - 337.44, - 339.35, - 339.79, - 337.84, - 333.07, - 326.89, - 322.18, - 321.12, - 322.62, - 321.14, - 308.02, - 275.04, - 226.62, - 208.9, - 190.91, - 179.59, - 180.54, - 195.75, - 222.62, - 254.92, - 285.45, - 308.91, - 323.71, - 331.58, - 335.58, - 337.83, - 338.47, - 336.46, - 331.43, - 324.91, - 319.91, - 318.63, - 319.91, - 318.05, - 304.27, - 270.31, - 220.88, - 202.94, - 185.9, - 176.17, - 178.56, - 194.36, - 220.62, - 251.31, - 279.83, - 301.61, - 315.58, - 323.64, - 328.59, - 332.13, - 334.02, - 333.01, - 328.67, - 322.59, - 317.79, - 316.49, - 317.49, - 315.17, - 300.91, - 266.66, - 215.01, - 197.33, - 181.72, - 173.93, - 177.86, - 194.01, - 219.16, - 247.65, - 273.68, - 293.62, - 306.99, - 315.73, - 322.31, - 327.86, - 331.65, - 332.18, - 328.96, - 323.58, - 319.06, - 317.57, - 318.01, - 314.99, - 300.29, - 266.31, - 211.08, - 194.33, - 180.62, - 174.97, - 180.35, - 196.57, - 220.29, - 246.36, - 269.94, - 288.33, - 301.58, - 311.6, - 320.31, - 328.22, - 334.09, - 336.21, - 334.05, - 329.23, - 324.7, - 322.64, - 322.07, - 318.01, - 302.85, - 269.56, - 210.47, - ], - [ - 379.51, - 379.84, - 395.91, - 427.92, - 473.69, - 529.07, - 588.7, - 647.08, - 699.6, - 743.3, - 777.29, - 802.52, - 821.09, - 835.21, - 846.12, - 853.17, - 853.52, - 842.42, - 814.09, - 763.04, - 685.5, - 580.62, - 451.07, - 410.65, - 399.71, - 402.4, - 421.14, - 455.94, - 504.31, - 561.69, - 622.3, - 680.28, - 730.75, - 770.76, - 799.62, - 818.74, - 830.83, - 838.79, - 844.48, - 847.73, - 845.89, - 834.15, - 806.4, - 756.67, - 680.62, - 576.86, - 447.69, - 414.98, - 404.12, - 407.11, - 426.47, - 462.19, - 511.7, - 570.26, - 631.84, - 690.32, - 740.61, - 779.61, - 806.68, - 823.39, - 832.79, - 838.15, - 841.71, - 843.6, - 841.34, - 830.09, - 803.55, - 755.43, - 681.02, - 578.58, - 450.14, - 410.03, - 397.89, - 399.64, - 418.02, - 453.23, - 502.83, - 562.11, - 624.99, - 685.17, - 737.35, - 778.23, - 806.94, - 824.96, - 835.31, - 841.33, - 845.37, - 847.71, - 845.99, - 835.41, - 809.68, - 762.41, - 688.78, - 586.93, - 458.78, - 401.4, - 387.35, - 387.06, - 403.58, - 437.47, - 486.6, - 546.53, - 611.24, - 674.4, - 730.49, - 775.86, - 809.27, - 831.79, - 846.08, - 855.24, - 861.51, - 865.15, - 863.91, - 853.15, - 826.78, - 778.64, - 704.06, - 601.31, - 472.38, - 394.53, - 378.79, - 376.46, - 391, - 423.41, - 472.04, - 532.82, - 600, - 667.24, - 728.84, - 780.73, - 821.08, - 850.34, - 870.52, - 884.23, - 893.36, - 898.07, - 896.23, - 883.52, - 854.27, - 802.83, - 725, - 619.43, - 488.37, - 393.66, - 377.18, - 373.6, - 386.72, - 418.08, - 466.56, - 528.58, - 598.73, - 670.77, - 738.81, - 798.32, - 846.78, - 883.81, - 910.59, - 929.03, - 940.55, - 945.15, - 940.84, - 923.78, - 888.98, - 831.54, - 748.02, - 637.69, - 503.18, - 392.42, - ], - [ - 41.737, - 33.076, - 24.388, - 16.425, - 10.052, - 6.4268, - 6.8474, - 12.206, - 22.299, - 35.43, - 48.68, - 58.883, - 63.898, - 63.561, - 59.771, - 55.585, - 53.758, - 55.447, - 59.744, - 64.226, - 66.185, - 63.872, - 57.142, - 51.474, - 43.482, - 34.651, - 25.854, - 17.828, - 11.468, - 8.0229, - 8.9065, - 15.065, - 26.203, - 40.365, - 54.275, - 64.445, - 68.619, - 66.805, - 61.329, - 55.764, - 53.26, - 55.074, - 60.078, - 65.417, - 67.947, - 65.656, - 58.396, - 51.897, - 43.901, - 35.098, - 26.31, - 18.242, - 11.813, - 8.3517, - 9.3741, - 15.887, - 27.558, - 42.285, - 56.579, - 66.77, - 70.543, - 68.02, - 61.774, - 55.659, - 53.001, - 55.044, - 60.453, - 66.077, - 68.523, - 65.69, - 57.541, - 51.467, - 43.599, - 34.896, - 26.14, - 18.03, - 11.515, - 7.9791, - 8.9882, - 15.567, - 27.362, - 42.223, - 56.617, - 66.867, - 70.691, - 68.278, - 62.263, - 56.511, - 54.28, - 56.674, - 62.18, - 67.505, - 69.216, - 65.307, - 55.944, - 50.993, - 43.278, - 34.645, - 25.895, - 17.776, - 11.289, - 7.8312, - 8.9214, - 15.506, - 27.173, - 41.794, - 55.957, - 66.162, - 70.288, - 68.58, - 63.578, - 58.925, - 57.59, - 60.424, - 65.773, - 70.355, - 70.884, - 65.589, - 54.893, - 51.015, - 43.422, - 34.798, - 26.04, - 18.023, - 11.818, - 8.7764, - 10.249, - 16.969, - 28.395, - 42.451, - 55.976, - 65.827, - 70.181, - 69.366, - 65.74, - 62.552, - 62.332, - 65.609, - 70.634, - 74.263, - 73.492, - 66.863, - 55.035, - 51.552, - 44.063, - 35.433, - 26.729, - 19.01, - 13.424, - 11.212, - 13.437, - 20.492, - 31.624, - 44.814, - 57.244, - 66.265, - 70.467, - 70.32, - 67.965, - 66.186, - 67.005, - 70.615, - 75.233, - 77.94, - 76.092, - 68.545, - 56.119, - 49.486, - ], - [ - 75.287, - 73.758, - 72.659, - 72.035, - 71.962, - 72.496, - 73.613, - 75.186, - 76.995, - 78.785, - 80.34, - 81.53, - 82.325, - 82.764, - 82.909, - 82.804, - 82.474, - 81.942, - 81.26, - 80.525, - 79.851, - 79.319, - 78.904, - 81.595, - 81.221, - 81.035, - 80.886, - 80.722, - 80.593, - 80.601, - 80.828, - 81.284, - 81.896, - 82.541, - 83.098, - 83.49, - 83.696, - 83.727, - 83.595, - 83.288, - 82.784, - 82.081, - 81.234, - 80.368, - 79.638, - 79.167, - 78.954, - 81.439, - 81.784, - 82.216, - 82.507, - 82.564, - 82.436, - 82.264, - 82.193, - 82.307, - 82.599, - 82.99, - 83.378, - 83.682, - 83.857, - 83.881, - 83.732, - 83.374, - 82.773, - 81.934, - 80.937, - 79.941, - 79.139, - 78.685, - 78.595, - 78.829, - 79.15, - 79.549, - 79.794, - 79.799, - 79.638, - 79.482, - 79.503, - 79.802, - 80.367, - 81.098, - 81.857, - 82.52, - 83.002, - 83.256, - 83.251, - 82.957, - 82.36, - 81.493, - 80.463, - 79.451, - 78.668, - 78.274, - 78.283, - 76.309, - 76.127, - 76.043, - 75.873, - 75.58, - 75.277, - 75.155, - 75.391, - 76.059, - 77.103, - 78.362, - 79.636, - 80.745, - 81.574, - 82.07, - 82.22, - 82.031, - 81.527, - 80.773, - 79.89, - 79.055, - 78.46, - 78.234, - 78.365, - 76.287, - 75.462, - 74.722, - 73.96, - 73.203, - 72.608, - 72.384, - 72.692, - 73.569, - 74.899, - 76.456, - 77.981, - 79.261, - 80.174, - 80.69, - 80.844, - 80.699, - 80.328, - 79.815, - 79.271, - 78.83, - 78.618, - 78.695, - 78.994, - 80.444, - 79.175, - 77.895, - 76.574, - 75.301, - 74.273, - 73.708, - 73.754, - 74.41, - 75.518, - 76.813, - 78.012, - 78.905, - 79.401, - 79.532, - 79.402, - 79.144, - 78.872, - 78.674, - 78.616, - 78.746, - 79.081, - 79.576, - 80.082, - 77.239, - ], - ] - ) + 379.51, + 379.84, + 395.91, + 427.92, + 473.69, + 529.07, + 588.7, + 647.08, + 699.6, + 743.3, + 777.29, + 802.52, + 821.09, + 835.21, + 846.12, + 853.17, + 853.52, + 842.42, + 814.09, + 763.04, + 685.5, + 580.62, + 451.07, + 410.65, + 399.71, + 402.4, + 421.14, + 455.94, + 504.31, + 561.69, + 622.3, + 680.28, + 730.75, + 770.76, + 799.62, + 818.74, + 830.83, + 838.79, + 844.48, + 847.73, + 845.89, + 834.15, + 806.4, + 756.67, + 680.62, + 576.86, + 447.69, + 414.98, + 404.12, + 407.11, + 426.47, + 462.19, + 511.7, + 570.26, + 631.84, + 690.32, + 740.61, + 779.61, + 806.68, + 823.39, + 832.79, + 838.15, + 841.71, + 843.6, + 841.34, + 830.09, + 803.55, + 755.43, + 681.02, + 578.58, + 450.14, + 410.03, + 397.89, + 399.64, + 418.02, + 453.23, + 502.83, + 562.11, + 624.99, + 685.17, + 737.35, + 778.23, + 806.94, + 824.96, + 835.31, + 841.33, + 845.37, + 847.71, + 845.99, + 835.41, + 809.68, + 762.41, + 688.78, + 586.93, + 458.78, + 401.4, + 387.35, + 387.06, + 403.58, + 437.47, + 486.6, + 546.53, + 611.24, + 674.4, + 730.49, + 775.86, + 809.27, + 831.79, + 846.08, + 855.24, + 861.51, + 865.15, + 863.91, + 853.15, + 826.78, + 778.64, + 704.06, + 601.31, + 472.38, + 394.53, + 378.79, + 376.46, + 391, + 423.41, + 472.04, + 532.82, + 600, + 667.24, + 728.84, + 780.73, + 821.08, + 850.34, + 870.52, + 884.23, + 893.36, + 898.07, + 896.23, + 883.52, + 854.27, + 802.83, + 725, + 619.43, + 488.37, + 393.66, + 377.18, + 373.6, + 386.72, + 418.08, + 466.56, + 528.58, + 598.73, + 670.77, + 738.81, + 798.32, + 846.78, + 883.81, + 910.59, + 929.03, + 940.55, + 945.15, + 940.84, + 923.78, + 888.98, + 831.54, + 748.02, + 637.69, + 503.18, + 392.42, + ], + [ + 41.737, + 33.076, + 24.388, + 16.425, + 10.052, + 6.4268, + 6.8474, + 12.206, + 22.299, + 35.43, + 48.68, + 58.883, + 63.898, + 63.561, + 59.771, + 55.585, + 53.758, + 55.447, + 59.744, + 64.226, + 66.185, + 63.872, + 57.142, + 51.474, + 43.482, + 34.651, + 25.854, + 17.828, + 11.468, + 8.0229, + 8.9065, + 15.065, + 26.203, + 40.365, + 54.275, + 64.445, + 68.619, + 66.805, + 61.329, + 55.764, + 53.26, + 55.074, + 60.078, + 65.417, + 67.947, + 65.656, + 58.396, + 51.897, + 43.901, + 35.098, + 26.31, + 18.242, + 11.813, + 8.3517, + 9.3741, + 15.887, + 27.558, + 42.285, + 56.579, + 66.77, + 70.543, + 68.02, + 61.774, + 55.659, + 53.001, + 55.044, + 60.453, + 66.077, + 68.523, + 65.69, + 57.541, + 51.467, + 43.599, + 34.896, + 26.14, + 18.03, + 11.515, + 7.9791, + 8.9882, + 15.567, + 27.362, + 42.223, + 56.617, + 66.867, + 70.691, + 68.278, + 62.263, + 56.511, + 54.28, + 56.674, + 62.18, + 67.505, + 69.216, + 65.307, + 55.944, + 50.993, + 43.278, + 34.645, + 25.895, + 17.776, + 11.289, + 7.8312, + 8.9214, + 15.506, + 27.173, + 41.794, + 55.957, + 66.162, + 70.288, + 68.58, + 63.578, + 58.925, + 57.59, + 60.424, + 65.773, + 70.355, + 70.884, + 65.589, + 54.893, + 51.015, + 43.422, + 34.798, + 26.04, + 18.023, + 11.818, + 8.7764, + 10.249, + 16.969, + 28.395, + 42.451, + 55.976, + 65.827, + 70.181, + 69.366, + 65.74, + 62.552, + 62.332, + 65.609, + 70.634, + 74.263, + 73.492, + 66.863, + 55.035, + 51.552, + 44.063, + 35.433, + 26.729, + 19.01, + 13.424, + 11.212, + 13.437, + 20.492, + 31.624, + 44.814, + 57.244, + 66.265, + 70.467, + 70.32, + 67.965, + 66.186, + 67.005, + 70.615, + 75.233, + 77.94, + 76.092, + 68.545, + 56.119, + 49.486, + ], + [ + 75.287, + 73.758, + 72.659, + 72.035, + 71.962, + 72.496, + 73.613, + 75.186, + 76.995, + 78.785, + 80.34, + 81.53, + 82.325, + 82.764, + 82.909, + 82.804, + 82.474, + 81.942, + 81.26, + 80.525, + 79.851, + 79.319, + 78.904, + 81.595, + 81.221, + 81.035, + 80.886, + 80.722, + 80.593, + 80.601, + 80.828, + 81.284, + 81.896, + 82.541, + 83.098, + 83.49, + 83.696, + 83.727, + 83.595, + 83.288, + 82.784, + 82.081, + 81.234, + 80.368, + 79.638, + 79.167, + 78.954, + 81.439, + 81.784, + 82.216, + 82.507, + 82.564, + 82.436, + 82.264, + 82.193, + 82.307, + 82.599, + 82.99, + 83.378, + 83.682, + 83.857, + 83.881, + 83.732, + 83.374, + 82.773, + 81.934, + 80.937, + 79.941, + 79.139, + 78.685, + 78.595, + 78.829, + 79.15, + 79.549, + 79.794, + 79.799, + 79.638, + 79.482, + 79.503, + 79.802, + 80.367, + 81.098, + 81.857, + 82.52, + 83.002, + 83.256, + 83.251, + 82.957, + 82.36, + 81.493, + 80.463, + 79.451, + 78.668, + 78.274, + 78.283, + 76.309, + 76.127, + 76.043, + 75.873, + 75.58, + 75.277, + 75.155, + 75.391, + 76.059, + 77.103, + 78.362, + 79.636, + 80.745, + 81.574, + 82.07, + 82.22, + 82.031, + 81.527, + 80.773, + 79.89, + 79.055, + 78.46, + 78.234, + 78.365, + 76.287, + 75.462, + 74.722, + 73.96, + 73.203, + 72.608, + 72.384, + 72.692, + 73.569, + 74.899, + 76.456, + 77.981, + 79.261, + 80.174, + 80.69, + 80.844, + 80.699, + 80.328, + 79.815, + 79.271, + 78.83, + 78.618, + 78.695, + 78.994, + 80.444, + 79.175, + 77.895, + 76.574, + 75.301, + 74.273, + 73.708, + 73.754, + 74.41, + 75.518, + 76.813, + 78.012, + 78.905, + 79.401, + 79.532, + 79.402, + 79.144, + 78.872, + 78.674, + 78.616, + 78.746, + 79.081, + 79.576, + 80.082, + 77.239, + ], + ]) def load_exact_std(): - return nd.array( + return nd.array([ + [ + 62.202, + 61.116, + 60.813, + 60.657, + 60.515, + 60.403, + 60.333, + 60.297, + 60.282, + 60.276, + 60.275, + 60.275, + 60.276, + 60.282, + 60.297, + 60.333, + 60.403, + 60.515, + 60.657, + 60.813, + 61.116, + 62.202, + 65.704, + 62.155, + 59.724, + 59.199, + 59.129, + 59.061, + 58.969, + 58.9, + 58.866, + 58.85, + 58.838, + 58.826, + 58.819, + 58.819, + 58.826, + 58.838, + 58.85, + 58.866, + 58.9, + 58.969, + 59.061, + 59.129, + 59.199, + 59.724, + 62.155, + 61.374, + 59.422, + 59.048, + 58.973, + 58.87, + 58.764, + 58.706, + 58.692, + 58.689, + 58.68, + 58.665, + 58.654, + 58.654, + 58.665, + 58.68, + 58.689, + 58.692, + 58.706, + 58.764, + 58.87, + 58.973, + 59.048, + 59.422, + 61.374, + 61.321, + 59.489, + 59.128, + 59.023, + 58.892, + 58.773, + 58.716, + 58.708, + 58.711, + 58.704, + 58.688, + 58.675, + 58.675, + 58.688, + 58.704, + 58.711, + 58.708, + 58.716, + 58.773, + 58.892, + 59.023, + 59.128, + 59.489, + 61.321, + 61.374, + 59.422, + 59.048, + 58.973, + 58.87, + 58.764, + 58.706, + 58.692, + 58.689, + 58.68, + 58.665, + 58.654, + 58.654, + 58.665, + 58.68, + 58.689, + 58.692, + 58.706, + 58.764, + 58.87, + 58.973, + 59.048, + 59.422, + 61.374, + 62.155, + 59.724, + 59.199, + 59.129, + 59.061, + 58.969, + 58.9, + 58.866, + 58.85, + 58.838, + 58.826, + 58.819, + 58.819, + 58.826, + 58.838, + 58.85, + 58.866, + 58.9, + 58.969, + 59.061, + 59.129, + 59.199, + 59.724, + 62.155, + 65.704, + 62.202, + 61.116, + 60.813, + 60.657, + 60.515, + 60.403, + 60.333, + 60.297, + 60.282, + 60.276, + 60.275, + 60.275, + 60.276, + 60.282, + 60.297, + 60.333, + 60.403, + 60.515, + 60.657, + 60.813, + 61.116, + 62.202, + 65.704, + 65.704, + ], + [ + 4.6693, + 4.6057, + 4.5913, + 4.5803, + 4.5721, + 4.5684, + 4.5667, + 4.5655, + 4.5644, + 4.5637, + 4.5633, + 4.5633, + 4.5637, + 4.5644, + 4.5655, + 4.5667, + 4.5684, + 4.5721, + 4.5803, + 4.5913, + 4.6057, + 4.6693, + 4.9388, + 4.6701, + 4.4794, + 4.4525, + 4.4488, + 4.4409, + 4.4349, + 4.4328, + 4.4317, + 4.4306, + 4.4299, + 4.4295, + 4.4293, + 4.4293, + 4.4295, + 4.4299, + 4.4306, + 4.4317, + 4.4328, + 4.4349, + 4.4409, + 4.4488, + 4.4525, + 4.4794, + 4.6701, + 4.618, + 4.4607, + 4.44, + 4.4336, + 4.4245, + 4.4199, + 4.4191, + 4.4183, + 4.4169, + 4.4162, + 4.416, + 4.416, + 4.416, + 4.416, + 4.4162, + 4.4169, + 4.4183, + 4.4191, + 4.4199, + 4.4245, + 4.4336, + 4.44, + 4.4607, + 4.618, + 4.616, + 4.4654, + 4.4433, + 4.434, + 4.4239, + 4.4199, + 4.4198, + 4.419, + 4.4175, + 4.4167, + 4.4166, + 4.4166, + 4.4166, + 4.4166, + 4.4167, + 4.4175, + 4.419, + 4.4198, + 4.4199, + 4.4239, + 4.434, + 4.4433, + 4.4654, + 4.616, + 4.618, + 4.4607, + 4.44, + 4.4336, + 4.4245, + 4.4199, + 4.4191, + 4.4183, + 4.4169, + 4.4162, + 4.416, + 4.416, + 4.416, + 4.416, + 4.4162, + 4.4169, + 4.4183, + 4.4191, + 4.4199, + 4.4245, + 4.4336, + 4.44, + 4.4607, + 4.618, + 4.6701, + 4.4794, + 4.4525, + 4.4488, + 4.4409, + 4.4349, + 4.4328, + 4.4317, + 4.4306, + 4.4299, + 4.4295, + 4.4293, + 4.4293, + 4.4295, + 4.4299, + 4.4306, + 4.4317, + 4.4328, + 4.4349, + 4.4409, + 4.4488, + 4.4525, + 4.4794, + 4.6701, + 4.9388, + 4.6693, + 4.6057, + 4.5913, + 4.5803, + 4.5721, + 4.5684, + 4.5667, + 4.5655, + 4.5644, + 4.5637, + 4.5633, + 4.5633, + 4.5637, + 4.5644, + 4.5655, + 4.5667, + 4.5684, + 4.5721, + 4.5803, + 4.5913, + 4.6057, + 4.6693, + 4.9388, + 4.9388, + ], + [ + 10.6, + 10.437, + 10.384, + 10.335, + 10.304, + 10.29, + 10.28, + 10.271, + 10.264, + 10.261, + 10.259, + 10.259, + 10.261, + 10.264, + 10.271, + 10.28, + 10.29, + 10.304, + 10.335, + 10.384, + 10.437, + 10.6, + 11.514, + 10.618, + 10.069, + 10.028, + 10.002, + 9.9658, + 9.9464, + 9.9396, + 9.9341, + 9.9288, + 9.9256, + 9.9239, + 9.9228, + 9.9228, + 9.9239, + 9.9256, + 9.9288, + 9.9341, + 9.9396, + 9.9464, + 9.9658, + 10.002, + 10.028, + 10.069, + 10.618, + 10.505, + 10.051, + 10.003, + 9.9655, + 9.9329, + 9.9221, + 9.9184, + 9.9114, + 9.9045, + 9.9017, + 9.9012, + 9.9009, + 9.9009, + 9.9012, + 9.9017, + 9.9045, + 9.9114, + 9.9184, + 9.9221, + 9.9329, + 9.9655, + 10.003, + 10.051, + 10.505, + 10.492, + 10.043, + 9.9813, + 9.9381, + 9.9089, + 9.9031, + 9.9013, + 9.8936, + 9.8858, + 9.8829, + 9.8828, + 9.8828, + 9.8828, + 9.8828, + 9.8829, + 9.8858, + 9.8936, + 9.9013, + 9.9031, + 9.9089, + 9.9381, + 9.9813, + 10.043, + 10.492, + 10.505, + 10.051, + 10.003, + 9.9655, + 9.9329, + 9.9221, + 9.9184, + 9.9114, + 9.9045, + 9.9017, + 9.9012, + 9.9009, + 9.9009, + 9.9012, + 9.9017, + 9.9045, + 9.9114, + 9.9184, + 9.9221, + 9.9329, + 9.9655, + 10.003, + 10.051, + 10.505, + 10.618, + 10.069, + 10.028, + 10.002, + 9.9658, + 9.9464, + 9.9396, + 9.9341, + 9.9288, + 9.9256, + 9.9239, + 9.9228, + 9.9228, + 9.9239, + 9.9256, + 9.9288, + 9.9341, + 9.9396, + 9.9464, + 9.9658, + 10.002, + 10.028, + 10.069, + 10.618, + 11.514, + 10.6, + 10.437, + 10.384, + 10.335, + 10.304, + 10.29, + 10.28, + 10.271, + 10.264, + 10.261, + 10.259, + 10.259, + 10.261, + 10.264, + 10.271, + 10.28, + 10.29, + 10.304, + 10.335, + 10.384, + 10.437, + 10.6, + 11.514, + 11.514, + ], + [ + 5.8917, + 5.847, + 5.803, + 5.7866, + 5.7789, + 5.7729, + 5.7707, + 5.7687, + 5.7674, + 5.7671, + 5.7668, + 5.7668, + 5.7671, + 5.7674, + 5.7687, + 5.7707, + 5.7729, + 5.7789, + 5.7866, + 5.803, + 5.847, + 5.8917, + 6.4598, + 5.9005, + 5.5807, + 5.5595, + 5.5222, + 5.5167, + 5.5118, + 5.507, + 5.5059, + 5.5046, + 5.5038, + 5.5035, + 5.5033, + 5.5033, + 5.5035, + 5.5038, + 5.5046, + 5.5059, + 5.507, + 5.5118, + 5.5167, + 5.5222, + 5.5595, + 5.5807, + 5.9005, + 5.8645, + 5.566, + 5.5328, + 5.502, + 5.4976, + 5.4912, + 5.4879, + 5.487, + 5.4852, + 5.4848, + 5.4846, + 5.4843, + 5.4843, + 5.4846, + 5.4848, + 5.4852, + 5.487, + 5.4879, + 5.4912, + 5.4976, + 5.502, + 5.5328, + 5.566, + 5.8645, + 5.8339, + 5.534, + 5.507, + 5.4832, + 5.4774, + 5.4694, + 5.467, + 5.4663, + 5.4643, + 5.4639, + 5.4638, + 5.4634, + 5.4634, + 5.4638, + 5.4639, + 5.4643, + 5.4663, + 5.467, + 5.4694, + 5.4774, + 5.4832, + 5.507, + 5.534, + 5.8339, + 5.8645, + 5.566, + 5.5328, + 5.502, + 5.4976, + 5.4912, + 5.4879, + 5.487, + 5.4852, + 5.4848, + 5.4846, + 5.4843, + 5.4843, + 5.4846, + 5.4848, + 5.4852, + 5.487, + 5.4879, + 5.4912, + 5.4976, + 5.502, + 5.5328, + 5.566, + 5.8645, + 5.9005, + 5.5807, + 5.5595, + 5.5222, + 5.5167, + 5.5118, + 5.507, + 5.5059, + 5.5046, + 5.5038, + 5.5035, + 5.5033, + 5.5033, + 5.5035, + 5.5038, + 5.5046, + 5.5059, + 5.507, + 5.5118, + 5.5167, + 5.5222, + 5.5595, + 5.5807, + 5.9005, + 6.4598, + 5.8917, + 5.847, + 5.803, + 5.7866, + 5.7789, + 5.7729, + 5.7707, + 5.7687, + 5.7674, + 5.7671, + 5.7668, + 5.7668, + 5.7671, + 5.7674, + 5.7687, + 5.7707, + 5.7729, + 5.7789, + 5.7866, + 5.803, + 5.847, + 5.8917, + 6.4598, + 6.4598, + ], + [ + 5.6763, + 5.5919, + 5.567, + 5.5445, + 5.5302, + 5.5238, + 5.5196, + 5.5158, + 5.513, + 5.5114, + 5.5108, + 5.5108, + 5.5114, + 5.513, + 5.5158, + 5.5196, + 5.5238, + 5.5302, + 5.5445, + 5.567, + 5.5919, + 5.6763, + 6.1243, + 5.6824, + 5.4021, + 5.3783, + 5.3679, + 5.3509, + 5.3416, + 5.3385, + 5.336, + 5.3336, + 5.3322, + 5.3315, + 5.3309, + 5.3309, + 5.3315, + 5.3322, + 5.3336, + 5.336, + 5.3385, + 5.3416, + 5.3509, + 5.3679, + 5.3783, + 5.4021, + 5.6824, + 5.6208, + 5.3897, + 5.3653, + 5.3488, + 5.3329, + 5.3275, + 5.3261, + 5.3231, + 5.3201, + 5.3189, + 5.3187, + 5.3184, + 5.3184, + 5.3187, + 5.3189, + 5.3201, + 5.3231, + 5.3261, + 5.3275, + 5.3329, + 5.3488, + 5.3653, + 5.3897, + 5.6208, + 5.617, + 5.3904, + 5.3595, + 5.339, + 5.324, + 5.3207, + 5.3202, + 5.3169, + 5.3134, + 5.3122, + 5.3121, + 5.312, + 5.312, + 5.3121, + 5.3122, + 5.3134, + 5.3169, + 5.3202, + 5.3207, + 5.324, + 5.339, + 5.3595, + 5.3904, + 5.617, + 5.6208, + 5.3897, + 5.3653, + 5.3488, + 5.3329, + 5.3275, + 5.3261, + 5.3231, + 5.3201, + 5.3189, + 5.3187, + 5.3184, + 5.3184, + 5.3187, + 5.3189, + 5.3201, + 5.3231, + 5.3261, + 5.3275, + 5.3329, + 5.3488, + 5.3653, + 5.3897, + 5.6208, + 5.6824, + 5.4021, + 5.3783, + 5.3679, + 5.3509, + 5.3416, + 5.3385, + 5.336, + 5.3336, + 5.3322, + 5.3315, + 5.3309, + 5.3309, + 5.3315, + 5.3322, + 5.3336, + 5.336, + 5.3385, + 5.3416, + 5.3509, + 5.3679, + 5.3783, + 5.4021, + 5.6824, + 6.1243, + 5.6763, + 5.5919, + 5.567, + 5.5445, + 5.5302, + 5.5238, + 5.5196, + 5.5158, + 5.513, + 5.5114, + 5.5108, + 5.5108, + 5.5114, + 5.513, + 5.5158, + 5.5196, + 5.5238, + 5.5302, + 5.5445, + 5.567, + 5.5919, + 5.6763, + 6.1243, + 6.1243, + ], [ - [ - 62.202, - 61.116, - 60.813, - 60.657, - 60.515, - 60.403, - 60.333, - 60.297, - 60.282, - 60.276, - 60.275, - 60.275, - 60.276, - 60.282, - 60.297, - 60.333, - 60.403, - 60.515, - 60.657, - 60.813, - 61.116, - 62.202, - 65.704, - 62.155, - 59.724, - 59.199, - 59.129, - 59.061, - 58.969, - 58.9, - 58.866, - 58.85, - 58.838, - 58.826, - 58.819, - 58.819, - 58.826, - 58.838, - 58.85, - 58.866, - 58.9, - 58.969, - 59.061, - 59.129, - 59.199, - 59.724, - 62.155, - 61.374, - 59.422, - 59.048, - 58.973, - 58.87, - 58.764, - 58.706, - 58.692, - 58.689, - 58.68, - 58.665, - 58.654, - 58.654, - 58.665, - 58.68, - 58.689, - 58.692, - 58.706, - 58.764, - 58.87, - 58.973, - 59.048, - 59.422, - 61.374, - 61.321, - 59.489, - 59.128, - 59.023, - 58.892, - 58.773, - 58.716, - 58.708, - 58.711, - 58.704, - 58.688, - 58.675, - 58.675, - 58.688, - 58.704, - 58.711, - 58.708, - 58.716, - 58.773, - 58.892, - 59.023, - 59.128, - 59.489, - 61.321, - 61.374, - 59.422, - 59.048, - 58.973, - 58.87, - 58.764, - 58.706, - 58.692, - 58.689, - 58.68, - 58.665, - 58.654, - 58.654, - 58.665, - 58.68, - 58.689, - 58.692, - 58.706, - 58.764, - 58.87, - 58.973, - 59.048, - 59.422, - 61.374, - 62.155, - 59.724, - 59.199, - 59.129, - 59.061, - 58.969, - 58.9, - 58.866, - 58.85, - 58.838, - 58.826, - 58.819, - 58.819, - 58.826, - 58.838, - 58.85, - 58.866, - 58.9, - 58.969, - 59.061, - 59.129, - 59.199, - 59.724, - 62.155, - 65.704, - 62.202, - 61.116, - 60.813, - 60.657, - 60.515, - 60.403, - 60.333, - 60.297, - 60.282, - 60.276, - 60.275, - 60.275, - 60.276, - 60.282, - 60.297, - 60.333, - 60.403, - 60.515, - 60.657, - 60.813, - 61.116, - 62.202, - 65.704, - 65.704, - ], - [ - 4.6693, - 4.6057, - 4.5913, - 4.5803, - 4.5721, - 4.5684, - 4.5667, - 4.5655, - 4.5644, - 4.5637, - 4.5633, - 4.5633, - 4.5637, - 4.5644, - 4.5655, - 4.5667, - 4.5684, - 4.5721, - 4.5803, - 4.5913, - 4.6057, - 4.6693, - 4.9388, - 4.6701, - 4.4794, - 4.4525, - 4.4488, - 4.4409, - 4.4349, - 4.4328, - 4.4317, - 4.4306, - 4.4299, - 4.4295, - 4.4293, - 4.4293, - 4.4295, - 4.4299, - 4.4306, - 4.4317, - 4.4328, - 4.4349, - 4.4409, - 4.4488, - 4.4525, - 4.4794, - 4.6701, - 4.618, - 4.4607, - 4.44, - 4.4336, - 4.4245, - 4.4199, - 4.4191, - 4.4183, - 4.4169, - 4.4162, - 4.416, - 4.416, - 4.416, - 4.416, - 4.4162, - 4.4169, - 4.4183, - 4.4191, - 4.4199, - 4.4245, - 4.4336, - 4.44, - 4.4607, - 4.618, - 4.616, - 4.4654, - 4.4433, - 4.434, - 4.4239, - 4.4199, - 4.4198, - 4.419, - 4.4175, - 4.4167, - 4.4166, - 4.4166, - 4.4166, - 4.4166, - 4.4167, - 4.4175, - 4.419, - 4.4198, - 4.4199, - 4.4239, - 4.434, - 4.4433, - 4.4654, - 4.616, - 4.618, - 4.4607, - 4.44, - 4.4336, - 4.4245, - 4.4199, - 4.4191, - 4.4183, - 4.4169, - 4.4162, - 4.416, - 4.416, - 4.416, - 4.416, - 4.4162, - 4.4169, - 4.4183, - 4.4191, - 4.4199, - 4.4245, - 4.4336, - 4.44, - 4.4607, - 4.618, - 4.6701, - 4.4794, - 4.4525, - 4.4488, - 4.4409, - 4.4349, - 4.4328, - 4.4317, - 4.4306, - 4.4299, - 4.4295, - 4.4293, - 4.4293, - 4.4295, - 4.4299, - 4.4306, - 4.4317, - 4.4328, - 4.4349, - 4.4409, - 4.4488, - 4.4525, - 4.4794, - 4.6701, - 4.9388, - 4.6693, - 4.6057, - 4.5913, - 4.5803, - 4.5721, - 4.5684, - 4.5667, - 4.5655, - 4.5644, - 4.5637, - 4.5633, - 4.5633, - 4.5637, - 4.5644, - 4.5655, - 4.5667, - 4.5684, - 4.5721, - 4.5803, - 4.5913, - 4.6057, - 4.6693, - 4.9388, - 4.9388, - ], - [ - 10.6, - 10.437, - 10.384, - 10.335, - 10.304, - 10.29, - 10.28, - 10.271, - 10.264, - 10.261, - 10.259, - 10.259, - 10.261, - 10.264, - 10.271, - 10.28, - 10.29, - 10.304, - 10.335, - 10.384, - 10.437, - 10.6, - 11.514, - 10.618, - 10.069, - 10.028, - 10.002, - 9.9658, - 9.9464, - 9.9396, - 9.9341, - 9.9288, - 9.9256, - 9.9239, - 9.9228, - 9.9228, - 9.9239, - 9.9256, - 9.9288, - 9.9341, - 9.9396, - 9.9464, - 9.9658, - 10.002, - 10.028, - 10.069, - 10.618, - 10.505, - 10.051, - 10.003, - 9.9655, - 9.9329, - 9.9221, - 9.9184, - 9.9114, - 9.9045, - 9.9017, - 9.9012, - 9.9009, - 9.9009, - 9.9012, - 9.9017, - 9.9045, - 9.9114, - 9.9184, - 9.9221, - 9.9329, - 9.9655, - 10.003, - 10.051, - 10.505, - 10.492, - 10.043, - 9.9813, - 9.9381, - 9.9089, - 9.9031, - 9.9013, - 9.8936, - 9.8858, - 9.8829, - 9.8828, - 9.8828, - 9.8828, - 9.8828, - 9.8829, - 9.8858, - 9.8936, - 9.9013, - 9.9031, - 9.9089, - 9.9381, - 9.9813, - 10.043, - 10.492, - 10.505, - 10.051, - 10.003, - 9.9655, - 9.9329, - 9.9221, - 9.9184, - 9.9114, - 9.9045, - 9.9017, - 9.9012, - 9.9009, - 9.9009, - 9.9012, - 9.9017, - 9.9045, - 9.9114, - 9.9184, - 9.9221, - 9.9329, - 9.9655, - 10.003, - 10.051, - 10.505, - 10.618, - 10.069, - 10.028, - 10.002, - 9.9658, - 9.9464, - 9.9396, - 9.9341, - 9.9288, - 9.9256, - 9.9239, - 9.9228, - 9.9228, - 9.9239, - 9.9256, - 9.9288, - 9.9341, - 9.9396, - 9.9464, - 9.9658, - 10.002, - 10.028, - 10.069, - 10.618, - 11.514, - 10.6, - 10.437, - 10.384, - 10.335, - 10.304, - 10.29, - 10.28, - 10.271, - 10.264, - 10.261, - 10.259, - 10.259, - 10.261, - 10.264, - 10.271, - 10.28, - 10.29, - 10.304, - 10.335, - 10.384, - 10.437, - 10.6, - 11.514, - 11.514, - ], - [ - 5.8917, - 5.847, - 5.803, - 5.7866, - 5.7789, - 5.7729, - 5.7707, - 5.7687, - 5.7674, - 5.7671, - 5.7668, - 5.7668, - 5.7671, - 5.7674, - 5.7687, - 5.7707, - 5.7729, - 5.7789, - 5.7866, - 5.803, - 5.847, - 5.8917, - 6.4598, - 5.9005, - 5.5807, - 5.5595, - 5.5222, - 5.5167, - 5.5118, - 5.507, - 5.5059, - 5.5046, - 5.5038, - 5.5035, - 5.5033, - 5.5033, - 5.5035, - 5.5038, - 5.5046, - 5.5059, - 5.507, - 5.5118, - 5.5167, - 5.5222, - 5.5595, - 5.5807, - 5.9005, - 5.8645, - 5.566, - 5.5328, - 5.502, - 5.4976, - 5.4912, - 5.4879, - 5.487, - 5.4852, - 5.4848, - 5.4846, - 5.4843, - 5.4843, - 5.4846, - 5.4848, - 5.4852, - 5.487, - 5.4879, - 5.4912, - 5.4976, - 5.502, - 5.5328, - 5.566, - 5.8645, - 5.8339, - 5.534, - 5.507, - 5.4832, - 5.4774, - 5.4694, - 5.467, - 5.4663, - 5.4643, - 5.4639, - 5.4638, - 5.4634, - 5.4634, - 5.4638, - 5.4639, - 5.4643, - 5.4663, - 5.467, - 5.4694, - 5.4774, - 5.4832, - 5.507, - 5.534, - 5.8339, - 5.8645, - 5.566, - 5.5328, - 5.502, - 5.4976, - 5.4912, - 5.4879, - 5.487, - 5.4852, - 5.4848, - 5.4846, - 5.4843, - 5.4843, - 5.4846, - 5.4848, - 5.4852, - 5.487, - 5.4879, - 5.4912, - 5.4976, - 5.502, - 5.5328, - 5.566, - 5.8645, - 5.9005, - 5.5807, - 5.5595, - 5.5222, - 5.5167, - 5.5118, - 5.507, - 5.5059, - 5.5046, - 5.5038, - 5.5035, - 5.5033, - 5.5033, - 5.5035, - 5.5038, - 5.5046, - 5.5059, - 5.507, - 5.5118, - 5.5167, - 5.5222, - 5.5595, - 5.5807, - 5.9005, - 6.4598, - 5.8917, - 5.847, - 5.803, - 5.7866, - 5.7789, - 5.7729, - 5.7707, - 5.7687, - 5.7674, - 5.7671, - 5.7668, - 5.7668, - 5.7671, - 5.7674, - 5.7687, - 5.7707, - 5.7729, - 5.7789, - 5.7866, - 5.803, - 5.847, - 5.8917, - 6.4598, - 6.4598, - ], - [ - 5.6763, - 5.5919, - 5.567, - 5.5445, - 5.5302, - 5.5238, - 5.5196, - 5.5158, - 5.513, - 5.5114, - 5.5108, - 5.5108, - 5.5114, - 5.513, - 5.5158, - 5.5196, - 5.5238, - 5.5302, - 5.5445, - 5.567, - 5.5919, - 5.6763, - 6.1243, - 5.6824, - 5.4021, - 5.3783, - 5.3679, - 5.3509, - 5.3416, - 5.3385, - 5.336, - 5.3336, - 5.3322, - 5.3315, - 5.3309, - 5.3309, - 5.3315, - 5.3322, - 5.3336, - 5.336, - 5.3385, - 5.3416, - 5.3509, - 5.3679, - 5.3783, - 5.4021, - 5.6824, - 5.6208, - 5.3897, - 5.3653, - 5.3488, - 5.3329, - 5.3275, - 5.3261, - 5.3231, - 5.3201, - 5.3189, - 5.3187, - 5.3184, - 5.3184, - 5.3187, - 5.3189, - 5.3201, - 5.3231, - 5.3261, - 5.3275, - 5.3329, - 5.3488, - 5.3653, - 5.3897, - 5.6208, - 5.617, - 5.3904, - 5.3595, - 5.339, - 5.324, - 5.3207, - 5.3202, - 5.3169, - 5.3134, - 5.3122, - 5.3121, - 5.312, - 5.312, - 5.3121, - 5.3122, - 5.3134, - 5.3169, - 5.3202, - 5.3207, - 5.324, - 5.339, - 5.3595, - 5.3904, - 5.617, - 5.6208, - 5.3897, - 5.3653, - 5.3488, - 5.3329, - 5.3275, - 5.3261, - 5.3231, - 5.3201, - 5.3189, - 5.3187, - 5.3184, - 5.3184, - 5.3187, - 5.3189, - 5.3201, - 5.3231, - 5.3261, - 5.3275, - 5.3329, - 5.3488, - 5.3653, - 5.3897, - 5.6208, - 5.6824, - 5.4021, - 5.3783, - 5.3679, - 5.3509, - 5.3416, - 5.3385, - 5.336, - 5.3336, - 5.3322, - 5.3315, - 5.3309, - 5.3309, - 5.3315, - 5.3322, - 5.3336, - 5.336, - 5.3385, - 5.3416, - 5.3509, - 5.3679, - 5.3783, - 5.4021, - 5.6824, - 6.1243, - 5.6763, - 5.5919, - 5.567, - 5.5445, - 5.5302, - 5.5238, - 5.5196, - 5.5158, - 5.513, - 5.5114, - 5.5108, - 5.5108, - 5.5114, - 5.513, - 5.5158, - 5.5196, - 5.5238, - 5.5302, - 5.5445, - 5.567, - 5.5919, - 5.6763, - 6.1243, - 6.1243, - ], - [ - 15.404, - 15.259, - 15.168, - 15.109, - 15.087, - 15.07, - 15.059, - 15.055, - 15.051, - 15.049, - 15.047, - 15.047, - 15.049, - 15.051, - 15.055, - 15.059, - 15.07, - 15.087, - 15.109, - 15.168, - 15.259, - 15.404, - 16.812, - 15.43, - 14.606, - 14.573, - 14.493, - 14.455, - 14.448, - 14.437, - 14.429, - 14.426, - 14.424, - 14.422, - 14.422, - 14.422, - 14.422, - 14.424, - 14.426, - 14.429, - 14.437, - 14.448, - 14.455, - 14.493, - 14.573, - 14.606, - 15.43, - 15.313, - 14.587, - 14.524, - 14.447, - 14.422, - 14.413, - 14.399, - 14.395, - 14.394, - 14.39, - 14.388, - 14.389, - 14.389, - 14.388, - 14.39, - 14.394, - 14.395, - 14.399, - 14.413, - 14.422, - 14.447, - 14.524, - 14.587, - 15.313, - 15.264, - 14.528, - 14.456, - 14.393, - 14.377, - 14.366, - 14.349, - 14.346, - 14.345, - 14.341, - 14.339, - 14.34, - 14.34, - 14.339, - 14.341, - 14.345, - 14.346, - 14.349, - 14.366, - 14.377, - 14.393, - 14.456, - 14.528, - 15.264, - 15.313, - 14.587, - 14.524, - 14.447, - 14.422, - 14.413, - 14.399, - 14.395, - 14.394, - 14.39, - 14.388, - 14.389, - 14.389, - 14.388, - 14.39, - 14.394, - 14.395, - 14.399, - 14.413, - 14.422, - 14.447, - 14.524, - 14.587, - 15.313, - 15.43, - 14.606, - 14.573, - 14.493, - 14.455, - 14.448, - 14.437, - 14.429, - 14.426, - 14.424, - 14.422, - 14.422, - 14.422, - 14.422, - 14.424, - 14.426, - 14.429, - 14.437, - 14.448, - 14.455, - 14.493, - 14.573, - 14.606, - 15.43, - 16.812, - 15.404, - 15.259, - 15.168, - 15.109, - 15.087, - 15.07, - 15.059, - 15.055, - 15.051, - 15.049, - 15.047, - 15.047, - 15.049, - 15.051, - 15.055, - 15.059, - 15.07, - 15.087, - 15.109, - 15.168, - 15.259, - 15.404, - 16.812, - 16.812, - ], - [ - 11.994, - 11.873, - 11.788, - 11.733, - 11.712, - 11.695, - 11.683, - 11.678, - 11.675, - 11.672, - 11.67, - 11.67, - 11.672, - 11.675, - 11.678, - 11.683, - 11.695, - 11.712, - 11.733, - 11.788, - 11.873, - 11.994, - 13.178, - 12.023, - 11.358, - 11.328, - 11.255, - 11.221, - 11.214, - 11.204, - 11.196, - 11.194, - 11.191, - 11.189, - 11.189, - 11.189, - 11.189, - 11.191, - 11.194, - 11.196, - 11.204, - 11.214, - 11.221, - 11.255, - 11.328, - 11.358, - 12.023, - 11.93, - 11.341, - 11.283, - 11.215, - 11.193, - 11.185, - 11.171, - 11.166, - 11.165, - 11.162, - 11.16, - 11.16, - 11.16, - 11.16, - 11.162, - 11.165, - 11.166, - 11.171, - 11.185, - 11.193, - 11.215, - 11.283, - 11.341, - 11.93, - 11.881, - 11.284, - 11.222, - 11.168, - 11.154, - 11.144, - 11.128, - 11.124, - 11.124, - 11.12, - 11.118, - 11.118, - 11.118, - 11.118, - 11.12, - 11.124, - 11.124, - 11.128, - 11.144, - 11.154, - 11.168, - 11.222, - 11.284, - 11.881, - 11.93, - 11.341, - 11.283, - 11.215, - 11.193, - 11.185, - 11.171, - 11.166, - 11.165, - 11.162, - 11.16, - 11.16, - 11.16, - 11.16, - 11.162, - 11.165, - 11.166, - 11.171, - 11.185, - 11.193, - 11.215, - 11.283, - 11.341, - 11.93, - 12.023, - 11.358, - 11.328, - 11.255, - 11.221, - 11.214, - 11.204, - 11.196, - 11.194, - 11.191, - 11.189, - 11.189, - 11.189, - 11.189, - 11.191, - 11.194, - 11.196, - 11.204, - 11.214, - 11.221, - 11.255, - 11.328, - 11.358, - 12.023, - 13.178, - 11.994, - 11.873, - 11.788, - 11.733, - 11.712, - 11.695, - 11.683, - 11.678, - 11.675, - 11.672, - 11.67, - 11.67, - 11.672, - 11.675, - 11.678, - 11.683, - 11.695, - 11.712, - 11.733, - 11.788, - 11.873, - 11.994, - 13.178, - 13.178, - ], - [ - 45.261, - 44.52, - 44.323, - 44.162, - 44.037, - 43.973, - 43.944, - 43.922, - 43.9, - 43.881, - 43.871, - 43.871, - 43.881, - 43.9, - 43.922, - 43.944, - 43.973, - 44.037, - 44.162, - 44.323, - 44.52, - 45.261, - 48.529, - 45.281, - 43.171, - 42.914, - 42.863, - 42.755, - 42.67, - 42.636, - 42.62, - 42.603, - 42.588, - 42.581, - 42.58, - 42.58, - 42.581, - 42.588, - 42.603, - 42.62, - 42.636, - 42.67, - 42.755, - 42.863, - 42.914, - 43.171, - 45.281, - 44.753, - 43.038, - 42.82, - 42.723, - 42.601, - 42.538, - 42.526, - 42.517, - 42.495, - 42.476, - 42.471, - 42.475, - 42.475, - 42.471, - 42.476, - 42.495, - 42.517, - 42.526, - 42.538, - 42.601, - 42.723, - 42.82, - 43.038, - 44.753, - 44.736, - 43.082, - 42.828, - 42.692, - 42.561, - 42.506, - 42.505, - 42.499, - 42.475, - 42.454, - 42.449, - 42.453, - 42.453, - 42.449, - 42.454, - 42.475, - 42.499, - 42.505, - 42.506, - 42.561, - 42.692, - 42.828, - 43.082, - 44.736, - 44.753, - 43.038, - 42.82, - 42.723, - 42.601, - 42.538, - 42.526, - 42.517, - 42.495, - 42.476, - 42.471, - 42.475, - 42.475, - 42.471, - 42.476, - 42.495, - 42.517, - 42.526, - 42.538, - 42.601, - 42.723, - 42.82, - 43.038, - 44.753, - 45.281, - 43.171, - 42.914, - 42.863, - 42.755, - 42.67, - 42.636, - 42.62, - 42.603, - 42.588, - 42.581, - 42.58, - 42.58, - 42.581, - 42.588, - 42.603, - 42.62, - 42.636, - 42.67, - 42.755, - 42.863, - 42.914, - 43.171, - 45.281, - 48.529, - 45.261, - 44.52, - 44.323, - 44.162, - 44.037, - 43.973, - 43.944, - 43.922, - 43.9, - 43.881, - 43.871, - 43.871, - 43.881, - 43.9, - 43.922, - 43.944, - 43.973, - 44.037, - 44.162, - 44.323, - 44.52, - 45.261, - 48.529, - 48.529, - ], - [ - 4.5187, - 4.4744, - 4.4501, - 4.433, - 4.4268, - 4.4222, - 4.4187, - 4.4174, - 4.4167, - 4.4159, - 4.4154, - 4.4154, - 4.4159, - 4.4167, - 4.4174, - 4.4187, - 4.4222, - 4.4268, - 4.433, - 4.4501, - 4.4744, - 4.5187, - 4.9177, - 4.5256, - 4.2875, - 4.2783, - 4.2574, - 4.2455, - 4.2434, - 4.2405, - 4.2382, - 4.2374, - 4.2368, - 4.2362, - 4.2361, - 4.2361, - 4.2362, - 4.2368, - 4.2374, - 4.2382, - 4.2405, - 4.2434, - 4.2455, - 4.2574, - 4.2783, - 4.2875, - 4.5256, - 4.4897, - 4.2819, - 4.2649, - 4.2439, - 4.2359, - 4.2338, - 4.2301, - 4.2283, - 4.2281, - 4.2273, - 4.2267, - 4.2267, - 4.2267, - 4.2267, - 4.2273, - 4.2281, - 4.2283, - 4.2301, - 4.2338, - 4.2359, - 4.2439, - 4.2649, - 4.2819, - 4.4897, - 4.4781, - 4.2679, - 4.2471, - 4.2289, - 4.2238, - 4.2214, - 4.2168, - 4.2152, - 4.2153, - 4.2144, - 4.2137, - 4.2138, - 4.2138, - 4.2137, - 4.2144, - 4.2153, - 4.2152, - 4.2168, - 4.2214, - 4.2238, - 4.2289, - 4.2471, - 4.2679, - 4.4781, - 4.4897, - 4.2819, - 4.2649, - 4.2439, - 4.2359, - 4.2338, - 4.2301, - 4.2283, - 4.2281, - 4.2273, - 4.2267, - 4.2267, - 4.2267, - 4.2267, - 4.2273, - 4.2281, - 4.2283, - 4.2301, - 4.2338, - 4.2359, - 4.2439, - 4.2649, - 4.2819, - 4.4897, - 4.5256, - 4.2875, - 4.2783, - 4.2574, - 4.2455, - 4.2434, - 4.2405, - 4.2382, - 4.2374, - 4.2368, - 4.2362, - 4.2361, - 4.2361, - 4.2362, - 4.2368, - 4.2374, - 4.2382, - 4.2405, - 4.2434, - 4.2455, - 4.2574, - 4.2783, - 4.2875, - 4.5256, - 4.9177, - 4.5187, - 4.4744, - 4.4501, - 4.433, - 4.4268, - 4.4222, - 4.4187, - 4.4174, - 4.4167, - 4.4159, - 4.4154, - 4.4154, - 4.4159, - 4.4167, - 4.4174, - 4.4187, - 4.4222, - 4.4268, - 4.433, - 4.4501, - 4.4744, - 4.5187, - 4.9177, - 4.9177, - ], - [ - 1.3486, - 1.3332, - 1.3187, - 1.3103, - 1.3068, - 1.3037, - 1.3015, - 1.3006, - 1.2999, - 1.2991, - 1.2986, - 1.2986, - 1.2991, - 1.2999, - 1.3006, - 1.3015, - 1.3037, - 1.3068, - 1.3103, - 1.3187, - 1.3332, - 1.3486, - 1.506, - 1.3566, - 1.2758, - 1.2693, - 1.2573, - 1.2531, - 1.2522, - 1.2503, - 1.2489, - 1.2484, - 1.2479, - 1.2475, - 1.2474, - 1.2474, - 1.2475, - 1.2479, - 1.2484, - 1.2489, - 1.2503, - 1.2522, - 1.2531, - 1.2573, - 1.2693, - 1.2758, - 1.3566, - 1.345, - 1.2709, - 1.2607, - 1.2507, - 1.248, - 1.2463, - 1.2442, - 1.2434, - 1.2432, - 1.2425, - 1.242, - 1.242, - 1.242, - 1.242, - 1.2425, - 1.2432, - 1.2434, - 1.2442, - 1.2463, - 1.248, - 1.2507, - 1.2607, - 1.2709, - 1.345, - 1.335, - 1.2613, - 1.2528, - 1.2457, - 1.2438, - 1.2416, - 1.239, - 1.2384, - 1.2384, - 1.2377, - 1.2372, - 1.2373, - 1.2373, - 1.2372, - 1.2377, - 1.2384, - 1.2384, - 1.239, - 1.2416, - 1.2438, - 1.2457, - 1.2528, - 1.2613, - 1.335, - 1.345, - 1.2709, - 1.2607, - 1.2507, - 1.248, - 1.2463, - 1.2442, - 1.2434, - 1.2432, - 1.2425, - 1.242, - 1.242, - 1.242, - 1.242, - 1.2425, - 1.2432, - 1.2434, - 1.2442, - 1.2463, - 1.248, - 1.2507, - 1.2607, - 1.2709, - 1.345, - 1.3566, - 1.2758, - 1.2693, - 1.2573, - 1.2531, - 1.2522, - 1.2503, - 1.2489, - 1.2484, - 1.2479, - 1.2475, - 1.2474, - 1.2474, - 1.2475, - 1.2479, - 1.2484, - 1.2489, - 1.2503, - 1.2522, - 1.2531, - 1.2573, - 1.2693, - 1.2758, - 1.3566, - 1.506, - 1.3486, - 1.3332, - 1.3187, - 1.3103, - 1.3068, - 1.3037, - 1.3015, - 1.3006, - 1.2999, - 1.2991, - 1.2986, - 1.2986, - 1.2991, - 1.2999, - 1.3006, - 1.3015, - 1.3037, - 1.3068, - 1.3103, - 1.3187, - 1.3332, - 1.3486, - 1.506, - 1.506, - ], - ] - ) + 15.404, + 15.259, + 15.168, + 15.109, + 15.087, + 15.07, + 15.059, + 15.055, + 15.051, + 15.049, + 15.047, + 15.047, + 15.049, + 15.051, + 15.055, + 15.059, + 15.07, + 15.087, + 15.109, + 15.168, + 15.259, + 15.404, + 16.812, + 15.43, + 14.606, + 14.573, + 14.493, + 14.455, + 14.448, + 14.437, + 14.429, + 14.426, + 14.424, + 14.422, + 14.422, + 14.422, + 14.422, + 14.424, + 14.426, + 14.429, + 14.437, + 14.448, + 14.455, + 14.493, + 14.573, + 14.606, + 15.43, + 15.313, + 14.587, + 14.524, + 14.447, + 14.422, + 14.413, + 14.399, + 14.395, + 14.394, + 14.39, + 14.388, + 14.389, + 14.389, + 14.388, + 14.39, + 14.394, + 14.395, + 14.399, + 14.413, + 14.422, + 14.447, + 14.524, + 14.587, + 15.313, + 15.264, + 14.528, + 14.456, + 14.393, + 14.377, + 14.366, + 14.349, + 14.346, + 14.345, + 14.341, + 14.339, + 14.34, + 14.34, + 14.339, + 14.341, + 14.345, + 14.346, + 14.349, + 14.366, + 14.377, + 14.393, + 14.456, + 14.528, + 15.264, + 15.313, + 14.587, + 14.524, + 14.447, + 14.422, + 14.413, + 14.399, + 14.395, + 14.394, + 14.39, + 14.388, + 14.389, + 14.389, + 14.388, + 14.39, + 14.394, + 14.395, + 14.399, + 14.413, + 14.422, + 14.447, + 14.524, + 14.587, + 15.313, + 15.43, + 14.606, + 14.573, + 14.493, + 14.455, + 14.448, + 14.437, + 14.429, + 14.426, + 14.424, + 14.422, + 14.422, + 14.422, + 14.422, + 14.424, + 14.426, + 14.429, + 14.437, + 14.448, + 14.455, + 14.493, + 14.573, + 14.606, + 15.43, + 16.812, + 15.404, + 15.259, + 15.168, + 15.109, + 15.087, + 15.07, + 15.059, + 15.055, + 15.051, + 15.049, + 15.047, + 15.047, + 15.049, + 15.051, + 15.055, + 15.059, + 15.07, + 15.087, + 15.109, + 15.168, + 15.259, + 15.404, + 16.812, + 16.812, + ], + [ + 11.994, + 11.873, + 11.788, + 11.733, + 11.712, + 11.695, + 11.683, + 11.678, + 11.675, + 11.672, + 11.67, + 11.67, + 11.672, + 11.675, + 11.678, + 11.683, + 11.695, + 11.712, + 11.733, + 11.788, + 11.873, + 11.994, + 13.178, + 12.023, + 11.358, + 11.328, + 11.255, + 11.221, + 11.214, + 11.204, + 11.196, + 11.194, + 11.191, + 11.189, + 11.189, + 11.189, + 11.189, + 11.191, + 11.194, + 11.196, + 11.204, + 11.214, + 11.221, + 11.255, + 11.328, + 11.358, + 12.023, + 11.93, + 11.341, + 11.283, + 11.215, + 11.193, + 11.185, + 11.171, + 11.166, + 11.165, + 11.162, + 11.16, + 11.16, + 11.16, + 11.16, + 11.162, + 11.165, + 11.166, + 11.171, + 11.185, + 11.193, + 11.215, + 11.283, + 11.341, + 11.93, + 11.881, + 11.284, + 11.222, + 11.168, + 11.154, + 11.144, + 11.128, + 11.124, + 11.124, + 11.12, + 11.118, + 11.118, + 11.118, + 11.118, + 11.12, + 11.124, + 11.124, + 11.128, + 11.144, + 11.154, + 11.168, + 11.222, + 11.284, + 11.881, + 11.93, + 11.341, + 11.283, + 11.215, + 11.193, + 11.185, + 11.171, + 11.166, + 11.165, + 11.162, + 11.16, + 11.16, + 11.16, + 11.16, + 11.162, + 11.165, + 11.166, + 11.171, + 11.185, + 11.193, + 11.215, + 11.283, + 11.341, + 11.93, + 12.023, + 11.358, + 11.328, + 11.255, + 11.221, + 11.214, + 11.204, + 11.196, + 11.194, + 11.191, + 11.189, + 11.189, + 11.189, + 11.189, + 11.191, + 11.194, + 11.196, + 11.204, + 11.214, + 11.221, + 11.255, + 11.328, + 11.358, + 12.023, + 13.178, + 11.994, + 11.873, + 11.788, + 11.733, + 11.712, + 11.695, + 11.683, + 11.678, + 11.675, + 11.672, + 11.67, + 11.67, + 11.672, + 11.675, + 11.678, + 11.683, + 11.695, + 11.712, + 11.733, + 11.788, + 11.873, + 11.994, + 13.178, + 13.178, + ], + [ + 45.261, + 44.52, + 44.323, + 44.162, + 44.037, + 43.973, + 43.944, + 43.922, + 43.9, + 43.881, + 43.871, + 43.871, + 43.881, + 43.9, + 43.922, + 43.944, + 43.973, + 44.037, + 44.162, + 44.323, + 44.52, + 45.261, + 48.529, + 45.281, + 43.171, + 42.914, + 42.863, + 42.755, + 42.67, + 42.636, + 42.62, + 42.603, + 42.588, + 42.581, + 42.58, + 42.58, + 42.581, + 42.588, + 42.603, + 42.62, + 42.636, + 42.67, + 42.755, + 42.863, + 42.914, + 43.171, + 45.281, + 44.753, + 43.038, + 42.82, + 42.723, + 42.601, + 42.538, + 42.526, + 42.517, + 42.495, + 42.476, + 42.471, + 42.475, + 42.475, + 42.471, + 42.476, + 42.495, + 42.517, + 42.526, + 42.538, + 42.601, + 42.723, + 42.82, + 43.038, + 44.753, + 44.736, + 43.082, + 42.828, + 42.692, + 42.561, + 42.506, + 42.505, + 42.499, + 42.475, + 42.454, + 42.449, + 42.453, + 42.453, + 42.449, + 42.454, + 42.475, + 42.499, + 42.505, + 42.506, + 42.561, + 42.692, + 42.828, + 43.082, + 44.736, + 44.753, + 43.038, + 42.82, + 42.723, + 42.601, + 42.538, + 42.526, + 42.517, + 42.495, + 42.476, + 42.471, + 42.475, + 42.475, + 42.471, + 42.476, + 42.495, + 42.517, + 42.526, + 42.538, + 42.601, + 42.723, + 42.82, + 43.038, + 44.753, + 45.281, + 43.171, + 42.914, + 42.863, + 42.755, + 42.67, + 42.636, + 42.62, + 42.603, + 42.588, + 42.581, + 42.58, + 42.58, + 42.581, + 42.588, + 42.603, + 42.62, + 42.636, + 42.67, + 42.755, + 42.863, + 42.914, + 43.171, + 45.281, + 48.529, + 45.261, + 44.52, + 44.323, + 44.162, + 44.037, + 43.973, + 43.944, + 43.922, + 43.9, + 43.881, + 43.871, + 43.871, + 43.881, + 43.9, + 43.922, + 43.944, + 43.973, + 44.037, + 44.162, + 44.323, + 44.52, + 45.261, + 48.529, + 48.529, + ], + [ + 4.5187, + 4.4744, + 4.4501, + 4.433, + 4.4268, + 4.4222, + 4.4187, + 4.4174, + 4.4167, + 4.4159, + 4.4154, + 4.4154, + 4.4159, + 4.4167, + 4.4174, + 4.4187, + 4.4222, + 4.4268, + 4.433, + 4.4501, + 4.4744, + 4.5187, + 4.9177, + 4.5256, + 4.2875, + 4.2783, + 4.2574, + 4.2455, + 4.2434, + 4.2405, + 4.2382, + 4.2374, + 4.2368, + 4.2362, + 4.2361, + 4.2361, + 4.2362, + 4.2368, + 4.2374, + 4.2382, + 4.2405, + 4.2434, + 4.2455, + 4.2574, + 4.2783, + 4.2875, + 4.5256, + 4.4897, + 4.2819, + 4.2649, + 4.2439, + 4.2359, + 4.2338, + 4.2301, + 4.2283, + 4.2281, + 4.2273, + 4.2267, + 4.2267, + 4.2267, + 4.2267, + 4.2273, + 4.2281, + 4.2283, + 4.2301, + 4.2338, + 4.2359, + 4.2439, + 4.2649, + 4.2819, + 4.4897, + 4.4781, + 4.2679, + 4.2471, + 4.2289, + 4.2238, + 4.2214, + 4.2168, + 4.2152, + 4.2153, + 4.2144, + 4.2137, + 4.2138, + 4.2138, + 4.2137, + 4.2144, + 4.2153, + 4.2152, + 4.2168, + 4.2214, + 4.2238, + 4.2289, + 4.2471, + 4.2679, + 4.4781, + 4.4897, + 4.2819, + 4.2649, + 4.2439, + 4.2359, + 4.2338, + 4.2301, + 4.2283, + 4.2281, + 4.2273, + 4.2267, + 4.2267, + 4.2267, + 4.2267, + 4.2273, + 4.2281, + 4.2283, + 4.2301, + 4.2338, + 4.2359, + 4.2439, + 4.2649, + 4.2819, + 4.4897, + 4.5256, + 4.2875, + 4.2783, + 4.2574, + 4.2455, + 4.2434, + 4.2405, + 4.2382, + 4.2374, + 4.2368, + 4.2362, + 4.2361, + 4.2361, + 4.2362, + 4.2368, + 4.2374, + 4.2382, + 4.2405, + 4.2434, + 4.2455, + 4.2574, + 4.2783, + 4.2875, + 4.5256, + 4.9177, + 4.5187, + 4.4744, + 4.4501, + 4.433, + 4.4268, + 4.4222, + 4.4187, + 4.4174, + 4.4167, + 4.4159, + 4.4154, + 4.4154, + 4.4159, + 4.4167, + 4.4174, + 4.4187, + 4.4222, + 4.4268, + 4.433, + 4.4501, + 4.4744, + 4.5187, + 4.9177, + 4.9177, + ], + [ + 1.3486, + 1.3332, + 1.3187, + 1.3103, + 1.3068, + 1.3037, + 1.3015, + 1.3006, + 1.2999, + 1.2991, + 1.2986, + 1.2986, + 1.2991, + 1.2999, + 1.3006, + 1.3015, + 1.3037, + 1.3068, + 1.3103, + 1.3187, + 1.3332, + 1.3486, + 1.506, + 1.3566, + 1.2758, + 1.2693, + 1.2573, + 1.2531, + 1.2522, + 1.2503, + 1.2489, + 1.2484, + 1.2479, + 1.2475, + 1.2474, + 1.2474, + 1.2475, + 1.2479, + 1.2484, + 1.2489, + 1.2503, + 1.2522, + 1.2531, + 1.2573, + 1.2693, + 1.2758, + 1.3566, + 1.345, + 1.2709, + 1.2607, + 1.2507, + 1.248, + 1.2463, + 1.2442, + 1.2434, + 1.2432, + 1.2425, + 1.242, + 1.242, + 1.242, + 1.242, + 1.2425, + 1.2432, + 1.2434, + 1.2442, + 1.2463, + 1.248, + 1.2507, + 1.2607, + 1.2709, + 1.345, + 1.335, + 1.2613, + 1.2528, + 1.2457, + 1.2438, + 1.2416, + 1.239, + 1.2384, + 1.2384, + 1.2377, + 1.2372, + 1.2373, + 1.2373, + 1.2372, + 1.2377, + 1.2384, + 1.2384, + 1.239, + 1.2416, + 1.2438, + 1.2457, + 1.2528, + 1.2613, + 1.335, + 1.345, + 1.2709, + 1.2607, + 1.2507, + 1.248, + 1.2463, + 1.2442, + 1.2434, + 1.2432, + 1.2425, + 1.242, + 1.242, + 1.242, + 1.242, + 1.2425, + 1.2432, + 1.2434, + 1.2442, + 1.2463, + 1.248, + 1.2507, + 1.2607, + 1.2709, + 1.345, + 1.3566, + 1.2758, + 1.2693, + 1.2573, + 1.2531, + 1.2522, + 1.2503, + 1.2489, + 1.2484, + 1.2479, + 1.2475, + 1.2474, + 1.2474, + 1.2475, + 1.2479, + 1.2484, + 1.2489, + 1.2503, + 1.2522, + 1.2531, + 1.2573, + 1.2693, + 1.2758, + 1.3566, + 1.506, + 1.3486, + 1.3332, + 1.3187, + 1.3103, + 1.3068, + 1.3037, + 1.3015, + 1.3006, + 1.2999, + 1.2991, + 1.2986, + 1.2986, + 1.2991, + 1.2999, + 1.3006, + 1.3015, + 1.3037, + 1.3068, + 1.3103, + 1.3187, + 1.3332, + 1.3486, + 1.506, + 1.506, + ], + ]) def load_xfull(): - return nd.array( - [ - [0.0, 1.0], - [0.0, 2.0], - [0.0, 3.0], - [0.0, 4.0], - [0.0, 5.0], - [0.0, 6.0], - [0.0, 7.0], - [0.0, 8.0], - [0.0, 9.0], - [0.0, 10.0], - [0.0, 11.0], - [0.0, 12.0], - [0.0, 13.0], - [0.0, 14.0], - [0.0, 15.0], - [0.0, 16.0], - [0.0, 17.0], - [0.0, 18.0], - [0.0, 19.0], - [0.0, 20.0], - [0.0, 21.0], - [0.0, 22.0], - [0.0, 23.0], - [1.0, 0.0], - [1.0, 1.0], - [1.0, 2.0], - [1.0, 3.0], - [1.0, 4.0], - [1.0, 5.0], - [1.0, 6.0], - [1.0, 7.0], - [1.0, 8.0], - [1.0, 9.0], - [1.0, 10.0], - [1.0, 11.0], - [1.0, 12.0], - [1.0, 13.0], - [1.0, 14.0], - [1.0, 15.0], - [1.0, 16.0], - [1.0, 17.0], - [1.0, 18.0], - [1.0, 19.0], - [1.0, 20.0], - [1.0, 21.0], - [1.0, 22.0], - [1.0, 23.0], - [2.0, 0.0], - [2.0, 1.0], - [2.0, 2.0], - [2.0, 3.0], - [2.0, 4.0], - [2.0, 5.0], - [2.0, 6.0], - [2.0, 7.0], - [2.0, 8.0], - [2.0, 9.0], - [2.0, 10.0], - [2.0, 11.0], - [2.0, 12.0], - [2.0, 13.0], - [2.0, 14.0], - [2.0, 15.0], - [2.0, 16.0], - [2.0, 17.0], - [2.0, 18.0], - [2.0, 19.0], - [2.0, 20.0], - [2.0, 21.0], - [2.0, 22.0], - [2.0, 23.0], - [3.0, 0.0], - [3.0, 1.0], - [3.0, 2.0], - [3.0, 3.0], - [3.0, 4.0], - [3.0, 5.0], - [3.0, 6.0], - [3.0, 7.0], - [3.0, 8.0], - [3.0, 9.0], - [3.0, 10.0], - [3.0, 11.0], - [3.0, 12.0], - [3.0, 13.0], - [3.0, 14.0], - [3.0, 15.0], - [3.0, 16.0], - [3.0, 17.0], - [3.0, 18.0], - [3.0, 19.0], - [3.0, 20.0], - [3.0, 21.0], - [3.0, 22.0], - [3.0, 23.0], - [4.0, 0.0], - [4.0, 1.0], - [4.0, 2.0], - [4.0, 3.0], - [4.0, 4.0], - [4.0, 5.0], - [4.0, 6.0], - [4.0, 7.0], - [4.0, 8.0], - [4.0, 9.0], - [4.0, 10.0], - [4.0, 11.0], - [4.0, 12.0], - [4.0, 13.0], - [4.0, 14.0], - [4.0, 15.0], - [4.0, 16.0], - [4.0, 17.0], - [4.0, 18.0], - [4.0, 19.0], - [4.0, 20.0], - [4.0, 21.0], - [4.0, 22.0], - [4.0, 23.0], - [5.0, 0.0], - [5.0, 1.0], - [5.0, 2.0], - [5.0, 3.0], - [5.0, 4.0], - [5.0, 5.0], - [5.0, 6.0], - [5.0, 7.0], - [5.0, 8.0], - [5.0, 9.0], - [5.0, 10.0], - [5.0, 11.0], - [5.0, 12.0], - [5.0, 13.0], - [5.0, 14.0], - [5.0, 15.0], - [5.0, 16.0], - [5.0, 17.0], - [5.0, 18.0], - [5.0, 19.0], - [5.0, 20.0], - [5.0, 21.0], - [5.0, 22.0], - [5.0, 23.0], - [6.0, 0.0], - [6.0, 1.0], - [6.0, 2.0], - [6.0, 3.0], - [6.0, 4.0], - [6.0, 5.0], - [6.0, 6.0], - [6.0, 7.0], - [6.0, 8.0], - [6.0, 9.0], - [6.0, 10.0], - [6.0, 11.0], - [6.0, 12.0], - [6.0, 13.0], - [6.0, 14.0], - [6.0, 15.0], - [6.0, 16.0], - [6.0, 17.0], - [6.0, 18.0], - [6.0, 19.0], - [6.0, 20.0], - [6.0, 21.0], - [6.0, 22.0], - [6.0, 23.0], - [0.0, 0.0], - [0.0, 1.0], - [0.0, 2.0], - [0.0, 3.0], - [0.0, 4.0], - [0.0, 5.0], - [0.0, 6.0], - [0.0, 7.0], - [0.0, 8.0], - [0.0, 9.0], - [0.0, 10.0], - [0.0, 11.0], - [0.0, 12.0], - [0.0, 13.0], - [0.0, 14.0], - [0.0, 15.0], - [0.0, 16.0], - [0.0, 17.0], - [0.0, 18.0], - [0.0, 19.0], - [0.0, 20.0], - [0.0, 21.0], - [0.0, 22.0], - [0.0, 23.0], - [1.0, 0.0], - [1.0, 1.0], - [1.0, 2.0], - [1.0, 3.0], - [1.0, 4.0], - [1.0, 5.0], - [1.0, 6.0], - [1.0, 7.0], - [1.0, 8.0], - [1.0, 9.0], - [1.0, 10.0], - [1.0, 11.0], - [1.0, 12.0], - [1.0, 13.0], - [1.0, 14.0], - [1.0, 15.0], - [1.0, 16.0], - [1.0, 17.0], - [1.0, 18.0], - [1.0, 19.0], - [1.0, 20.0], - [1.0, 21.0], - [1.0, 22.0], - [1.0, 23.0], - [2.0, 0.0], - [2.0, 1.0], - [2.0, 2.0], - [2.0, 3.0], - [2.0, 4.0], - [2.0, 5.0], - [2.0, 6.0], - [2.0, 7.0], - [2.0, 8.0], - [2.0, 9.0], - [2.0, 10.0], - [2.0, 11.0], - [2.0, 12.0], - [2.0, 13.0], - [2.0, 14.0], - [2.0, 15.0], - [2.0, 16.0], - [2.0, 17.0], - [2.0, 18.0], - [2.0, 19.0], - [2.0, 20.0], - [2.0, 21.0], - [2.0, 22.0], - [2.0, 23.0], - [3.0, 0.0], - [3.0, 1.0], - [3.0, 2.0], - [3.0, 3.0], - [3.0, 4.0], - [3.0, 5.0], - [3.0, 6.0], - [3.0, 7.0], - [3.0, 8.0], - [3.0, 9.0], - [3.0, 10.0], - [3.0, 11.0], - [3.0, 12.0], - [3.0, 13.0], - [3.0, 14.0], - [3.0, 15.0], - [3.0, 16.0], - [3.0, 17.0], - [3.0, 18.0], - [3.0, 19.0], - [3.0, 20.0], - [3.0, 21.0], - [3.0, 22.0], - [3.0, 23.0], - [4.0, 0.0], - [4.0, 1.0], - [4.0, 2.0], - [4.0, 3.0], - [4.0, 4.0], - [4.0, 5.0], - [4.0, 6.0], - [4.0, 7.0], - [4.0, 8.0], - [4.0, 9.0], - [4.0, 10.0], - [4.0, 11.0], - [4.0, 12.0], - [4.0, 13.0], - [4.0, 14.0], - [4.0, 15.0], - [4.0, 16.0], - [4.0, 17.0], - [4.0, 18.0], - [4.0, 19.0], - [4.0, 20.0], - [4.0, 21.0], - [4.0, 22.0], - [4.0, 23.0], - [5.0, 0.0], - [5.0, 1.0], - [5.0, 2.0], - [5.0, 3.0], - [5.0, 4.0], - [5.0, 5.0], - [5.0, 6.0], - [5.0, 7.0], - [5.0, 8.0], - [5.0, 9.0], - [5.0, 10.0], - [5.0, 11.0], - [5.0, 12.0], - [5.0, 13.0], - [5.0, 14.0], - [5.0, 15.0], - [5.0, 16.0], - [5.0, 17.0], - [5.0, 18.0], - [5.0, 19.0], - [5.0, 20.0], - [5.0, 21.0], - [5.0, 22.0], - [5.0, 23.0], - [6.0, 0.0], - [6.0, 1.0], - [6.0, 2.0], - [6.0, 3.0], - [6.0, 4.0], - [6.0, 5.0], - [6.0, 6.0], - [6.0, 7.0], - [6.0, 8.0], - [6.0, 9.0], - [6.0, 10.0], - [6.0, 11.0], - [6.0, 12.0], - [6.0, 13.0], - [6.0, 14.0], - [6.0, 15.0], - [6.0, 16.0], - [6.0, 17.0], - [6.0, 18.0], - [6.0, 19.0], - [6.0, 20.0], - [6.0, 21.0], - [6.0, 22.0], - [6.0, 23.0], - [0.0, 0.0], - ] - ).expand_dims(axis=0) + return nd.array([ + [0.0, 1.0], + [0.0, 2.0], + [0.0, 3.0], + [0.0, 4.0], + [0.0, 5.0], + [0.0, 6.0], + [0.0, 7.0], + [0.0, 8.0], + [0.0, 9.0], + [0.0, 10.0], + [0.0, 11.0], + [0.0, 12.0], + [0.0, 13.0], + [0.0, 14.0], + [0.0, 15.0], + [0.0, 16.0], + [0.0, 17.0], + [0.0, 18.0], + [0.0, 19.0], + [0.0, 20.0], + [0.0, 21.0], + [0.0, 22.0], + [0.0, 23.0], + [1.0, 0.0], + [1.0, 1.0], + [1.0, 2.0], + [1.0, 3.0], + [1.0, 4.0], + [1.0, 5.0], + [1.0, 6.0], + [1.0, 7.0], + [1.0, 8.0], + [1.0, 9.0], + [1.0, 10.0], + [1.0, 11.0], + [1.0, 12.0], + [1.0, 13.0], + [1.0, 14.0], + [1.0, 15.0], + [1.0, 16.0], + [1.0, 17.0], + [1.0, 18.0], + [1.0, 19.0], + [1.0, 20.0], + [1.0, 21.0], + [1.0, 22.0], + [1.0, 23.0], + [2.0, 0.0], + [2.0, 1.0], + [2.0, 2.0], + [2.0, 3.0], + [2.0, 4.0], + [2.0, 5.0], + [2.0, 6.0], + [2.0, 7.0], + [2.0, 8.0], + [2.0, 9.0], + [2.0, 10.0], + [2.0, 11.0], + [2.0, 12.0], + [2.0, 13.0], + [2.0, 14.0], + [2.0, 15.0], + [2.0, 16.0], + [2.0, 17.0], + [2.0, 18.0], + [2.0, 19.0], + [2.0, 20.0], + [2.0, 21.0], + [2.0, 22.0], + [2.0, 23.0], + [3.0, 0.0], + [3.0, 1.0], + [3.0, 2.0], + [3.0, 3.0], + [3.0, 4.0], + [3.0, 5.0], + [3.0, 6.0], + [3.0, 7.0], + [3.0, 8.0], + [3.0, 9.0], + [3.0, 10.0], + [3.0, 11.0], + [3.0, 12.0], + [3.0, 13.0], + [3.0, 14.0], + [3.0, 15.0], + [3.0, 16.0], + [3.0, 17.0], + [3.0, 18.0], + [3.0, 19.0], + [3.0, 20.0], + [3.0, 21.0], + [3.0, 22.0], + [3.0, 23.0], + [4.0, 0.0], + [4.0, 1.0], + [4.0, 2.0], + [4.0, 3.0], + [4.0, 4.0], + [4.0, 5.0], + [4.0, 6.0], + [4.0, 7.0], + [4.0, 8.0], + [4.0, 9.0], + [4.0, 10.0], + [4.0, 11.0], + [4.0, 12.0], + [4.0, 13.0], + [4.0, 14.0], + [4.0, 15.0], + [4.0, 16.0], + [4.0, 17.0], + [4.0, 18.0], + [4.0, 19.0], + [4.0, 20.0], + [4.0, 21.0], + [4.0, 22.0], + [4.0, 23.0], + [5.0, 0.0], + [5.0, 1.0], + [5.0, 2.0], + [5.0, 3.0], + [5.0, 4.0], + [5.0, 5.0], + [5.0, 6.0], + [5.0, 7.0], + [5.0, 8.0], + [5.0, 9.0], + [5.0, 10.0], + [5.0, 11.0], + [5.0, 12.0], + [5.0, 13.0], + [5.0, 14.0], + [5.0, 15.0], + [5.0, 16.0], + [5.0, 17.0], + [5.0, 18.0], + [5.0, 19.0], + [5.0, 20.0], + [5.0, 21.0], + [5.0, 22.0], + [5.0, 23.0], + [6.0, 0.0], + [6.0, 1.0], + [6.0, 2.0], + [6.0, 3.0], + [6.0, 4.0], + [6.0, 5.0], + [6.0, 6.0], + [6.0, 7.0], + [6.0, 8.0], + [6.0, 9.0], + [6.0, 10.0], + [6.0, 11.0], + [6.0, 12.0], + [6.0, 13.0], + [6.0, 14.0], + [6.0, 15.0], + [6.0, 16.0], + [6.0, 17.0], + [6.0, 18.0], + [6.0, 19.0], + [6.0, 20.0], + [6.0, 21.0], + [6.0, 22.0], + [6.0, 23.0], + [0.0, 0.0], + [0.0, 1.0], + [0.0, 2.0], + [0.0, 3.0], + [0.0, 4.0], + [0.0, 5.0], + [0.0, 6.0], + [0.0, 7.0], + [0.0, 8.0], + [0.0, 9.0], + [0.0, 10.0], + [0.0, 11.0], + [0.0, 12.0], + [0.0, 13.0], + [0.0, 14.0], + [0.0, 15.0], + [0.0, 16.0], + [0.0, 17.0], + [0.0, 18.0], + [0.0, 19.0], + [0.0, 20.0], + [0.0, 21.0], + [0.0, 22.0], + [0.0, 23.0], + [1.0, 0.0], + [1.0, 1.0], + [1.0, 2.0], + [1.0, 3.0], + [1.0, 4.0], + [1.0, 5.0], + [1.0, 6.0], + [1.0, 7.0], + [1.0, 8.0], + [1.0, 9.0], + [1.0, 10.0], + [1.0, 11.0], + [1.0, 12.0], + [1.0, 13.0], + [1.0, 14.0], + [1.0, 15.0], + [1.0, 16.0], + [1.0, 17.0], + [1.0, 18.0], + [1.0, 19.0], + [1.0, 20.0], + [1.0, 21.0], + [1.0, 22.0], + [1.0, 23.0], + [2.0, 0.0], + [2.0, 1.0], + [2.0, 2.0], + [2.0, 3.0], + [2.0, 4.0], + [2.0, 5.0], + [2.0, 6.0], + [2.0, 7.0], + [2.0, 8.0], + [2.0, 9.0], + [2.0, 10.0], + [2.0, 11.0], + [2.0, 12.0], + [2.0, 13.0], + [2.0, 14.0], + [2.0, 15.0], + [2.0, 16.0], + [2.0, 17.0], + [2.0, 18.0], + [2.0, 19.0], + [2.0, 20.0], + [2.0, 21.0], + [2.0, 22.0], + [2.0, 23.0], + [3.0, 0.0], + [3.0, 1.0], + [3.0, 2.0], + [3.0, 3.0], + [3.0, 4.0], + [3.0, 5.0], + [3.0, 6.0], + [3.0, 7.0], + [3.0, 8.0], + [3.0, 9.0], + [3.0, 10.0], + [3.0, 11.0], + [3.0, 12.0], + [3.0, 13.0], + [3.0, 14.0], + [3.0, 15.0], + [3.0, 16.0], + [3.0, 17.0], + [3.0, 18.0], + [3.0, 19.0], + [3.0, 20.0], + [3.0, 21.0], + [3.0, 22.0], + [3.0, 23.0], + [4.0, 0.0], + [4.0, 1.0], + [4.0, 2.0], + [4.0, 3.0], + [4.0, 4.0], + [4.0, 5.0], + [4.0, 6.0], + [4.0, 7.0], + [4.0, 8.0], + [4.0, 9.0], + [4.0, 10.0], + [4.0, 11.0], + [4.0, 12.0], + [4.0, 13.0], + [4.0, 14.0], + [4.0, 15.0], + [4.0, 16.0], + [4.0, 17.0], + [4.0, 18.0], + [4.0, 19.0], + [4.0, 20.0], + [4.0, 21.0], + [4.0, 22.0], + [4.0, 23.0], + [5.0, 0.0], + [5.0, 1.0], + [5.0, 2.0], + [5.0, 3.0], + [5.0, 4.0], + [5.0, 5.0], + [5.0, 6.0], + [5.0, 7.0], + [5.0, 8.0], + [5.0, 9.0], + [5.0, 10.0], + [5.0, 11.0], + [5.0, 12.0], + [5.0, 13.0], + [5.0, 14.0], + [5.0, 15.0], + [5.0, 16.0], + [5.0, 17.0], + [5.0, 18.0], + [5.0, 19.0], + [5.0, 20.0], + [5.0, 21.0], + [5.0, 22.0], + [5.0, 23.0], + [6.0, 0.0], + [6.0, 1.0], + [6.0, 2.0], + [6.0, 3.0], + [6.0, 4.0], + [6.0, 5.0], + [6.0, 6.0], + [6.0, 7.0], + [6.0, 8.0], + [6.0, 9.0], + [6.0, 10.0], + [6.0, 11.0], + [6.0, 12.0], + [6.0, 13.0], + [6.0, 14.0], + [6.0, 15.0], + [6.0, 16.0], + [6.0, 17.0], + [6.0, 18.0], + [6.0, 19.0], + [6.0, 20.0], + [6.0, 21.0], + [6.0, 22.0], + [6.0, 23.0], + [0.0, 0.0], + ]).expand_dims(axis=0) def load_ytrain(): - return nd.array( + return nd.array([ + [ + 3.392694091796875000e02, + 3.198630065917968750e02, + 3.210045776367187500e02, + 3.175798950195312500e02, + 3.130137023925781250e02, + 4.615182800292968750e02, + 5.757534179687500000e02, + 5.871689453125000000e02, + 6.362899780273437500e02, + 6.785616455078125000e02, + 7.677511596679687500e02, + 9.501369628906250000e02, + 1.041723754882812500e03, + 1.016529663085937500e03, + 9.112100219726562500e02, + 1.019965759277343750e03, + 1.019977172851562500e03, + 1.055456665039062500e03, + 1.033710083007812500e03, + 7.803538818359375000e02, + 7.322830810546875000e02, + 4.934931640625000000e02, + 3.872488708496093750e02, + 3.655593566894531250e02, + 3.747031860351562500e02, + 3.518378906250000000e02, + 3.621347045898437500e02, + 3.427054748535156250e02, + 3.461187133789062500e02, + 4.992009277343750000e02, + 6.488584594726562500e02, + 6.145890502929687500e02, + 6.888356323242187500e02, + 7.624086914062500000e02, + 7.517237548828125000e02, + 9.593036499023437500e02, + 9.810502319335937500e02, + 1.038287719726562500e03, + 9.249543457031250000e02, + 1.033710083007812500e03, + 1.079497680664062500e03, + 1.050879028320312500e03, + 9.673059082031250000e02, + 7.490639038085937500e02, + 7.319063720703125000e02, + 4.854908752441406250e02, + 3.826940612792968750e02, + 3.644178161621093750e02, + 3.667009277343750000e02, + 3.506963500976562500e02, + 3.438470458984375000e02, + 3.267123413085937500e02, + 3.347031860351562500e02, + 4.923515930175781250e02, + 6.385730590820312500e02, + 6.180136718750000000e02, + 6.766895141601562500e02, + 6.880822143554687500e02, + 7.643150634765625000e02, + 1.010810485839843750e03, + 1.053173461914062500e03, + 1.042876708984375000e03, + 9.570091552734375000e02, + 1.054303710937500000e03, + 1.118424682617187500e03, + 1.089805908203125000e03, + 9.661757812500000000e02, + 7.852968139648437500e02, + 7.608903808593750000e02, + 4.912100524902343750e02, + 3.803995361328125000e02, + 3.586986389160156250e02, + 3.667009277343750000e02, + 3.461187133789062500e02, + 3.609817199707031250e02, + 3.449771728515625000e02, + 3.404109497070312500e02, + 4.843721313476562500e02, + 6.560958862304687500e02, + 6.214497680664062500e02, + 6.937899780273437500e02, + 7.357191772460937500e02, + 7.421917724609375000e02, + 9.673173217773437500e02, + 9.375456542968750000e02, + 8.986301269531250000e02, + 9.787671508789062500e02, + 1.026849365234375000e03, + 9.684703369140625000e02, + 9.066552734375000000e02, + 8.253767089843750000e02, + 7.277055053710937500e02, + 7.002625732421875000e02, + 4.809360656738281250e02, + 3.712785339355468750e02, + 3.518378906250000000e02, + 3.552625427246093750e02, + 3.426940612792968750e02, + 3.438470458984375000e02, + 3.324201049804687500e02, + 3.278538818359375000e02, + 4.923401794433593750e02, + 6.351483764648437500e02, + 6.054566040039062500e02, + 6.796917724609375000e02, + 6.983789672851562500e02, + 7.273287963867187500e02, + 9.192237548828125000e02, + 9.512785644531250000e02, + 9.547260131835937500e02, + 9.169406127929687500e02, + 1.023401855468750000e03, + 1.011940612792968750e03, + 1.018824218750000000e03, + 9.146575317382812500e02, + 8.070548095703125000e02, + 7.444863281250000000e02, + 5.072031860351562500e02, + 3.998287658691406250e02, + 3.575570678710937500e02, + 3.598401794433593750e02, + 3.438356018066406250e02, + 3.472716979980468750e02, + 3.381278686523437500e02, + 3.426940612792968750e02, + 6.206963500976562500e02, + 5.837442626953125000e02, + 6.111643676757812500e02, + 7.471689453125000000e02, + 7.593721313476562500e02, + 8.070548095703125000e02, + 9.604451904296875000e02, + 1.025707763671875000e03, + 9.741781005859375000e02, + 9.364155273437500000e02, + 1.019965759277343750e03, + 1.085216918945312500e03, + 9.970662231445312500e02, + 8.929109497070312500e02, + 7.597374267578125000e02, + 7.505822143554687500e02, + 5.083447570800781250e02, + 3.986872253417968750e02, + 3.689954223632812500e02, + 3.769862976074218750e02, + 3.506849365234375000e02, + 3.552739868164062500e02, + 3.404223632812500000e02, + 3.312785339355468750e02, + 5.711757812500000000e02, + 5.403310546875000000e02, + 6.317237548828125000e02, + 7.059703369140625000e02, + 7.517579956054687500e02, + 8.024657592773437500e02, + 1.046301391601562500e03, + 1.109269409179687500e03, + 1.121860717773437500e03, + 1.015388122558593750e03, + 1.094383544921875000e03, + 1.151620971679687500e03, + 1.062317382812500000e03, + 1.077191772460937500e03, + 8.036187133789062500e02, + 7.216096191406250000e02, + 5.003424682617187500e02, + 3.964041137695312500e02, + 3.735616455078125000e02, + ], + [ + 2.263374519348144531e01, + 2.932098770141601562e01, + 2.932098770141601562e01, + 2.949245452880859375e01, + 2.880658531188964844e01, + 2.657750320434570312e01, + 2.263374519348144531e01, + 1.886145401000976562e01, + 2.434842300415039062e01, + 3.137860107421875000e01, + 2.400548744201660156e01, + 2.897805213928222656e01, + 3.377914810180664062e01, + 3.360768127441406250e01, + 3.240740585327148438e01, + 3.275034332275390625e01, + 2.880658531188964844e01, + 2.743484306335449219e01, + 3.755144119262695312e01, + 3.635116577148437500e01, + 3.703703689575195312e01, + 4.098079681396484375e01, + 3.275034332275390625e01, + 2.897805213928222656e01, + 3.086419677734375000e01, + 2.726337432861328125e01, + 1.920438957214355469e01, + 2.777777862548828125e01, + 2.846364974975585938e01, + 2.331961631774902344e01, + 1.406035709381103516e01, + 2.640603637695312500e01, + 2.537722969055175781e01, + 2.572016525268554688e01, + 3.240740585327148438e01, + 2.743484306335449219e01, + 2.897805213928222656e01, + 2.726337432861328125e01, + 2.006172752380371094e01, + 3.069272994995117188e01, + 3.000685882568359375e01, + 2.503429412841796875e01, + 3.000685882568359375e01, + 3.858024597167968750e01, + 3.789437484741210938e01, + 3.240740585327148438e01, + 3.480795669555664062e01, + 2.589163208007812500e01, + 2.263374519348144531e01, + 1.954732513427734375e01, + 1.920438957214355469e01, + 1.817558288574218750e01, + 2.006172752380371094e01, + 1.543209838867187500e01, + 1.406035709381103516e01, + 2.280521202087402344e01, + 3.275034332275390625e01, + 3.412208557128906250e01, + 3.360768127441406250e01, + 3.446501922607421875e01, + 3.515089035034179688e01, + 3.395061874389648438e01, + 2.846364974975585938e01, + 2.949245452880859375e01, + 3.017832565307617188e01, + 2.777777862548828125e01, + 3.446501922607421875e01, + 4.526749038696289062e01, + 3.943758392333984375e01, + 4.818244171142578125e01, + 3.737997436523437500e01, + 3.480795669555664062e01, + 2.400548744201660156e01, + 2.623456764221191406e01, + 2.760630989074707031e01, + 2.846364974975585938e01, + 2.932098770141601562e01, + 1.851851844787597656e01, + 2.383401870727539062e01, + 2.914951896667480469e01, + 2.897805213928222656e01, + 2.709190750122070312e01, + 3.343621444702148438e01, + 3.292181015014648438e01, + 3.772290802001953125e01, + 3.858024597167968750e01, + 3.326474761962890625e01, + 2.589163208007812500e01, + 3.172153663635253906e01, + 3.377914810180664062e01, + 4.406721496582031250e01, + 4.749657058715820312e01, + 4.406721496582031250e01, + 3.446501922607421875e01, + 3.960905456542968750e01, + 3.463648986816406250e01, + 3.189300346374511719e01, + 3.034979438781738281e01, + 3.069272994995117188e01, + 2.880658531188964844e01, + 2.897805213928222656e01, + 2.143346977233886719e01, + 2.331961631774902344e01, + 2.520576095581054688e01, + 2.109053421020507812e01, + 2.777777862548828125e01, + 3.000685882568359375e01, + 3.017832565307617188e01, + 3.343621444702148438e01, + 3.223593902587890625e01, + 2.503429412841796875e01, + 3.446501922607421875e01, + 3.343621444702148438e01, + 2.812071418762207031e01, + 2.897805213928222656e01, + 3.995198822021484375e01, + 3.377914810180664062e01, + 3.515089035034179688e01, + 3.703703689575195312e01, + 2.897805213928222656e01, + 2.366255187988281250e01, + 2.812071418762207031e01, + 2.760630989074707031e01, + 2.143346977233886719e01, + 2.743484306335449219e01, + 2.589163208007812500e01, + 2.177640533447265625e01, + 2.057613182067871094e01, + 3.292181015014648438e01, + 2.486282539367675781e01, + 2.589163208007812500e01, + 4.200960159301757812e01, + 3.086419677734375000e01, + 3.412208557128906250e01, + 3.840877914428710938e01, + 3.155006790161132812e01, + 3.600823211669921875e01, + 3.669410324096679688e01, + 3.446501922607421875e01, + 3.720850372314453125e01, + 4.406721496582031250e01, + 3.326474761962890625e01, + 3.532236099243164062e01, + 3.858024597167968750e01, + 3.275034332275390625e01, + 2.194787406921386719e01, + 2.897805213928222656e01, + 2.863511657714843750e01, + 2.469135856628417969e01, + 2.331961631774902344e01, + 2.554869651794433594e01, + 2.143346977233886719e01, + 2.486282539367675781e01, + 2.743484306335449219e01, + 2.383401870727539062e01, + 3.069272994995117188e01, + 2.794924545288085938e01, + 2.812071418762207031e01, + 2.434842300415039062e01, + 1.989026069641113281e01, + 1.971879196166992188e01, + 2.194787406921386719e01, + 2.486282539367675781e01, + 3.360768127441406250e01, + 3.343621444702148438e01, + 3.206447219848632812e01, + 3.034979438781738281e01, + 2.349108314514160156e01, + ], + [ + 1.214824981689453125e02, + 1.192518844604492188e02, + 1.103294448852539062e02, + 1.139327392578125000e02, + 1.137611541748046875e02, + 1.456760406494140625e02, + 1.597460479736328125e02, + 1.443033599853515625e02, + 1.691832580566406250e02, + 1.798215484619140625e02, + 1.918325347900390625e02, + 1.990391235351562500e02, + 1.932052154541015625e02, + 2.007549743652343750e02, + 2.069320526123046875e02, + 2.084763183593750000e02, + 2.134523010253906250e02, + 2.148249816894531250e02, + 2.132807159423828125e02, + 2.143102264404296875e02, + 2.086479034423828125e02, + 1.935483856201171875e02, + 1.743308105468750000e02, + 1.717570343017578125e02, + 1.662662963867187500e02, + 1.657515411376953125e02, + 1.580301971435546875e02, + 1.626630096435546875e02, + 1.640356903076171875e02, + 1.841111907958984375e02, + 2.000686340332031250e02, + 1.853122863769531250e02, + 2.057309570312500000e02, + 2.074468078613281250e02, + 2.151681518554687500e02, + 2.095058288574218750e02, + 2.088194885253906250e02, + 2.156829071044921875e02, + 2.189430389404296875e02, + 2.215168151855468750e02, + 2.287234039306640625e02, + 2.292381591796875000e02, + 2.230610809326171875e02, + 2.177419281005859375e02, + 2.115648651123046875e02, + 1.921757049560546875e02, + 1.741592254638671875e02, + 1.535689697265625000e02, + 1.590597076416015625e02, + 1.606039733886718750e02, + 1.568291015625000000e02, + 1.571722717285156250e02, + 1.575154418945312500e02, + 1.911461944580078125e02, + 1.866849670410156250e02, + 1.885724029541015625e02, + 2.033287506103515625e02, + 2.115648651123046875e02, + 2.316403503417968750e02, + 2.280370635986328125e02, + 2.302676696777343750e02, + 2.395332946777343750e02, + 2.343857269287109375e02, + 2.278654785156250000e02, + 2.285518188476562500e02, + 2.268359680175781250e02, + 2.184282836914062500e02, + 2.112216949462890625e02, + 2.155113220214843750e02, + 2.148249816894531250e02, + 1.988675384521484375e02, + 1.842827758789062500e02, + 1.726149597167968750e02, + 1.806794738769531250e02, + 1.762182617187500000e02, + 1.774193572998046875e02, + 1.835964355468750000e02, + 1.964653472900390625e02, + 2.240905914306640625e02, + 1.781056976318359375e02, + 1.878860626220703125e02, + 1.969801025390625000e02, + 2.091626586914062500e02, + 2.098489990234375000e02, + 2.163692474365234375e02, + 2.210020599365234375e02, + 2.251201171875000000e02, + 2.242621765136718750e02, + 2.196293792724609375e02, + 2.276938934326171875e02, + 2.225463256835937500e02, + 2.210020599365234375e02, + 2.047014465332031250e02, + 1.944063110351562500e02, + 1.722717895507812500e02, + 1.523678741455078125e02, + 1.559711761474609375e02, + 1.551132507324218750e02, + 1.513383636474609375e02, + 1.544269104003906250e02, + 1.530542144775390625e02, + 1.871997222900390625e02, + 1.799931335449218750e02, + 1.717570343017578125e02, + 1.935483856201171875e02, + 1.969801025390625000e02, + 2.067604675292968750e02, + 2.228894958496093750e02, + 2.105353393554687500e02, + 2.149965667724609375e02, + 2.024708251953125000e02, + 2.409059753417968750e02, + 2.491420745849609375e02, + 2.283802337646484375e02, + 2.213452301025390625e02, + 2.198009643554687500e02, + 2.007549743652343750e02, + 1.859986267089843750e02, + 1.739876403808593750e02, + 1.583733673095703125e02, + 1.376115264892578125e02, + 1.377831115722656250e02, + 1.362388458251953125e02, + 1.353809204101562500e02, + 1.431022644042968750e02, + 1.695264282226562500e02, + 1.822237548828125000e02, + 1.703843536376953125e02, + 1.896019287109375000e02, + 1.983527832031250000e02, + 2.048730316162109375e02, + 2.052162017822265625e02, + 2.062457122802734375e02, + 2.081331481933593750e02, + 2.047014465332031250e02, + 2.149965667724609375e02, + 2.047014465332031250e02, + 2.055593719482421875e02, + 2.227179107666015625e02, + 2.113932800292968750e02, + 1.998970489501953125e02, + 1.817089843750000000e02, + 1.657515411376953125e02, + 1.499656829833984375e02, + 1.364104309082031250e02, + 1.312628631591796875e02, + 1.283459167480468750e02, + 1.250857925415039062e02, + 1.264584732055664062e02, + 1.654083709716796875e02, + 1.798215484619140625e02, + 1.643788604736328125e02, + 1.904598541259765625e02, + 1.921757049560546875e02, + 2.069320526123046875e02, + 2.215168151855468750e02, + 2.185998687744140625e02, + 2.246053466796875000e02, + 2.259780426025390625e02, + 2.304392547607421875e02, + 2.283802337646484375e02, + 2.264927978515625000e02, + 2.222031555175781250e02, + 2.115648651123046875e02, + 2.010981445312500000e02, + 1.944063110351562500e02, + 1.676389770507812500e02, + 1.460192108154296875e02, + ], + [ + 6.946183013916015625e01, + 6.226533126831054688e01, + 6.007509231567382812e01, + 5.913642120361328125e01, + 5.788485717773437500e01, + 6.289111328125000000e01, + 4.787234115600585938e01, + 5.287859725952148438e01, + 5.287859725952148438e01, + 6.038798522949218750e01, + 6.570713043212890625e01, + 7.196495819091796875e01, + 7.352941131591796875e01, + 6.007509231567382812e01, + 5.851063919067382812e01, + 6.695870208740234375e01, + 6.289111328125000000e01, + 7.259073638916015625e01, + 9.355444335937500000e01, + 8.573216247558593750e01, + 9.793492126464843750e01, + 1.013767242431640625e02, + 1.004380493164062500e02, + 7.884856414794921875e01, + 6.727159118652343750e01, + 6.351689529418945312e01, + 6.351689529418945312e01, + 6.226533126831054688e01, + 6.289111328125000000e01, + 6.821026611328125000e01, + 4.630788421630859375e01, + 5.162703323364257812e01, + 5.037546920776367188e01, + 6.070087432861328125e01, + 6.821026611328125000e01, + 7.446808624267578125e01, + 7.415519714355468750e01, + 6.758448028564453125e01, + 6.539424133300781250e01, + 6.602002716064453125e01, + 6.226533126831054688e01, + 7.133917236328125000e01, + 8.854818725585937500e01, + 8.823529052734375000e01, + 1.041927413940429688e02, + 1.029411773681640625e02, + 9.699624633789062500e01, + 7.634542846679687500e01, + 6.758448028564453125e01, + 6.445556640625000000e01, + 6.351689529418945312e01, + 6.101376724243164062e01, + 5.976220321655273438e01, + 6.789736938476562500e01, + 5.225281524658203125e01, + 4.724655914306640625e01, + 5.131414413452148438e01, + 6.508135223388671875e01, + 7.478097534179687500e01, + 7.384230041503906250e01, + 6.163954925537109375e01, + 6.476846313476562500e01, + 6.414267730712890625e01, + 5.694618225097656250e01, + 6.508135223388671875e01, + 6.914893341064453125e01, + 7.947434234619140625e01, + 8.604505920410156250e01, + 1.001251602172851562e02, + 9.793492126464843750e01, + 9.167709350585937500e01, + 7.790988922119140625e01, + 6.789736938476562500e01, + 6.633291625976562500e01, + 6.351689529418945312e01, + 6.226533126831054688e01, + 6.195244216918945312e01, + 6.476846313476562500e01, + 5.225281524658203125e01, + 5.068836212158203125e01, + 5.413016128540039062e01, + 5.757196426391601562e01, + 7.352941131591796875e01, + 7.321652221679687500e01, + 7.196495819091796875e01, + 6.163954925537109375e01, + 5.882352828979492188e01, + 6.414267730712890625e01, + 6.476846313476562500e01, + 7.790988922119140625e01, + 9.824781036376953125e01, + 8.604505920410156250e01, + 1.026282882690429688e02, + 1.091989974975585938e02, + 1.073216552734375000e02, + 8.541927337646484375e01, + 7.415519714355468750e01, + 6.852315521240234375e01, + 6.289111328125000000e01, + 6.070087432861328125e01, + 6.101376724243164062e01, + 6.414267730712890625e01, + 5.256570816040039062e01, + 4.943679428100585938e01, + 6.070087432861328125e01, + 7.571965026855468750e01, + 8.792240142822265625e01, + 8.698372650146484375e01, + 7.822277832031250000e01, + 7.133917236328125000e01, + 7.227784729003906250e01, + 6.602002716064453125e01, + 6.883604431152343750e01, + 6.977471923828125000e01, + 7.978723144531250000e01, + 9.042552947998046875e01, + 1.057571945190429688e02, + 1.007509384155273438e02, + 9.762202453613281250e01, + 8.479349517822265625e01, + 7.165206146240234375e01, + 6.821026611328125000e01, + 6.226533126831054688e01, + 6.195244216918945312e01, + 7.478097534179687500e01, + 5.882352828979492188e01, + 4.974968719482421875e01, + 5.381727218627929688e01, + 6.007509231567382812e01, + 7.259073638916015625e01, + 7.603253936767578125e01, + 8.573216247558593750e01, + 8.698372650146484375e01, + 8.041301727294921875e01, + 7.853566741943359375e01, + 8.197747039794921875e01, + 7.603253936767578125e01, + 8.948686218261718750e01, + 8.948686218261718750e01, + 9.230287933349609375e01, + 9.230287933349609375e01, + 9.511889648437500000e01, + 9.543179321289062500e01, + 8.823529052734375000e01, + 8.166458129882812500e01, + 7.133917236328125000e01, + 6.195244216918945312e01, + 6.038798522949218750e01, + 5.976220321655273438e01, + 6.883604431152343750e01, + 4.849812316894531250e01, + 5.319149017333984375e01, + 6.382978820800781250e01, + 7.634542846679687500e01, + 9.261576843261718750e01, + 1.001251602172851562e02, + 8.573216247558593750e01, + 6.539424133300781250e01, + 5.600751113891601562e01, + 5.694618225097656250e01, + 5.538172531127929688e01, + 6.382978820800781250e01, + 7.790988922119140625e01, + 8.197747039794921875e01, + 1.029411773681640625e02, + 1.135794754028320312e02, + 1.004380493164062500e02, + 8.479349517822265625e01, + ], + [ + 4.069791030883789062e01, + 4.158940505981445312e01, + 3.891365432739257812e01, + 3.942435073852539062e01, + 4.057055664062500000e01, + 6.334054946899414062e01, + 5.866785430908203125e01, + 6.113219451904296875e01, + 7.438232421875000000e01, + 8.779418945312500000e01, + 9.533239746093750000e01, + 9.737519073486328125e01, + 9.609780883789062500e01, + 9.724783325195312500e01, + 9.839658355712890625e01, + 1.000573120117187500e02, + 1.004406509399414062e02, + 1.027394256591796875e02, + 9.839658355712890625e01, + 1.037621002197265625e02, + 9.520503997802734375e01, + 7.690779113769531250e01, + 4.783494567871093750e01, + 4.464849853515625000e01, + 4.490448379516601562e01, + 4.630667495727539062e01, + 4.362837600708007812e01, + 4.528655242919921875e01, + 4.452114105224609375e01, + 6.675624084472656250e01, + 7.067881011962890625e01, + 6.830870819091796875e01, + 7.859780883789062500e01, + 9.124427032470703125e01, + 9.673586273193359375e01, + 9.367167663574218750e01, + 9.852394104003906250e01, + 9.699057769775390625e01, + 9.647988128662109375e01, + 1.008239974975585938e02, + 1.017180328369140625e02, + 1.015906753540039062e02, + 1.017193069458007812e02, + 9.545848083496093750e01, + 9.341441345214843750e01, + 7.792536926269531250e01, + 5.165945053100585938e01, + 4.707080841064453125e01, + 4.719816589355468750e01, + 4.643275451660156250e01, + 4.579597473144531250e01, + 4.656138610839843750e01, + 4.732552337646484375e01, + 7.882450103759765625e01, + 7.118950653076171875e01, + 7.361691284179687500e01, + 8.600611114501953125e01, + 1.018466644287109375e02, + 9.929190063476562500e01, + 1.051681137084960938e02, + 1.092536926269531250e02, + 1.017180328369140625e02, + 1.055514526367187500e02, + 1.087442703247070312e02, + 1.073382568359375000e02, + 1.014620513916015625e02, + 1.061895065307617188e02, + 9.954534149169921875e01, + 9.418109893798828125e01, + 7.562786865234375000e01, + 4.923586273193359375e01, + 4.630540084838867188e01, + 4.643275451660156250e01, + 4.643275451660156250e01, + 4.477585220336914062e01, + 4.541263198852539062e01, + 4.388181304931640625e01, + 6.921293640136718750e01, + 7.013372039794921875e01, + 6.910977935791015625e01, + 8.690015411376953125e01, + 1.012073364257812500e02, + 1.023560867309570312e02, + 1.006953659057617188e02, + 9.967396545410156250e01, + 9.980132293701171875e01, + 1.022287292480468750e02, + 1.095096817016601562e02, + 1.096370315551757812e02, + 1.024847183227539062e02, + 1.031240463256835938e02, + 1.026120758056640625e02, + 9.277508544921875000e01, + 7.511716461181640625e01, + 4.872771453857421875e01, + 4.273560714721679688e01, + 4.311894989013671875e01, + 4.107997894287109375e01, + 3.980641937255859375e01, + 4.006113052368164062e01, + 3.853158569335937500e01, + 6.669383239746093750e01, + 6.282093811035156250e01, + 6.205807495117187500e01, + 8.000127410888671875e01, + 9.022160339355468750e01, + 9.839531707763671875e01, + 9.941798400878906250e01, + 1.023573608398437500e02, + 1.008239974975585938e02, + 1.035061111450195312e02, + 1.044001541137695312e02, + 1.088716278076171875e02, + 1.124477844238281250e02, + 1.075929718017578125e02, + 1.008239974975585938e02, + 9.622644042968750000e01, + 9.073229980468750000e01, + 6.257386779785156250e01, + 4.311894989013671875e01, + 4.299032211303710938e01, + 4.273815536499023438e01, + 4.095262527465820312e01, + 4.082526779174804688e01, + 4.146204757690429688e01, + 7.409832000732421875e01, + 6.358507537841796875e01, + 6.754330444335937500e01, + 8.396331787109375000e01, + 9.405374145507812500e01, + 1.013346939086914062e02, + 1.063155899047851562e02, + 1.123204269409179688e02, + 1.139798812866210938e02, + 1.114251174926757812e02, + 1.038894577026367188e02, + 1.175560379028320312e02, + 1.093823242187500000e02, + 1.097656631469726562e02, + 1.035061111450195312e02, + 9.890728759765625000e01, + 9.341568756103515625e01, + 6.959373474121093750e01, + 4.503183746337890625e01, + 4.515919494628906250e01, + 4.388308715820312500e01, + 4.069791030883789062e01, + 4.082526779174804688e01, + 4.069791030883789062e01, + 6.576541137695312500e01, + 6.329724884033203125e01, + 6.716250610351562500e01, + 8.204534149169921875e01, + 9.405374145507812500e01, + 1.003120193481445312e02, + 1.045287857055664062e02, + 1.082335739135742188e02, + 1.067002029418945312e02, + 1.054215469360351562e02, + 1.070835418701171875e02, + 1.100216522216796875e02, + 1.037608261108398438e02, + 9.865257263183593750e01, + 9.775852966308593750e01, + 9.711793518066406250e01, + 8.137290191650390625e01, + 4.974656295776367188e01, + 4.464849853515625000e01, + ], + [ + 1.140194625854492188e02, + 1.181511993408203125e02, + 1.222866744995117188e02, + 1.147717056274414062e02, + 1.102619781494140625e02, + 1.170284423828125000e02, + 1.421856231689453125e02, + 1.713136291503906250e02, + 1.703555450439453125e02, + 1.906886291503906250e02, + 2.291916198730468750e02, + 2.488061370849609375e02, + 2.491841278076171875e02, + 2.295696105957031250e02, + 2.412574920654296875e02, + 2.488061370849609375e02, + 2.601235046386718750e02, + 2.408869781494140625e02, + 2.397492523193359375e02, + 2.078667602539062500e02, + 1.358046417236328125e02, + 1.249139251708984375e02, + 1.219124221801757812e02, + 1.155202102661132812e02, + 1.162761993408203125e02, + 1.196519470214843750e02, + 1.245359268188476562e02, + 1.189034423828125000e02, + 1.162761993408203125e02, + 1.410591278076171875e02, + 1.515793457031250000e02, + 2.093974609375000000e02, + 2.220247039794921875e02, + 2.397567291259765625e02, + 2.518263397216796875e02, + 2.555950622558593750e02, + 2.435254516601562500e02, + 2.242926635742187500e02, + 1.980613708496093750e02, + 2.208907165527343750e02, + 2.337163238525390625e02, + 2.310778503417968750e02, + 1.957934112548828125e02, + 1.893824920654296875e02, + 1.403068847656250000e02, + 1.264146728515625000e02, + 1.207821884155273438e02, + 1.155239486694335938e02, + 1.166504516601562500e02, + 1.189071884155273438e02, + 1.252881698608398438e02, + 1.192814407348632812e02, + 1.158982009887695312e02, + 1.414371185302734375e02, + 1.583345794677734375e02, + 2.112799377441406250e02, + 2.442814331054687500e02, + 2.499401245117187500e02, + 2.529528503417968750e02, + 2.495583801269531250e02, + 2.552208099365234375e02, + 2.488061370849609375e02, + 2.465456542968750000e02, + 2.503181152343750000e02, + 2.552208099365234375e02, + 2.559693145751953125e02, + 2.484281463623046875e02, + 2.325860748291015625e02, + 1.423989562988281250e02, + 1.241654205322265625e02, + 1.192739486694335938e02, + 1.155202102661132812e02, + 1.158944625854492188e02, + 1.185254516601562500e02, + 1.256624221801757812e02, + 1.189034423828125000e02, + 1.155239486694335938e02, + 1.403106231689453125e02, + 1.538323364257812500e02, + 2.082634735107421875e02, + 2.152357788085937500e02, + 2.340943145751953125e02, + 2.457896728515625000e02, + 2.461676635742187500e02, + 2.337163238525390625e02, + 2.363622741699218750e02, + 2.261751556396484375e02, + 2.431474609375000000e02, + 2.427694549560546875e02, + 2.205127258300781250e02, + 1.759880218505859375e02, + 1.863398132324218750e02, + 1.354266510009765625e02, + 1.200299377441406250e02, + 1.158982009887695312e02, + 1.110142211914062500e02, + 1.098877258300781250e02, + 1.125187149047851562e02, + 1.181549377441406250e02, + 1.113884735107421875e02, + 1.065119781494140625e02, + 1.335516510009765625e02, + 1.425636291503906250e02, + 1.694386291503906250e02, + 1.718562927246093750e02, + 1.916429595947265625e02, + 2.276871185302734375e02, + 2.446556854248046875e02, + 2.371145172119140625e02, + 2.333420715332031250e02, + 2.186302337646484375e02, + 2.431474609375000000e02, + 2.593675231933593750e02, + 2.529565887451171875e02, + 2.537088317871093750e02, + 2.352245483398437500e02, + 1.369311370849609375e02, + 1.219049377441406250e02, + 1.173989486694335938e02, + 1.132709579467773438e02, + 1.125187149047851562e02, + 1.151459579467773438e02, + 1.204079360961914062e02, + 1.140194625854492188e02, + 1.102619781494140625e02, + 1.327994079589843750e02, + 1.433121185302734375e02, + 1.943188629150390625e02, + 1.980651245117187500e02, + 2.193824920654296875e02, + 2.363622741699218750e02, + 2.540868225097656250e02, + 2.544648132324218750e02, + 2.593712463378906250e02, + 2.537125701904296875e02, + 2.457896728515625000e02, + 2.578592834472656250e02, + 2.601235046386718750e02, + 2.533308410644531250e02, + 2.446594238281250000e02, + 1.469086761474609375e02, + 1.222829360961914062e02, + 1.185291900634765625e02, + 1.147717056274414062e02, + 1.136414642333984375e02, + 1.151459579467773438e02, + 1.215344314575195312e02, + 1.136414642333984375e02, + 1.102619781494140625e02, + 1.331736602783203125e02, + 1.436901245117187500e02, + 1.984767913818359375e02, + 1.948652648925781250e02, + 2.276796417236328125e02, + 2.597492370605468750e02, + 2.631399841308593750e02, + 2.593675231933593750e02, + 2.604977416992187500e02, + 2.805651245117187500e02, + 2.882223205566406250e02, + 2.730014953613281250e02, + 2.786601867675781250e02, + 2.714932556152343750e02, + 2.461676635742187500e02, + 1.472866821289062500e02, + 1.256661682128906250e02, + 1.222829360961914062e02, + 1.166504516601562500e02, + ], + [ + 1.792779235839843750e02, + 1.680313415527343750e02, + 1.694005432128906250e02, + 1.666689300537109375e02, + 1.663317413330078125e02, + 1.772343292236328125e02, + 2.585524597167968750e02, + 2.466008148193359375e02, + 2.790497131347656250e02, + 2.776839294433593750e02, + 2.957867736816406250e02, + 3.036410217285156250e02, + 2.995436096191406250e02, + 3.166212463378906250e02, + 3.152520446777343750e02, + 3.149114379882812500e02, + 3.258412780761718750e02, + 3.009093933105468750e02, + 3.087636108398437500e02, + 3.152520446777343750e02, + 3.179870605468750000e02, + 3.080824279785156250e02, + 2.722173156738281250e02, + 2.232629394531250000e02, + 2.017779235839843750e02, + 1.908719329833984375e02, + 1.860967254638671875e02, + 1.802997283935546875e02, + 1.782561340332031250e02, + 1.860933227539062500e02, + 2.691450805664062500e02, + 2.848569335937500000e02, + 3.053474121093750000e02, + 3.067132263183593750e02, + 3.442847290039062500e02, + 3.296015014648437500e02, + 3.268699035644531250e02, + 3.381403198242187500e02, + 3.203746643066406250e02, + 3.319856872558593750e02, + 3.405279235839843750e02, + 3.094482421875000000e02, + 3.073978271484375000e02, + 3.220844726562500000e02, + 3.265258789062500000e02, + 3.121764221191406250e02, + 2.705075073242187500e02, + 2.285047760009765625e02, + 2.080279235839843750e02, + 1.952997283935546875e02, + 1.932561340332031250e02, + 1.891689300537109375e02, + 1.867813415527343750e02, + 1.939373321533203125e02, + 2.598126831054687500e02, + 2.913453674316406250e02, + 3.084230346679687500e02, + 3.128610229492187500e02, + 3.299421081542968750e02, + 3.436069335937500000e02, + 3.336954956054687500e02, + 3.504325561523437500e02, + 3.330143127441406250e02, + 3.220810546875000000e02, + 3.391621398925781250e02, + 3.183310546875000000e02, + 3.050068054199218750e02, + 3.261818847656250000e02, + 3.299421081542968750e02, + 3.125204467773437500e02, + 2.691450805664062500e02, + 2.244005432128906250e02, + 2.059877319335937500e02, + 1.949625396728515625e02, + 1.925749359130859375e02, + 1.888283386230468750e02, + 1.884877319335937500e02, + 2.018937377929687500e02, + 2.729019165039062500e02, + 2.940769653320312500e02, + 3.244788818359375000e02, + 3.173024597167968750e02, + 3.381403198242187500e02, + 3.272104797363281250e02, + 3.466791687011718750e02, + 3.354053039550781250e02, + 3.289169006347656250e02, + 3.401873168945312500e02, + 3.425749206542968750e02, + 3.149114379882812500e02, + 3.039850158691406250e02, + 3.227656555175781250e02, + 3.237908630371093750e02, + 3.128610229492187500e02, + 2.708480834960937500e02, + 2.315735626220703125e02, + 1.884809265136718750e02, + 1.803031311035156250e02, + 1.813249359130859375e02, + 1.809809265136718750e02, + 1.799591217041015625e02, + 2.236273803710937500e02, + 2.558242492675781250e02, + 2.780245361328125000e02, + 3.012465820312500000e02, + 3.026158142089843750e02, + 3.309673156738281250e02, + 3.227690734863281250e02, + 3.282322998046875000e02, + 3.265258789062500000e02, + 3.357459106445312500e02, + 3.347207031250000000e02, + 3.381403198242187500e02, + 3.149148559570312500e02, + 3.162806396484375000e02, + 3.237908630371093750e02, + 3.097922363281250000e02, + 3.036410217285156250e02, + 2.612874755859375000e02, + 2.131403198242187500e02, + 1.905313415527343750e02, + 1.799625396728515625e02, + 1.782561340332031250e02, + 1.799591217041015625e02, + 1.779155273437500000e02, + 2.352384185791015625e02, + 2.548024597167968750e02, + 2.619686584472656250e02, + 2.992030029296875000e02, + 2.961273803710937500e02, + 3.138862304687500000e02, + 3.343800964355468750e02, + 3.330143127441406250e02, + 3.227690734863281250e02, + 3.152520446777343750e02, + 3.234502868652343750e02, + 3.323330993652343750e02, + 3.145708312988281250e02, + 3.002247924804687500e02, + 3.142302551269531250e02, + 3.193528747558593750e02, + 3.036376037597656250e02, + 2.551464538574218750e02, + 2.176907348632812500e02, + 1.946185302734375000e02, + 1.799625396728515625e02, + 1.802997283935546875e02, + 1.772343292236328125e02, + 1.809809265136718750e02, + 2.389918212890625000e02, + 2.565054626464843750e02, + 2.534298400878906250e02, + 2.937363891601562500e02, + 2.971525878906250000e02, + 3.091076354980468750e02, + 3.309638977050781250e02, + 3.336954956054687500e02, + 3.323297119140625000e02, + 3.251566772460937500e02, + 3.504359741210937500e02, + 3.449693603515625000e02, + 3.149114379882812500e02, + 3.104734191894531250e02, + 3.268630676269531250e02, + 3.268664855957031250e02, + 3.094482421875000000e02, + 2.691450805664062500e02, + 2.232697601318359375e02, + ], + [ + 3.947381896972656250e02, + 3.778378295898437500e02, + 3.702280273437500000e02, + 3.660050659179687500e02, + 5.198479614257812500e02, + 5.274493408203125000e02, + 5.730996704101562500e02, + 5.959121704101562500e02, + 6.664611206054687500e02, + 7.358361206054687500e02, + 7.853547363281250000e02, + 8.268750000000000000e02, + 8.395861206054687500e02, + 8.141723022460937500e02, + 7.938428955078125000e02, + 8.760134887695312500e02, + 8.827871704101562500e02, + 8.395861206054687500e02, + 7.845185546875000000e02, + 7.489273681640625000e02, + 7.396115112304687500e02, + 5.734966430664062500e02, + 4.420692443847656250e02, + 4.124830932617187500e02, + 4.260135192871093750e02, + 4.116385192871093750e02, + 4.048817443847656250e02, + 4.031925659179687500e02, + 5.739357910156250000e02, + 5.967567749023437500e02, + 6.246536865234375000e02, + 6.195861206054687500e02, + 6.803969726562500000e02, + 7.751942749023437500e02, + 8.345017089843750000e02, + 7.743496704101562500e02, + 8.057009887695312500e02, + 8.124746704101562500e02, + 8.209459228515625000e02, + 8.539949340820312500e02, + 8.624661865234375000e02, + 8.192482910156250000e02, + 7.955321044921875000e02, + 7.692736206054687500e02, + 7.235134887695312500e02, + 5.726942749023437500e02, + 4.370101318359375000e02, + 4.133446044921875000e02, + 4.243243103027343750e02, + 4.133361511230468750e02, + 4.175675659179687500e02, + 4.226351318359375000e02, + 6.288598022460937500e02, + 6.711232910156250000e02, + 6.356503295898437500e02, + 6.880321044921875000e02, + 7.989189453125000000e02, + 8.972044067382812500e02, + 9.039780273437500000e02, + 8.709290771484375000e02, + 8.870354614257812500e02, + 8.734797363281250000e02, + 8.522973022460937500e02, + 8.921199340820312500e02, + 8.904138793945312500e02, + 8.607685546875000000e02, + 8.175591430664062500e02, + 7.299240112304687500e02, + 7.065709228515625000e02, + 5.743750000000000000e02, + 4.446114807128906250e02, + 3.998057556152343750e02, + 4.184121704101562500e02, + 3.998057556152343750e02, + 4.048733215332031250e02, + 4.065709533691406250e02, + 5.190033569335937500e02, + 5.959121704101562500e02, + 6.178969726562500000e02, + 6.588767089843750000e02, + 7.070354614257812500e02, + 7.896115112304687500e02, + 7.921536865234375000e02, + 7.870607910156250000e02, + 8.260303955078125000e02, + 7.921452636718750000e02, + 8.031588134765625000e02, + 8.683868408203125000e02, + 8.311148681640625000e02, + 7.896030273437500000e02, + 7.565625000000000000e02, + 7.112838134765625000e02, + 6.664696044921875000e02, + 5.337584228515625000e02, + 3.998057556152343750e02, + 3.803716125488281250e02, + 3.854391784667968750e02, + 3.676858215332031250e02, + 3.634628295898437500e02, + 3.685473022460937500e02, + 4.919510192871093750e02, + 5.536571044921875000e02, + 5.578800659179687500e02, + 5.790033569335937500e02, + 6.474830932617187500e02, + 7.463935546875000000e02, + 7.879053955078125000e02, + 8.090878295898437500e02, + 8.073901977539062500e02, + 7.692651977539062500e02, + 8.277280273437500000e02, + 8.700844726562500000e02, + 8.794088134765625000e02, + 8.345101318359375000e02, + 8.667060546875000000e02, + 7.523226318359375000e02, + 7.353800659179687500e02, + 6.554982910156250000e02, + 5.232178955078125000e02, + 3.938935852050781250e02, + 4.116469726562500000e02, + 3.896790466308593750e02, + 3.820608215332031250e02, + 3.854391784667968750e02, + 5.274408569335937500e02, + 6.001351318359375000e02, + 6.212753295898437500e02, + 6.322634887695312500e02, + 7.045017089843750000e02, + 8.387500000000000000e02, + 8.997381591796875000e02, + 8.912669067382812500e02, + 8.988851318359375000e02, + 8.548310546875000000e02, + 8.912669067382812500e02, + 9.107601318359375000e02, + 9.132939453125000000e02, + 8.497550659179687500e02, + 9.056672363281250000e02, + 8.285726318359375000e02, + 7.514780273437500000e02, + 6.580405273437500000e02, + 5.257601318359375000e02, + 3.837500000000000000e02, + 3.981250000000000000e02, + 3.964358215332031250e02, + 3.719172363281250000e02, + 3.829053955078125000e02, + 5.206756591796875000e02, + 5.705574340820312500e02, + 5.553463134765625000e02, + 6.233530273437500000e02, + 6.994172363281250000e02, + 8.336571044921875000e02, + 9.226182250976562500e02, + 9.056672363281250000e02, + 9.276942749023437500e02, + 8.683952636718750000e02, + 8.929560546875000000e02, + 1.004788879394531250e03, + 9.675169067382812500e02, + 9.285473022460937500e02, + 8.827955932617187500e02, + 8.014611206054687500e02, + 7.599493408203125000e02, + 5.938006591796875000e02, + 4.463006896972656250e02, + 3.896790466308593750e02, + ], + [ + 4.006647109985351562e01, + 3.545051574707031250e01, + 2.289512634277343750e01, + 1.772525787353515625e01, + 1.070901012420654297e01, + 9.231905937194824219e00, + 7.016248226165771484e00, + 1.070901012420654297e01, + 1.920236396789550781e01, + 3.766617584228515625e01, + 4.523633575439453125e01, + 5.409896469116210938e01, + 6.776218414306640625e01, + 6.720827484130859375e01, + 5.816100311279296875e01, + 5.243722152709960938e01, + 5.649925994873046875e01, + 5.243722152709960938e01, + 6.037666320800781250e01, + 6.573117065429687500e01, + 6.831610107421875000e01, + 6.185376739501953125e01, + 5.391432952880859375e01, + 4.911373519897460938e01, + 4.689807891845703125e01, + 2.861890602111816406e01, + 2.437223052978515625e01, + 1.901772499084472656e01, + 1.052437210083007812e01, + 8.124076843261718750e00, + 7.385524272918701172e00, + 1.107828617095947266e01, + 2.621861076354980469e01, + 4.671343994140625000e01, + 5.539143371582031250e01, + 5.889955520629882812e01, + 7.200886535644531250e01, + 6.776218414306640625e01, + 6.093057632446289062e01, + 5.742245101928710938e01, + 5.391432952880859375e01, + 5.612998580932617188e01, + 5.723781204223632812e01, + 6.794682312011718750e01, + 6.443869781494140625e01, + 6.277695846557617188e01, + 6.517725372314453125e01, + 5.225258636474609375e01, + 5.040620422363281250e01, + 3.489660263061523438e01, + 2.381831550598144531e01, + 1.790989685058593750e01, + 1.181683921813964844e01, + 8.493352890014648438e00, + 8.493352890014648438e00, + 1.827917289733886719e01, + 3.175775527954101562e01, + 5.003692626953125000e01, + 5.649925994873046875e01, + 6.739290618896484375e01, + 7.090103149414062500e01, + 6.517725372314453125e01, + 5.889955520629882812e01, + 5.760708999633789062e01, + 5.760708999633789062e01, + 5.594534683227539062e01, + 5.889955520629882812e01, + 7.348596954345703125e01, + 7.256277465820312500e01, + 6.517725372314453125e01, + 5.982274627685546875e01, + 5.631462478637695312e01, + 4.689807891845703125e01, + 3.101920318603515625e01, + 2.511078262329101562e01, + 2.344903945922851562e01, + 1.144756317138671875e01, + 7.570162296295166016e00, + 7.200886249542236328e00, + 1.347858238220214844e01, + 2.566469764709472656e01, + 4.209748840332031250e01, + 4.745199584960937500e01, + 6.425405883789062500e01, + 7.920974731445312500e01, + 6.905464935302734375e01, + 5.539143371582031250e01, + 5.280649948120117188e01, + 5.059084320068359375e01, + 5.077547836303710938e01, + 5.631462478637695312e01, + 6.462333679199218750e01, + 7.163958740234375000e01, + 5.631462478637695312e01, + 5.668389892578125000e01, + 4.412850952148437500e01, + 4.080502319335937500e01, + 3.157311630249023438e01, + 2.234121131896972656e01, + 2.012555313110351562e01, + 1.532496261596679688e01, + 9.047266960144042969e00, + 9.970458030700683594e00, + 1.366322040557861328e01, + 2.677252578735351562e01, + 4.375923156738281250e01, + 5.483751678466796875e01, + 6.351551055908203125e01, + 7.994830322265625000e01, + 7.330133056640625000e01, + 6.240768051147460938e01, + 6.351551055908203125e01, + 5.834564208984375000e01, + 5.631462478637695312e01, + 7.200886535644531250e01, + 7.477843475341796875e01, + 7.071639251708984375e01, + 6.517725372314453125e01, + 5.409896469116210938e01, + 4.966765213012695312e01, + 4.560561370849609375e01, + 3.526588058471679688e01, + 2.307976341247558594e01, + 2.455686759948730469e01, + 1.366322040557861328e01, + 8.862628936767578125e00, + 7.016248226165771484e00, + 1.403249645233154297e01, + 2.437223052978515625e01, + 4.394387054443359375e01, + 5.132939529418945312e01, + 6.277695846557617188e01, + 7.293205261230468750e01, + 6.591580200195312500e01, + 6.351551055908203125e01, + 6.296159362792968750e01, + 6.333087158203125000e01, + 6.628507995605468750e01, + 7.182422637939453125e01, + 7.754800415039062500e01, + 7.644017791748046875e01, + 6.259231948852539062e01, + 5.280649948120117188e01, + 5.409896469116210938e01, + 4.911373519897460938e01, + 3.120384025573730469e01, + 2.474150657653808594e01, + 2.049483108520507812e01, + 1.310930538177490234e01, + 9.601181983947753906e00, + 8.493352890014648438e00, + 1.643279266357421875e01, + 5.096011734008789062e01, + 4.837518310546875000e01, + 4.966765213012695312e01, + 6.333087158203125000e01, + 7.440915679931640625e01, + 7.274741363525390625e01, + 6.683899688720703125e01, + 6.517725372314453125e01, + 6.831610107421875000e01, + 6.702363586425781250e01, + 7.607089996337890625e01, + 7.662481689453125000e01, + 7.828656005859375000e01, + 6.499261474609375000e01, + 6.093057632446289062e01, + 5.040620422363281250e01, + ], [ - [ - 3.392694091796875000e02, - 3.198630065917968750e02, - 3.210045776367187500e02, - 3.175798950195312500e02, - 3.130137023925781250e02, - 4.615182800292968750e02, - 5.757534179687500000e02, - 5.871689453125000000e02, - 6.362899780273437500e02, - 6.785616455078125000e02, - 7.677511596679687500e02, - 9.501369628906250000e02, - 1.041723754882812500e03, - 1.016529663085937500e03, - 9.112100219726562500e02, - 1.019965759277343750e03, - 1.019977172851562500e03, - 1.055456665039062500e03, - 1.033710083007812500e03, - 7.803538818359375000e02, - 7.322830810546875000e02, - 4.934931640625000000e02, - 3.872488708496093750e02, - 3.655593566894531250e02, - 3.747031860351562500e02, - 3.518378906250000000e02, - 3.621347045898437500e02, - 3.427054748535156250e02, - 3.461187133789062500e02, - 4.992009277343750000e02, - 6.488584594726562500e02, - 6.145890502929687500e02, - 6.888356323242187500e02, - 7.624086914062500000e02, - 7.517237548828125000e02, - 9.593036499023437500e02, - 9.810502319335937500e02, - 1.038287719726562500e03, - 9.249543457031250000e02, - 1.033710083007812500e03, - 1.079497680664062500e03, - 1.050879028320312500e03, - 9.673059082031250000e02, - 7.490639038085937500e02, - 7.319063720703125000e02, - 4.854908752441406250e02, - 3.826940612792968750e02, - 3.644178161621093750e02, - 3.667009277343750000e02, - 3.506963500976562500e02, - 3.438470458984375000e02, - 3.267123413085937500e02, - 3.347031860351562500e02, - 4.923515930175781250e02, - 6.385730590820312500e02, - 6.180136718750000000e02, - 6.766895141601562500e02, - 6.880822143554687500e02, - 7.643150634765625000e02, - 1.010810485839843750e03, - 1.053173461914062500e03, - 1.042876708984375000e03, - 9.570091552734375000e02, - 1.054303710937500000e03, - 1.118424682617187500e03, - 1.089805908203125000e03, - 9.661757812500000000e02, - 7.852968139648437500e02, - 7.608903808593750000e02, - 4.912100524902343750e02, - 3.803995361328125000e02, - 3.586986389160156250e02, - 3.667009277343750000e02, - 3.461187133789062500e02, - 3.609817199707031250e02, - 3.449771728515625000e02, - 3.404109497070312500e02, - 4.843721313476562500e02, - 6.560958862304687500e02, - 6.214497680664062500e02, - 6.937899780273437500e02, - 7.357191772460937500e02, - 7.421917724609375000e02, - 9.673173217773437500e02, - 9.375456542968750000e02, - 8.986301269531250000e02, - 9.787671508789062500e02, - 1.026849365234375000e03, - 9.684703369140625000e02, - 9.066552734375000000e02, - 8.253767089843750000e02, - 7.277055053710937500e02, - 7.002625732421875000e02, - 4.809360656738281250e02, - 3.712785339355468750e02, - 3.518378906250000000e02, - 3.552625427246093750e02, - 3.426940612792968750e02, - 3.438470458984375000e02, - 3.324201049804687500e02, - 3.278538818359375000e02, - 4.923401794433593750e02, - 6.351483764648437500e02, - 6.054566040039062500e02, - 6.796917724609375000e02, - 6.983789672851562500e02, - 7.273287963867187500e02, - 9.192237548828125000e02, - 9.512785644531250000e02, - 9.547260131835937500e02, - 9.169406127929687500e02, - 1.023401855468750000e03, - 1.011940612792968750e03, - 1.018824218750000000e03, - 9.146575317382812500e02, - 8.070548095703125000e02, - 7.444863281250000000e02, - 5.072031860351562500e02, - 3.998287658691406250e02, - 3.575570678710937500e02, - 3.598401794433593750e02, - 3.438356018066406250e02, - 3.472716979980468750e02, - 3.381278686523437500e02, - 3.426940612792968750e02, - 6.206963500976562500e02, - 5.837442626953125000e02, - 6.111643676757812500e02, - 7.471689453125000000e02, - 7.593721313476562500e02, - 8.070548095703125000e02, - 9.604451904296875000e02, - 1.025707763671875000e03, - 9.741781005859375000e02, - 9.364155273437500000e02, - 1.019965759277343750e03, - 1.085216918945312500e03, - 9.970662231445312500e02, - 8.929109497070312500e02, - 7.597374267578125000e02, - 7.505822143554687500e02, - 5.083447570800781250e02, - 3.986872253417968750e02, - 3.689954223632812500e02, - 3.769862976074218750e02, - 3.506849365234375000e02, - 3.552739868164062500e02, - 3.404223632812500000e02, - 3.312785339355468750e02, - 5.711757812500000000e02, - 5.403310546875000000e02, - 6.317237548828125000e02, - 7.059703369140625000e02, - 7.517579956054687500e02, - 8.024657592773437500e02, - 1.046301391601562500e03, - 1.109269409179687500e03, - 1.121860717773437500e03, - 1.015388122558593750e03, - 1.094383544921875000e03, - 1.151620971679687500e03, - 1.062317382812500000e03, - 1.077191772460937500e03, - 8.036187133789062500e02, - 7.216096191406250000e02, - 5.003424682617187500e02, - 3.964041137695312500e02, - 3.735616455078125000e02, - ], - [ - 2.263374519348144531e01, - 2.932098770141601562e01, - 2.932098770141601562e01, - 2.949245452880859375e01, - 2.880658531188964844e01, - 2.657750320434570312e01, - 2.263374519348144531e01, - 1.886145401000976562e01, - 2.434842300415039062e01, - 3.137860107421875000e01, - 2.400548744201660156e01, - 2.897805213928222656e01, - 3.377914810180664062e01, - 3.360768127441406250e01, - 3.240740585327148438e01, - 3.275034332275390625e01, - 2.880658531188964844e01, - 2.743484306335449219e01, - 3.755144119262695312e01, - 3.635116577148437500e01, - 3.703703689575195312e01, - 4.098079681396484375e01, - 3.275034332275390625e01, - 2.897805213928222656e01, - 3.086419677734375000e01, - 2.726337432861328125e01, - 1.920438957214355469e01, - 2.777777862548828125e01, - 2.846364974975585938e01, - 2.331961631774902344e01, - 1.406035709381103516e01, - 2.640603637695312500e01, - 2.537722969055175781e01, - 2.572016525268554688e01, - 3.240740585327148438e01, - 2.743484306335449219e01, - 2.897805213928222656e01, - 2.726337432861328125e01, - 2.006172752380371094e01, - 3.069272994995117188e01, - 3.000685882568359375e01, - 2.503429412841796875e01, - 3.000685882568359375e01, - 3.858024597167968750e01, - 3.789437484741210938e01, - 3.240740585327148438e01, - 3.480795669555664062e01, - 2.589163208007812500e01, - 2.263374519348144531e01, - 1.954732513427734375e01, - 1.920438957214355469e01, - 1.817558288574218750e01, - 2.006172752380371094e01, - 1.543209838867187500e01, - 1.406035709381103516e01, - 2.280521202087402344e01, - 3.275034332275390625e01, - 3.412208557128906250e01, - 3.360768127441406250e01, - 3.446501922607421875e01, - 3.515089035034179688e01, - 3.395061874389648438e01, - 2.846364974975585938e01, - 2.949245452880859375e01, - 3.017832565307617188e01, - 2.777777862548828125e01, - 3.446501922607421875e01, - 4.526749038696289062e01, - 3.943758392333984375e01, - 4.818244171142578125e01, - 3.737997436523437500e01, - 3.480795669555664062e01, - 2.400548744201660156e01, - 2.623456764221191406e01, - 2.760630989074707031e01, - 2.846364974975585938e01, - 2.932098770141601562e01, - 1.851851844787597656e01, - 2.383401870727539062e01, - 2.914951896667480469e01, - 2.897805213928222656e01, - 2.709190750122070312e01, - 3.343621444702148438e01, - 3.292181015014648438e01, - 3.772290802001953125e01, - 3.858024597167968750e01, - 3.326474761962890625e01, - 2.589163208007812500e01, - 3.172153663635253906e01, - 3.377914810180664062e01, - 4.406721496582031250e01, - 4.749657058715820312e01, - 4.406721496582031250e01, - 3.446501922607421875e01, - 3.960905456542968750e01, - 3.463648986816406250e01, - 3.189300346374511719e01, - 3.034979438781738281e01, - 3.069272994995117188e01, - 2.880658531188964844e01, - 2.897805213928222656e01, - 2.143346977233886719e01, - 2.331961631774902344e01, - 2.520576095581054688e01, - 2.109053421020507812e01, - 2.777777862548828125e01, - 3.000685882568359375e01, - 3.017832565307617188e01, - 3.343621444702148438e01, - 3.223593902587890625e01, - 2.503429412841796875e01, - 3.446501922607421875e01, - 3.343621444702148438e01, - 2.812071418762207031e01, - 2.897805213928222656e01, - 3.995198822021484375e01, - 3.377914810180664062e01, - 3.515089035034179688e01, - 3.703703689575195312e01, - 2.897805213928222656e01, - 2.366255187988281250e01, - 2.812071418762207031e01, - 2.760630989074707031e01, - 2.143346977233886719e01, - 2.743484306335449219e01, - 2.589163208007812500e01, - 2.177640533447265625e01, - 2.057613182067871094e01, - 3.292181015014648438e01, - 2.486282539367675781e01, - 2.589163208007812500e01, - 4.200960159301757812e01, - 3.086419677734375000e01, - 3.412208557128906250e01, - 3.840877914428710938e01, - 3.155006790161132812e01, - 3.600823211669921875e01, - 3.669410324096679688e01, - 3.446501922607421875e01, - 3.720850372314453125e01, - 4.406721496582031250e01, - 3.326474761962890625e01, - 3.532236099243164062e01, - 3.858024597167968750e01, - 3.275034332275390625e01, - 2.194787406921386719e01, - 2.897805213928222656e01, - 2.863511657714843750e01, - 2.469135856628417969e01, - 2.331961631774902344e01, - 2.554869651794433594e01, - 2.143346977233886719e01, - 2.486282539367675781e01, - 2.743484306335449219e01, - 2.383401870727539062e01, - 3.069272994995117188e01, - 2.794924545288085938e01, - 2.812071418762207031e01, - 2.434842300415039062e01, - 1.989026069641113281e01, - 1.971879196166992188e01, - 2.194787406921386719e01, - 2.486282539367675781e01, - 3.360768127441406250e01, - 3.343621444702148438e01, - 3.206447219848632812e01, - 3.034979438781738281e01, - 2.349108314514160156e01, - ], - [ - 1.214824981689453125e02, - 1.192518844604492188e02, - 1.103294448852539062e02, - 1.139327392578125000e02, - 1.137611541748046875e02, - 1.456760406494140625e02, - 1.597460479736328125e02, - 1.443033599853515625e02, - 1.691832580566406250e02, - 1.798215484619140625e02, - 1.918325347900390625e02, - 1.990391235351562500e02, - 1.932052154541015625e02, - 2.007549743652343750e02, - 2.069320526123046875e02, - 2.084763183593750000e02, - 2.134523010253906250e02, - 2.148249816894531250e02, - 2.132807159423828125e02, - 2.143102264404296875e02, - 2.086479034423828125e02, - 1.935483856201171875e02, - 1.743308105468750000e02, - 1.717570343017578125e02, - 1.662662963867187500e02, - 1.657515411376953125e02, - 1.580301971435546875e02, - 1.626630096435546875e02, - 1.640356903076171875e02, - 1.841111907958984375e02, - 2.000686340332031250e02, - 1.853122863769531250e02, - 2.057309570312500000e02, - 2.074468078613281250e02, - 2.151681518554687500e02, - 2.095058288574218750e02, - 2.088194885253906250e02, - 2.156829071044921875e02, - 2.189430389404296875e02, - 2.215168151855468750e02, - 2.287234039306640625e02, - 2.292381591796875000e02, - 2.230610809326171875e02, - 2.177419281005859375e02, - 2.115648651123046875e02, - 1.921757049560546875e02, - 1.741592254638671875e02, - 1.535689697265625000e02, - 1.590597076416015625e02, - 1.606039733886718750e02, - 1.568291015625000000e02, - 1.571722717285156250e02, - 1.575154418945312500e02, - 1.911461944580078125e02, - 1.866849670410156250e02, - 1.885724029541015625e02, - 2.033287506103515625e02, - 2.115648651123046875e02, - 2.316403503417968750e02, - 2.280370635986328125e02, - 2.302676696777343750e02, - 2.395332946777343750e02, - 2.343857269287109375e02, - 2.278654785156250000e02, - 2.285518188476562500e02, - 2.268359680175781250e02, - 2.184282836914062500e02, - 2.112216949462890625e02, - 2.155113220214843750e02, - 2.148249816894531250e02, - 1.988675384521484375e02, - 1.842827758789062500e02, - 1.726149597167968750e02, - 1.806794738769531250e02, - 1.762182617187500000e02, - 1.774193572998046875e02, - 1.835964355468750000e02, - 1.964653472900390625e02, - 2.240905914306640625e02, - 1.781056976318359375e02, - 1.878860626220703125e02, - 1.969801025390625000e02, - 2.091626586914062500e02, - 2.098489990234375000e02, - 2.163692474365234375e02, - 2.210020599365234375e02, - 2.251201171875000000e02, - 2.242621765136718750e02, - 2.196293792724609375e02, - 2.276938934326171875e02, - 2.225463256835937500e02, - 2.210020599365234375e02, - 2.047014465332031250e02, - 1.944063110351562500e02, - 1.722717895507812500e02, - 1.523678741455078125e02, - 1.559711761474609375e02, - 1.551132507324218750e02, - 1.513383636474609375e02, - 1.544269104003906250e02, - 1.530542144775390625e02, - 1.871997222900390625e02, - 1.799931335449218750e02, - 1.717570343017578125e02, - 1.935483856201171875e02, - 1.969801025390625000e02, - 2.067604675292968750e02, - 2.228894958496093750e02, - 2.105353393554687500e02, - 2.149965667724609375e02, - 2.024708251953125000e02, - 2.409059753417968750e02, - 2.491420745849609375e02, - 2.283802337646484375e02, - 2.213452301025390625e02, - 2.198009643554687500e02, - 2.007549743652343750e02, - 1.859986267089843750e02, - 1.739876403808593750e02, - 1.583733673095703125e02, - 1.376115264892578125e02, - 1.377831115722656250e02, - 1.362388458251953125e02, - 1.353809204101562500e02, - 1.431022644042968750e02, - 1.695264282226562500e02, - 1.822237548828125000e02, - 1.703843536376953125e02, - 1.896019287109375000e02, - 1.983527832031250000e02, - 2.048730316162109375e02, - 2.052162017822265625e02, - 2.062457122802734375e02, - 2.081331481933593750e02, - 2.047014465332031250e02, - 2.149965667724609375e02, - 2.047014465332031250e02, - 2.055593719482421875e02, - 2.227179107666015625e02, - 2.113932800292968750e02, - 1.998970489501953125e02, - 1.817089843750000000e02, - 1.657515411376953125e02, - 1.499656829833984375e02, - 1.364104309082031250e02, - 1.312628631591796875e02, - 1.283459167480468750e02, - 1.250857925415039062e02, - 1.264584732055664062e02, - 1.654083709716796875e02, - 1.798215484619140625e02, - 1.643788604736328125e02, - 1.904598541259765625e02, - 1.921757049560546875e02, - 2.069320526123046875e02, - 2.215168151855468750e02, - 2.185998687744140625e02, - 2.246053466796875000e02, - 2.259780426025390625e02, - 2.304392547607421875e02, - 2.283802337646484375e02, - 2.264927978515625000e02, - 2.222031555175781250e02, - 2.115648651123046875e02, - 2.010981445312500000e02, - 1.944063110351562500e02, - 1.676389770507812500e02, - 1.460192108154296875e02, - ], - [ - 6.946183013916015625e01, - 6.226533126831054688e01, - 6.007509231567382812e01, - 5.913642120361328125e01, - 5.788485717773437500e01, - 6.289111328125000000e01, - 4.787234115600585938e01, - 5.287859725952148438e01, - 5.287859725952148438e01, - 6.038798522949218750e01, - 6.570713043212890625e01, - 7.196495819091796875e01, - 7.352941131591796875e01, - 6.007509231567382812e01, - 5.851063919067382812e01, - 6.695870208740234375e01, - 6.289111328125000000e01, - 7.259073638916015625e01, - 9.355444335937500000e01, - 8.573216247558593750e01, - 9.793492126464843750e01, - 1.013767242431640625e02, - 1.004380493164062500e02, - 7.884856414794921875e01, - 6.727159118652343750e01, - 6.351689529418945312e01, - 6.351689529418945312e01, - 6.226533126831054688e01, - 6.289111328125000000e01, - 6.821026611328125000e01, - 4.630788421630859375e01, - 5.162703323364257812e01, - 5.037546920776367188e01, - 6.070087432861328125e01, - 6.821026611328125000e01, - 7.446808624267578125e01, - 7.415519714355468750e01, - 6.758448028564453125e01, - 6.539424133300781250e01, - 6.602002716064453125e01, - 6.226533126831054688e01, - 7.133917236328125000e01, - 8.854818725585937500e01, - 8.823529052734375000e01, - 1.041927413940429688e02, - 1.029411773681640625e02, - 9.699624633789062500e01, - 7.634542846679687500e01, - 6.758448028564453125e01, - 6.445556640625000000e01, - 6.351689529418945312e01, - 6.101376724243164062e01, - 5.976220321655273438e01, - 6.789736938476562500e01, - 5.225281524658203125e01, - 4.724655914306640625e01, - 5.131414413452148438e01, - 6.508135223388671875e01, - 7.478097534179687500e01, - 7.384230041503906250e01, - 6.163954925537109375e01, - 6.476846313476562500e01, - 6.414267730712890625e01, - 5.694618225097656250e01, - 6.508135223388671875e01, - 6.914893341064453125e01, - 7.947434234619140625e01, - 8.604505920410156250e01, - 1.001251602172851562e02, - 9.793492126464843750e01, - 9.167709350585937500e01, - 7.790988922119140625e01, - 6.789736938476562500e01, - 6.633291625976562500e01, - 6.351689529418945312e01, - 6.226533126831054688e01, - 6.195244216918945312e01, - 6.476846313476562500e01, - 5.225281524658203125e01, - 5.068836212158203125e01, - 5.413016128540039062e01, - 5.757196426391601562e01, - 7.352941131591796875e01, - 7.321652221679687500e01, - 7.196495819091796875e01, - 6.163954925537109375e01, - 5.882352828979492188e01, - 6.414267730712890625e01, - 6.476846313476562500e01, - 7.790988922119140625e01, - 9.824781036376953125e01, - 8.604505920410156250e01, - 1.026282882690429688e02, - 1.091989974975585938e02, - 1.073216552734375000e02, - 8.541927337646484375e01, - 7.415519714355468750e01, - 6.852315521240234375e01, - 6.289111328125000000e01, - 6.070087432861328125e01, - 6.101376724243164062e01, - 6.414267730712890625e01, - 5.256570816040039062e01, - 4.943679428100585938e01, - 6.070087432861328125e01, - 7.571965026855468750e01, - 8.792240142822265625e01, - 8.698372650146484375e01, - 7.822277832031250000e01, - 7.133917236328125000e01, - 7.227784729003906250e01, - 6.602002716064453125e01, - 6.883604431152343750e01, - 6.977471923828125000e01, - 7.978723144531250000e01, - 9.042552947998046875e01, - 1.057571945190429688e02, - 1.007509384155273438e02, - 9.762202453613281250e01, - 8.479349517822265625e01, - 7.165206146240234375e01, - 6.821026611328125000e01, - 6.226533126831054688e01, - 6.195244216918945312e01, - 7.478097534179687500e01, - 5.882352828979492188e01, - 4.974968719482421875e01, - 5.381727218627929688e01, - 6.007509231567382812e01, - 7.259073638916015625e01, - 7.603253936767578125e01, - 8.573216247558593750e01, - 8.698372650146484375e01, - 8.041301727294921875e01, - 7.853566741943359375e01, - 8.197747039794921875e01, - 7.603253936767578125e01, - 8.948686218261718750e01, - 8.948686218261718750e01, - 9.230287933349609375e01, - 9.230287933349609375e01, - 9.511889648437500000e01, - 9.543179321289062500e01, - 8.823529052734375000e01, - 8.166458129882812500e01, - 7.133917236328125000e01, - 6.195244216918945312e01, - 6.038798522949218750e01, - 5.976220321655273438e01, - 6.883604431152343750e01, - 4.849812316894531250e01, - 5.319149017333984375e01, - 6.382978820800781250e01, - 7.634542846679687500e01, - 9.261576843261718750e01, - 1.001251602172851562e02, - 8.573216247558593750e01, - 6.539424133300781250e01, - 5.600751113891601562e01, - 5.694618225097656250e01, - 5.538172531127929688e01, - 6.382978820800781250e01, - 7.790988922119140625e01, - 8.197747039794921875e01, - 1.029411773681640625e02, - 1.135794754028320312e02, - 1.004380493164062500e02, - 8.479349517822265625e01, - ], - [ - 4.069791030883789062e01, - 4.158940505981445312e01, - 3.891365432739257812e01, - 3.942435073852539062e01, - 4.057055664062500000e01, - 6.334054946899414062e01, - 5.866785430908203125e01, - 6.113219451904296875e01, - 7.438232421875000000e01, - 8.779418945312500000e01, - 9.533239746093750000e01, - 9.737519073486328125e01, - 9.609780883789062500e01, - 9.724783325195312500e01, - 9.839658355712890625e01, - 1.000573120117187500e02, - 1.004406509399414062e02, - 1.027394256591796875e02, - 9.839658355712890625e01, - 1.037621002197265625e02, - 9.520503997802734375e01, - 7.690779113769531250e01, - 4.783494567871093750e01, - 4.464849853515625000e01, - 4.490448379516601562e01, - 4.630667495727539062e01, - 4.362837600708007812e01, - 4.528655242919921875e01, - 4.452114105224609375e01, - 6.675624084472656250e01, - 7.067881011962890625e01, - 6.830870819091796875e01, - 7.859780883789062500e01, - 9.124427032470703125e01, - 9.673586273193359375e01, - 9.367167663574218750e01, - 9.852394104003906250e01, - 9.699057769775390625e01, - 9.647988128662109375e01, - 1.008239974975585938e02, - 1.017180328369140625e02, - 1.015906753540039062e02, - 1.017193069458007812e02, - 9.545848083496093750e01, - 9.341441345214843750e01, - 7.792536926269531250e01, - 5.165945053100585938e01, - 4.707080841064453125e01, - 4.719816589355468750e01, - 4.643275451660156250e01, - 4.579597473144531250e01, - 4.656138610839843750e01, - 4.732552337646484375e01, - 7.882450103759765625e01, - 7.118950653076171875e01, - 7.361691284179687500e01, - 8.600611114501953125e01, - 1.018466644287109375e02, - 9.929190063476562500e01, - 1.051681137084960938e02, - 1.092536926269531250e02, - 1.017180328369140625e02, - 1.055514526367187500e02, - 1.087442703247070312e02, - 1.073382568359375000e02, - 1.014620513916015625e02, - 1.061895065307617188e02, - 9.954534149169921875e01, - 9.418109893798828125e01, - 7.562786865234375000e01, - 4.923586273193359375e01, - 4.630540084838867188e01, - 4.643275451660156250e01, - 4.643275451660156250e01, - 4.477585220336914062e01, - 4.541263198852539062e01, - 4.388181304931640625e01, - 6.921293640136718750e01, - 7.013372039794921875e01, - 6.910977935791015625e01, - 8.690015411376953125e01, - 1.012073364257812500e02, - 1.023560867309570312e02, - 1.006953659057617188e02, - 9.967396545410156250e01, - 9.980132293701171875e01, - 1.022287292480468750e02, - 1.095096817016601562e02, - 1.096370315551757812e02, - 1.024847183227539062e02, - 1.031240463256835938e02, - 1.026120758056640625e02, - 9.277508544921875000e01, - 7.511716461181640625e01, - 4.872771453857421875e01, - 4.273560714721679688e01, - 4.311894989013671875e01, - 4.107997894287109375e01, - 3.980641937255859375e01, - 4.006113052368164062e01, - 3.853158569335937500e01, - 6.669383239746093750e01, - 6.282093811035156250e01, - 6.205807495117187500e01, - 8.000127410888671875e01, - 9.022160339355468750e01, - 9.839531707763671875e01, - 9.941798400878906250e01, - 1.023573608398437500e02, - 1.008239974975585938e02, - 1.035061111450195312e02, - 1.044001541137695312e02, - 1.088716278076171875e02, - 1.124477844238281250e02, - 1.075929718017578125e02, - 1.008239974975585938e02, - 9.622644042968750000e01, - 9.073229980468750000e01, - 6.257386779785156250e01, - 4.311894989013671875e01, - 4.299032211303710938e01, - 4.273815536499023438e01, - 4.095262527465820312e01, - 4.082526779174804688e01, - 4.146204757690429688e01, - 7.409832000732421875e01, - 6.358507537841796875e01, - 6.754330444335937500e01, - 8.396331787109375000e01, - 9.405374145507812500e01, - 1.013346939086914062e02, - 1.063155899047851562e02, - 1.123204269409179688e02, - 1.139798812866210938e02, - 1.114251174926757812e02, - 1.038894577026367188e02, - 1.175560379028320312e02, - 1.093823242187500000e02, - 1.097656631469726562e02, - 1.035061111450195312e02, - 9.890728759765625000e01, - 9.341568756103515625e01, - 6.959373474121093750e01, - 4.503183746337890625e01, - 4.515919494628906250e01, - 4.388308715820312500e01, - 4.069791030883789062e01, - 4.082526779174804688e01, - 4.069791030883789062e01, - 6.576541137695312500e01, - 6.329724884033203125e01, - 6.716250610351562500e01, - 8.204534149169921875e01, - 9.405374145507812500e01, - 1.003120193481445312e02, - 1.045287857055664062e02, - 1.082335739135742188e02, - 1.067002029418945312e02, - 1.054215469360351562e02, - 1.070835418701171875e02, - 1.100216522216796875e02, - 1.037608261108398438e02, - 9.865257263183593750e01, - 9.775852966308593750e01, - 9.711793518066406250e01, - 8.137290191650390625e01, - 4.974656295776367188e01, - 4.464849853515625000e01, - ], - [ - 1.140194625854492188e02, - 1.181511993408203125e02, - 1.222866744995117188e02, - 1.147717056274414062e02, - 1.102619781494140625e02, - 1.170284423828125000e02, - 1.421856231689453125e02, - 1.713136291503906250e02, - 1.703555450439453125e02, - 1.906886291503906250e02, - 2.291916198730468750e02, - 2.488061370849609375e02, - 2.491841278076171875e02, - 2.295696105957031250e02, - 2.412574920654296875e02, - 2.488061370849609375e02, - 2.601235046386718750e02, - 2.408869781494140625e02, - 2.397492523193359375e02, - 2.078667602539062500e02, - 1.358046417236328125e02, - 1.249139251708984375e02, - 1.219124221801757812e02, - 1.155202102661132812e02, - 1.162761993408203125e02, - 1.196519470214843750e02, - 1.245359268188476562e02, - 1.189034423828125000e02, - 1.162761993408203125e02, - 1.410591278076171875e02, - 1.515793457031250000e02, - 2.093974609375000000e02, - 2.220247039794921875e02, - 2.397567291259765625e02, - 2.518263397216796875e02, - 2.555950622558593750e02, - 2.435254516601562500e02, - 2.242926635742187500e02, - 1.980613708496093750e02, - 2.208907165527343750e02, - 2.337163238525390625e02, - 2.310778503417968750e02, - 1.957934112548828125e02, - 1.893824920654296875e02, - 1.403068847656250000e02, - 1.264146728515625000e02, - 1.207821884155273438e02, - 1.155239486694335938e02, - 1.166504516601562500e02, - 1.189071884155273438e02, - 1.252881698608398438e02, - 1.192814407348632812e02, - 1.158982009887695312e02, - 1.414371185302734375e02, - 1.583345794677734375e02, - 2.112799377441406250e02, - 2.442814331054687500e02, - 2.499401245117187500e02, - 2.529528503417968750e02, - 2.495583801269531250e02, - 2.552208099365234375e02, - 2.488061370849609375e02, - 2.465456542968750000e02, - 2.503181152343750000e02, - 2.552208099365234375e02, - 2.559693145751953125e02, - 2.484281463623046875e02, - 2.325860748291015625e02, - 1.423989562988281250e02, - 1.241654205322265625e02, - 1.192739486694335938e02, - 1.155202102661132812e02, - 1.158944625854492188e02, - 1.185254516601562500e02, - 1.256624221801757812e02, - 1.189034423828125000e02, - 1.155239486694335938e02, - 1.403106231689453125e02, - 1.538323364257812500e02, - 2.082634735107421875e02, - 2.152357788085937500e02, - 2.340943145751953125e02, - 2.457896728515625000e02, - 2.461676635742187500e02, - 2.337163238525390625e02, - 2.363622741699218750e02, - 2.261751556396484375e02, - 2.431474609375000000e02, - 2.427694549560546875e02, - 2.205127258300781250e02, - 1.759880218505859375e02, - 1.863398132324218750e02, - 1.354266510009765625e02, - 1.200299377441406250e02, - 1.158982009887695312e02, - 1.110142211914062500e02, - 1.098877258300781250e02, - 1.125187149047851562e02, - 1.181549377441406250e02, - 1.113884735107421875e02, - 1.065119781494140625e02, - 1.335516510009765625e02, - 1.425636291503906250e02, - 1.694386291503906250e02, - 1.718562927246093750e02, - 1.916429595947265625e02, - 2.276871185302734375e02, - 2.446556854248046875e02, - 2.371145172119140625e02, - 2.333420715332031250e02, - 2.186302337646484375e02, - 2.431474609375000000e02, - 2.593675231933593750e02, - 2.529565887451171875e02, - 2.537088317871093750e02, - 2.352245483398437500e02, - 1.369311370849609375e02, - 1.219049377441406250e02, - 1.173989486694335938e02, - 1.132709579467773438e02, - 1.125187149047851562e02, - 1.151459579467773438e02, - 1.204079360961914062e02, - 1.140194625854492188e02, - 1.102619781494140625e02, - 1.327994079589843750e02, - 1.433121185302734375e02, - 1.943188629150390625e02, - 1.980651245117187500e02, - 2.193824920654296875e02, - 2.363622741699218750e02, - 2.540868225097656250e02, - 2.544648132324218750e02, - 2.593712463378906250e02, - 2.537125701904296875e02, - 2.457896728515625000e02, - 2.578592834472656250e02, - 2.601235046386718750e02, - 2.533308410644531250e02, - 2.446594238281250000e02, - 1.469086761474609375e02, - 1.222829360961914062e02, - 1.185291900634765625e02, - 1.147717056274414062e02, - 1.136414642333984375e02, - 1.151459579467773438e02, - 1.215344314575195312e02, - 1.136414642333984375e02, - 1.102619781494140625e02, - 1.331736602783203125e02, - 1.436901245117187500e02, - 1.984767913818359375e02, - 1.948652648925781250e02, - 2.276796417236328125e02, - 2.597492370605468750e02, - 2.631399841308593750e02, - 2.593675231933593750e02, - 2.604977416992187500e02, - 2.805651245117187500e02, - 2.882223205566406250e02, - 2.730014953613281250e02, - 2.786601867675781250e02, - 2.714932556152343750e02, - 2.461676635742187500e02, - 1.472866821289062500e02, - 1.256661682128906250e02, - 1.222829360961914062e02, - 1.166504516601562500e02, - ], - [ - 1.792779235839843750e02, - 1.680313415527343750e02, - 1.694005432128906250e02, - 1.666689300537109375e02, - 1.663317413330078125e02, - 1.772343292236328125e02, - 2.585524597167968750e02, - 2.466008148193359375e02, - 2.790497131347656250e02, - 2.776839294433593750e02, - 2.957867736816406250e02, - 3.036410217285156250e02, - 2.995436096191406250e02, - 3.166212463378906250e02, - 3.152520446777343750e02, - 3.149114379882812500e02, - 3.258412780761718750e02, - 3.009093933105468750e02, - 3.087636108398437500e02, - 3.152520446777343750e02, - 3.179870605468750000e02, - 3.080824279785156250e02, - 2.722173156738281250e02, - 2.232629394531250000e02, - 2.017779235839843750e02, - 1.908719329833984375e02, - 1.860967254638671875e02, - 1.802997283935546875e02, - 1.782561340332031250e02, - 1.860933227539062500e02, - 2.691450805664062500e02, - 2.848569335937500000e02, - 3.053474121093750000e02, - 3.067132263183593750e02, - 3.442847290039062500e02, - 3.296015014648437500e02, - 3.268699035644531250e02, - 3.381403198242187500e02, - 3.203746643066406250e02, - 3.319856872558593750e02, - 3.405279235839843750e02, - 3.094482421875000000e02, - 3.073978271484375000e02, - 3.220844726562500000e02, - 3.265258789062500000e02, - 3.121764221191406250e02, - 2.705075073242187500e02, - 2.285047760009765625e02, - 2.080279235839843750e02, - 1.952997283935546875e02, - 1.932561340332031250e02, - 1.891689300537109375e02, - 1.867813415527343750e02, - 1.939373321533203125e02, - 2.598126831054687500e02, - 2.913453674316406250e02, - 3.084230346679687500e02, - 3.128610229492187500e02, - 3.299421081542968750e02, - 3.436069335937500000e02, - 3.336954956054687500e02, - 3.504325561523437500e02, - 3.330143127441406250e02, - 3.220810546875000000e02, - 3.391621398925781250e02, - 3.183310546875000000e02, - 3.050068054199218750e02, - 3.261818847656250000e02, - 3.299421081542968750e02, - 3.125204467773437500e02, - 2.691450805664062500e02, - 2.244005432128906250e02, - 2.059877319335937500e02, - 1.949625396728515625e02, - 1.925749359130859375e02, - 1.888283386230468750e02, - 1.884877319335937500e02, - 2.018937377929687500e02, - 2.729019165039062500e02, - 2.940769653320312500e02, - 3.244788818359375000e02, - 3.173024597167968750e02, - 3.381403198242187500e02, - 3.272104797363281250e02, - 3.466791687011718750e02, - 3.354053039550781250e02, - 3.289169006347656250e02, - 3.401873168945312500e02, - 3.425749206542968750e02, - 3.149114379882812500e02, - 3.039850158691406250e02, - 3.227656555175781250e02, - 3.237908630371093750e02, - 3.128610229492187500e02, - 2.708480834960937500e02, - 2.315735626220703125e02, - 1.884809265136718750e02, - 1.803031311035156250e02, - 1.813249359130859375e02, - 1.809809265136718750e02, - 1.799591217041015625e02, - 2.236273803710937500e02, - 2.558242492675781250e02, - 2.780245361328125000e02, - 3.012465820312500000e02, - 3.026158142089843750e02, - 3.309673156738281250e02, - 3.227690734863281250e02, - 3.282322998046875000e02, - 3.265258789062500000e02, - 3.357459106445312500e02, - 3.347207031250000000e02, - 3.381403198242187500e02, - 3.149148559570312500e02, - 3.162806396484375000e02, - 3.237908630371093750e02, - 3.097922363281250000e02, - 3.036410217285156250e02, - 2.612874755859375000e02, - 2.131403198242187500e02, - 1.905313415527343750e02, - 1.799625396728515625e02, - 1.782561340332031250e02, - 1.799591217041015625e02, - 1.779155273437500000e02, - 2.352384185791015625e02, - 2.548024597167968750e02, - 2.619686584472656250e02, - 2.992030029296875000e02, - 2.961273803710937500e02, - 3.138862304687500000e02, - 3.343800964355468750e02, - 3.330143127441406250e02, - 3.227690734863281250e02, - 3.152520446777343750e02, - 3.234502868652343750e02, - 3.323330993652343750e02, - 3.145708312988281250e02, - 3.002247924804687500e02, - 3.142302551269531250e02, - 3.193528747558593750e02, - 3.036376037597656250e02, - 2.551464538574218750e02, - 2.176907348632812500e02, - 1.946185302734375000e02, - 1.799625396728515625e02, - 1.802997283935546875e02, - 1.772343292236328125e02, - 1.809809265136718750e02, - 2.389918212890625000e02, - 2.565054626464843750e02, - 2.534298400878906250e02, - 2.937363891601562500e02, - 2.971525878906250000e02, - 3.091076354980468750e02, - 3.309638977050781250e02, - 3.336954956054687500e02, - 3.323297119140625000e02, - 3.251566772460937500e02, - 3.504359741210937500e02, - 3.449693603515625000e02, - 3.149114379882812500e02, - 3.104734191894531250e02, - 3.268630676269531250e02, - 3.268664855957031250e02, - 3.094482421875000000e02, - 2.691450805664062500e02, - 2.232697601318359375e02, - ], - [ - 3.947381896972656250e02, - 3.778378295898437500e02, - 3.702280273437500000e02, - 3.660050659179687500e02, - 5.198479614257812500e02, - 5.274493408203125000e02, - 5.730996704101562500e02, - 5.959121704101562500e02, - 6.664611206054687500e02, - 7.358361206054687500e02, - 7.853547363281250000e02, - 8.268750000000000000e02, - 8.395861206054687500e02, - 8.141723022460937500e02, - 7.938428955078125000e02, - 8.760134887695312500e02, - 8.827871704101562500e02, - 8.395861206054687500e02, - 7.845185546875000000e02, - 7.489273681640625000e02, - 7.396115112304687500e02, - 5.734966430664062500e02, - 4.420692443847656250e02, - 4.124830932617187500e02, - 4.260135192871093750e02, - 4.116385192871093750e02, - 4.048817443847656250e02, - 4.031925659179687500e02, - 5.739357910156250000e02, - 5.967567749023437500e02, - 6.246536865234375000e02, - 6.195861206054687500e02, - 6.803969726562500000e02, - 7.751942749023437500e02, - 8.345017089843750000e02, - 7.743496704101562500e02, - 8.057009887695312500e02, - 8.124746704101562500e02, - 8.209459228515625000e02, - 8.539949340820312500e02, - 8.624661865234375000e02, - 8.192482910156250000e02, - 7.955321044921875000e02, - 7.692736206054687500e02, - 7.235134887695312500e02, - 5.726942749023437500e02, - 4.370101318359375000e02, - 4.133446044921875000e02, - 4.243243103027343750e02, - 4.133361511230468750e02, - 4.175675659179687500e02, - 4.226351318359375000e02, - 6.288598022460937500e02, - 6.711232910156250000e02, - 6.356503295898437500e02, - 6.880321044921875000e02, - 7.989189453125000000e02, - 8.972044067382812500e02, - 9.039780273437500000e02, - 8.709290771484375000e02, - 8.870354614257812500e02, - 8.734797363281250000e02, - 8.522973022460937500e02, - 8.921199340820312500e02, - 8.904138793945312500e02, - 8.607685546875000000e02, - 8.175591430664062500e02, - 7.299240112304687500e02, - 7.065709228515625000e02, - 5.743750000000000000e02, - 4.446114807128906250e02, - 3.998057556152343750e02, - 4.184121704101562500e02, - 3.998057556152343750e02, - 4.048733215332031250e02, - 4.065709533691406250e02, - 5.190033569335937500e02, - 5.959121704101562500e02, - 6.178969726562500000e02, - 6.588767089843750000e02, - 7.070354614257812500e02, - 7.896115112304687500e02, - 7.921536865234375000e02, - 7.870607910156250000e02, - 8.260303955078125000e02, - 7.921452636718750000e02, - 8.031588134765625000e02, - 8.683868408203125000e02, - 8.311148681640625000e02, - 7.896030273437500000e02, - 7.565625000000000000e02, - 7.112838134765625000e02, - 6.664696044921875000e02, - 5.337584228515625000e02, - 3.998057556152343750e02, - 3.803716125488281250e02, - 3.854391784667968750e02, - 3.676858215332031250e02, - 3.634628295898437500e02, - 3.685473022460937500e02, - 4.919510192871093750e02, - 5.536571044921875000e02, - 5.578800659179687500e02, - 5.790033569335937500e02, - 6.474830932617187500e02, - 7.463935546875000000e02, - 7.879053955078125000e02, - 8.090878295898437500e02, - 8.073901977539062500e02, - 7.692651977539062500e02, - 8.277280273437500000e02, - 8.700844726562500000e02, - 8.794088134765625000e02, - 8.345101318359375000e02, - 8.667060546875000000e02, - 7.523226318359375000e02, - 7.353800659179687500e02, - 6.554982910156250000e02, - 5.232178955078125000e02, - 3.938935852050781250e02, - 4.116469726562500000e02, - 3.896790466308593750e02, - 3.820608215332031250e02, - 3.854391784667968750e02, - 5.274408569335937500e02, - 6.001351318359375000e02, - 6.212753295898437500e02, - 6.322634887695312500e02, - 7.045017089843750000e02, - 8.387500000000000000e02, - 8.997381591796875000e02, - 8.912669067382812500e02, - 8.988851318359375000e02, - 8.548310546875000000e02, - 8.912669067382812500e02, - 9.107601318359375000e02, - 9.132939453125000000e02, - 8.497550659179687500e02, - 9.056672363281250000e02, - 8.285726318359375000e02, - 7.514780273437500000e02, - 6.580405273437500000e02, - 5.257601318359375000e02, - 3.837500000000000000e02, - 3.981250000000000000e02, - 3.964358215332031250e02, - 3.719172363281250000e02, - 3.829053955078125000e02, - 5.206756591796875000e02, - 5.705574340820312500e02, - 5.553463134765625000e02, - 6.233530273437500000e02, - 6.994172363281250000e02, - 8.336571044921875000e02, - 9.226182250976562500e02, - 9.056672363281250000e02, - 9.276942749023437500e02, - 8.683952636718750000e02, - 8.929560546875000000e02, - 1.004788879394531250e03, - 9.675169067382812500e02, - 9.285473022460937500e02, - 8.827955932617187500e02, - 8.014611206054687500e02, - 7.599493408203125000e02, - 5.938006591796875000e02, - 4.463006896972656250e02, - 3.896790466308593750e02, - ], - [ - 4.006647109985351562e01, - 3.545051574707031250e01, - 2.289512634277343750e01, - 1.772525787353515625e01, - 1.070901012420654297e01, - 9.231905937194824219e00, - 7.016248226165771484e00, - 1.070901012420654297e01, - 1.920236396789550781e01, - 3.766617584228515625e01, - 4.523633575439453125e01, - 5.409896469116210938e01, - 6.776218414306640625e01, - 6.720827484130859375e01, - 5.816100311279296875e01, - 5.243722152709960938e01, - 5.649925994873046875e01, - 5.243722152709960938e01, - 6.037666320800781250e01, - 6.573117065429687500e01, - 6.831610107421875000e01, - 6.185376739501953125e01, - 5.391432952880859375e01, - 4.911373519897460938e01, - 4.689807891845703125e01, - 2.861890602111816406e01, - 2.437223052978515625e01, - 1.901772499084472656e01, - 1.052437210083007812e01, - 8.124076843261718750e00, - 7.385524272918701172e00, - 1.107828617095947266e01, - 2.621861076354980469e01, - 4.671343994140625000e01, - 5.539143371582031250e01, - 5.889955520629882812e01, - 7.200886535644531250e01, - 6.776218414306640625e01, - 6.093057632446289062e01, - 5.742245101928710938e01, - 5.391432952880859375e01, - 5.612998580932617188e01, - 5.723781204223632812e01, - 6.794682312011718750e01, - 6.443869781494140625e01, - 6.277695846557617188e01, - 6.517725372314453125e01, - 5.225258636474609375e01, - 5.040620422363281250e01, - 3.489660263061523438e01, - 2.381831550598144531e01, - 1.790989685058593750e01, - 1.181683921813964844e01, - 8.493352890014648438e00, - 8.493352890014648438e00, - 1.827917289733886719e01, - 3.175775527954101562e01, - 5.003692626953125000e01, - 5.649925994873046875e01, - 6.739290618896484375e01, - 7.090103149414062500e01, - 6.517725372314453125e01, - 5.889955520629882812e01, - 5.760708999633789062e01, - 5.760708999633789062e01, - 5.594534683227539062e01, - 5.889955520629882812e01, - 7.348596954345703125e01, - 7.256277465820312500e01, - 6.517725372314453125e01, - 5.982274627685546875e01, - 5.631462478637695312e01, - 4.689807891845703125e01, - 3.101920318603515625e01, - 2.511078262329101562e01, - 2.344903945922851562e01, - 1.144756317138671875e01, - 7.570162296295166016e00, - 7.200886249542236328e00, - 1.347858238220214844e01, - 2.566469764709472656e01, - 4.209748840332031250e01, - 4.745199584960937500e01, - 6.425405883789062500e01, - 7.920974731445312500e01, - 6.905464935302734375e01, - 5.539143371582031250e01, - 5.280649948120117188e01, - 5.059084320068359375e01, - 5.077547836303710938e01, - 5.631462478637695312e01, - 6.462333679199218750e01, - 7.163958740234375000e01, - 5.631462478637695312e01, - 5.668389892578125000e01, - 4.412850952148437500e01, - 4.080502319335937500e01, - 3.157311630249023438e01, - 2.234121131896972656e01, - 2.012555313110351562e01, - 1.532496261596679688e01, - 9.047266960144042969e00, - 9.970458030700683594e00, - 1.366322040557861328e01, - 2.677252578735351562e01, - 4.375923156738281250e01, - 5.483751678466796875e01, - 6.351551055908203125e01, - 7.994830322265625000e01, - 7.330133056640625000e01, - 6.240768051147460938e01, - 6.351551055908203125e01, - 5.834564208984375000e01, - 5.631462478637695312e01, - 7.200886535644531250e01, - 7.477843475341796875e01, - 7.071639251708984375e01, - 6.517725372314453125e01, - 5.409896469116210938e01, - 4.966765213012695312e01, - 4.560561370849609375e01, - 3.526588058471679688e01, - 2.307976341247558594e01, - 2.455686759948730469e01, - 1.366322040557861328e01, - 8.862628936767578125e00, - 7.016248226165771484e00, - 1.403249645233154297e01, - 2.437223052978515625e01, - 4.394387054443359375e01, - 5.132939529418945312e01, - 6.277695846557617188e01, - 7.293205261230468750e01, - 6.591580200195312500e01, - 6.351551055908203125e01, - 6.296159362792968750e01, - 6.333087158203125000e01, - 6.628507995605468750e01, - 7.182422637939453125e01, - 7.754800415039062500e01, - 7.644017791748046875e01, - 6.259231948852539062e01, - 5.280649948120117188e01, - 5.409896469116210938e01, - 4.911373519897460938e01, - 3.120384025573730469e01, - 2.474150657653808594e01, - 2.049483108520507812e01, - 1.310930538177490234e01, - 9.601181983947753906e00, - 8.493352890014648438e00, - 1.643279266357421875e01, - 5.096011734008789062e01, - 4.837518310546875000e01, - 4.966765213012695312e01, - 6.333087158203125000e01, - 7.440915679931640625e01, - 7.274741363525390625e01, - 6.683899688720703125e01, - 6.517725372314453125e01, - 6.831610107421875000e01, - 6.702363586425781250e01, - 7.607089996337890625e01, - 7.662481689453125000e01, - 7.828656005859375000e01, - 6.499261474609375000e01, - 6.093057632446289062e01, - 5.040620422363281250e01, - ], - [ - 7.309510040283203125e01, - 7.324276733398437500e01, - 7.250443267822265625e01, - 7.250443267822265625e01, - 7.132309722900390625e01, - 7.265209960937500000e01, - 7.279976654052734375e01, - 7.545776367187500000e01, - 7.649143218994140625e01, - 7.870643615722656250e01, - 8.165977478027343750e01, - 8.254577636718750000e01, - 8.210277557373046875e01, - 8.313644409179687500e01, - 8.313644409179687500e01, - 8.298877716064453125e01, - 8.254577636718750000e01, - 8.225044250488281250e01, - 8.062610626220703125e01, - 8.106910705566406250e01, - 7.959243774414062500e01, - 7.974010467529296875e01, - 7.841110229492187500e01, - 8.033077239990234375e01, - 8.062610626220703125e01, - 8.254577636718750000e01, - 8.136444091796875000e01, - 8.092144012451171875e01, - 8.092144012451171875e01, - 8.033077239990234375e01, - 8.151210784912109375e01, - 8.136444091796875000e01, - 8.121677398681640625e01, - 8.225044250488281250e01, - 8.343177795410156250e01, - 8.313644409179687500e01, - 8.254577636718750000e01, - 8.210277557373046875e01, - 8.343177795410156250e01, - 8.254577636718750000e01, - 8.225044250488281250e01, - 8.151210784912109375e01, - 8.121677398681640625e01, - 8.062610626220703125e01, - 7.974010467529296875e01, - 7.885410308837890625e01, - 7.900177001953125000e01, - 8.018310546875000000e01, - 8.225044250488281250e01, - 8.225044250488281250e01, - 8.269344329833984375e01, - 8.239810943603515625e01, - 8.033077239990234375e01, - 8.136444091796875000e01, - 8.225044250488281250e01, - 8.284111022949218750e01, - 8.269344329833984375e01, - 8.239810943603515625e01, - 8.225044250488281250e01, - 8.328411102294921875e01, - 8.357944488525390625e01, - 8.343177795410156250e01, - 8.594211578369140625e01, - 8.417011260986328125e01, - 8.357944488525390625e01, - 8.269344329833984375e01, - 8.018310546875000000e01, - 8.106910705566406250e01, - 7.841110229492187500e01, - 7.988777160644531250e01, - 7.914943695068359375e01, - 7.988777160644531250e01, - 7.959243774414062500e01, - 8.136444091796875000e01, - 8.195510864257812500e01, - 8.239810943603515625e01, - 8.165977478027343750e01, - 8.180744171142578125e01, - 8.047843933105468750e01, - 8.121677398681640625e01, - 8.151210784912109375e01, - 8.298877716064453125e01, - 8.180744171142578125e01, - 8.313644409179687500e01, - 8.180744171142578125e01, - 8.328411102294921875e01, - 8.328411102294921875e01, - 8.313644409179687500e01, - 8.165977478027343750e01, - 8.151210784912109375e01, - 7.929710388183593750e01, - 7.900177001953125000e01, - 7.826343536376953125e01, - 7.796810150146484375e01, - 7.826343536376953125e01, - 7.708209991455078125e01, - 7.309510040283203125e01, - 7.368576812744140625e01, - 7.324276733398437500e01, - 7.250443267822265625e01, - 7.294743347167968750e01, - 7.206143188476562500e01, - 7.265209960937500000e01, - 7.442410278320312500e01, - 7.575309753417968750e01, - 7.870643615722656250e01, - 8.062610626220703125e01, - 8.106910705566406250e01, - 8.180744171142578125e01, - 8.313644409179687500e01, - 8.254577636718750000e01, - 8.180744171142578125e01, - 8.328411102294921875e01, - 8.254577636718750000e01, - 7.959243774414062500e01, - 7.782043457031250000e01, - 7.841110229492187500e01, - 7.767276763916015625e01, - 7.767276763916015625e01, - 7.722976684570312500e01, - 7.516242980957031250e01, - 7.531009674072265625e01, - 7.619609832763671875e01, - 7.457176971435546875e01, - 7.442410278320312500e01, - 7.398110198974609375e01, - 7.324276733398437500e01, - 7.368576812744140625e01, - 7.398110198974609375e01, - 7.708209991455078125e01, - 7.841110229492187500e01, - 7.885410308837890625e01, - 7.988777160644531250e01, - 7.929710388183593750e01, - 7.900177001953125000e01, - 8.033077239990234375e01, - 7.900177001953125000e01, - 8.062610626220703125e01, - 7.885410308837890625e01, - 7.841110229492187500e01, - 7.988777160644531250e01, - 8.003543853759765625e01, - 7.914943695068359375e01, - 8.018310546875000000e01, - 7.914943695068359375e01, - 7.870643615722656250e01, - 7.560543060302734375e01, - 7.457176971435546875e01, - 7.442410278320312500e01, - 7.442410278320312500e01, - 7.368576812744140625e01, - 7.471943664550781250e01, - 7.412876892089843750e01, - 7.634376525878906250e01, - 8.018310546875000000e01, - 7.855876922607421875e01, - 7.959243774414062500e01, - 7.959243774414062500e01, - 7.855876922607421875e01, - 8.047843933105468750e01, - 7.944477081298828125e01, - 7.900177001953125000e01, - 7.796810150146484375e01, - 7.811576843261718750e01, - 7.974010467529296875e01, - 7.914943695068359375e01, - 7.988777160644531250e01, - 7.974010467529296875e01, - ], - ] - ) + 7.309510040283203125e01, + 7.324276733398437500e01, + 7.250443267822265625e01, + 7.250443267822265625e01, + 7.132309722900390625e01, + 7.265209960937500000e01, + 7.279976654052734375e01, + 7.545776367187500000e01, + 7.649143218994140625e01, + 7.870643615722656250e01, + 8.165977478027343750e01, + 8.254577636718750000e01, + 8.210277557373046875e01, + 8.313644409179687500e01, + 8.313644409179687500e01, + 8.298877716064453125e01, + 8.254577636718750000e01, + 8.225044250488281250e01, + 8.062610626220703125e01, + 8.106910705566406250e01, + 7.959243774414062500e01, + 7.974010467529296875e01, + 7.841110229492187500e01, + 8.033077239990234375e01, + 8.062610626220703125e01, + 8.254577636718750000e01, + 8.136444091796875000e01, + 8.092144012451171875e01, + 8.092144012451171875e01, + 8.033077239990234375e01, + 8.151210784912109375e01, + 8.136444091796875000e01, + 8.121677398681640625e01, + 8.225044250488281250e01, + 8.343177795410156250e01, + 8.313644409179687500e01, + 8.254577636718750000e01, + 8.210277557373046875e01, + 8.343177795410156250e01, + 8.254577636718750000e01, + 8.225044250488281250e01, + 8.151210784912109375e01, + 8.121677398681640625e01, + 8.062610626220703125e01, + 7.974010467529296875e01, + 7.885410308837890625e01, + 7.900177001953125000e01, + 8.018310546875000000e01, + 8.225044250488281250e01, + 8.225044250488281250e01, + 8.269344329833984375e01, + 8.239810943603515625e01, + 8.033077239990234375e01, + 8.136444091796875000e01, + 8.225044250488281250e01, + 8.284111022949218750e01, + 8.269344329833984375e01, + 8.239810943603515625e01, + 8.225044250488281250e01, + 8.328411102294921875e01, + 8.357944488525390625e01, + 8.343177795410156250e01, + 8.594211578369140625e01, + 8.417011260986328125e01, + 8.357944488525390625e01, + 8.269344329833984375e01, + 8.018310546875000000e01, + 8.106910705566406250e01, + 7.841110229492187500e01, + 7.988777160644531250e01, + 7.914943695068359375e01, + 7.988777160644531250e01, + 7.959243774414062500e01, + 8.136444091796875000e01, + 8.195510864257812500e01, + 8.239810943603515625e01, + 8.165977478027343750e01, + 8.180744171142578125e01, + 8.047843933105468750e01, + 8.121677398681640625e01, + 8.151210784912109375e01, + 8.298877716064453125e01, + 8.180744171142578125e01, + 8.313644409179687500e01, + 8.180744171142578125e01, + 8.328411102294921875e01, + 8.328411102294921875e01, + 8.313644409179687500e01, + 8.165977478027343750e01, + 8.151210784912109375e01, + 7.929710388183593750e01, + 7.900177001953125000e01, + 7.826343536376953125e01, + 7.796810150146484375e01, + 7.826343536376953125e01, + 7.708209991455078125e01, + 7.309510040283203125e01, + 7.368576812744140625e01, + 7.324276733398437500e01, + 7.250443267822265625e01, + 7.294743347167968750e01, + 7.206143188476562500e01, + 7.265209960937500000e01, + 7.442410278320312500e01, + 7.575309753417968750e01, + 7.870643615722656250e01, + 8.062610626220703125e01, + 8.106910705566406250e01, + 8.180744171142578125e01, + 8.313644409179687500e01, + 8.254577636718750000e01, + 8.180744171142578125e01, + 8.328411102294921875e01, + 8.254577636718750000e01, + 7.959243774414062500e01, + 7.782043457031250000e01, + 7.841110229492187500e01, + 7.767276763916015625e01, + 7.767276763916015625e01, + 7.722976684570312500e01, + 7.516242980957031250e01, + 7.531009674072265625e01, + 7.619609832763671875e01, + 7.457176971435546875e01, + 7.442410278320312500e01, + 7.398110198974609375e01, + 7.324276733398437500e01, + 7.368576812744140625e01, + 7.398110198974609375e01, + 7.708209991455078125e01, + 7.841110229492187500e01, + 7.885410308837890625e01, + 7.988777160644531250e01, + 7.929710388183593750e01, + 7.900177001953125000e01, + 8.033077239990234375e01, + 7.900177001953125000e01, + 8.062610626220703125e01, + 7.885410308837890625e01, + 7.841110229492187500e01, + 7.988777160644531250e01, + 8.003543853759765625e01, + 7.914943695068359375e01, + 8.018310546875000000e01, + 7.914943695068359375e01, + 7.870643615722656250e01, + 7.560543060302734375e01, + 7.457176971435546875e01, + 7.442410278320312500e01, + 7.442410278320312500e01, + 7.368576812744140625e01, + 7.471943664550781250e01, + 7.412876892089843750e01, + 7.634376525878906250e01, + 8.018310546875000000e01, + 7.855876922607421875e01, + 7.959243774414062500e01, + 7.959243774414062500e01, + 7.855876922607421875e01, + 8.047843933105468750e01, + 7.944477081298828125e01, + 7.900177001953125000e01, + 7.796810150146484375e01, + 7.811576843261718750e01, + 7.974010467529296875e01, + 7.914943695068359375e01, + 7.988777160644531250e01, + 7.974010467529296875e01, + ], + ]) diff --git a/test/mx/model/renewal/test_predictor.py b/test/mx/model/renewal/test_predictor.py index e05be630dd..f10600c7af 100644 --- a/test/mx/model/renewal/test_predictor.py +++ b/test/mx/model/renewal/test_predictor.py @@ -43,12 +43,10 @@ [[[0, 0, 0, 0, 0, 0, 0]]], ), ( - [ - [ - [[3, 1, 2, 3, 1, 1, 1], [3, 5, 4, 1, 1, 1, 1]], - [[3, 1, 2, 3, 1, 1, 1], [3, 5, 4, 1, 1, 1, 1]], - ] - ], + [[ + [[3, 1, 2, 3, 1, 1, 1], [3, 5, 4, 1, 1, 1, 1]], + [[3, 1, 2, 3, 1, 1, 1], [3, 5, 4, 1, 1, 1, 1]], + ]], [[[0, 0, 3, 5, 0, 4, 0], [0, 0, 3, 5, 0, 4, 0]]], ), ( @@ -71,42 +69,38 @@ def test_output_transform(input, expected): def test_predictor_smoke_test(): train_ds = ListDataset( - [ - { - "target": [ - 100.0, - 63.0, - 83.0, - 126.0, - 115.0, - 92.0, - 57.0, - 95.0, - 94.0, - 92.0, - 142.0, - 35.0, - 116.0, - 78.0, - 64.0, - 141.0, - ], - "start": "2018-01-07 00:00:00", - "feat_static_cat": [0], - } - ], + [{ + "target": [ + 100.0, + 63.0, + 83.0, + 126.0, + 115.0, + 92.0, + 57.0, + 95.0, + 94.0, + 92.0, + 142.0, + 35.0, + 116.0, + 78.0, + 64.0, + 141.0, + ], + "start": "2018-01-07 00:00:00", + "feat_static_cat": [0], + }], freq="1m", ) test_ds = ListDataset( - [ - { - "target": [100.0, 63.0, 83.0, 126.0, 115.0, 92.0, 57.0, 95.0] - + [0] * 15, - "start": "2018-01-07 00:00:00", - "feat_static_cat": [1], - } - ], + [{ + "target": [100.0, 63.0, 83.0, 126.0, 115.0, 92.0, 57.0, 95.0] + + [0] * 15, + "start": "2018-01-07 00:00:00", + "feat_static_cat": [1], + }], freq="1m", ) diff --git a/test/mx/model/seq2seq/test_forking_sequence_splitter.py b/test/mx/model/seq2seq/test_forking_sequence_splitter.py index 00d1df560f..66d60b51db 100644 --- a/test/mx/model/seq2seq/test_forking_sequence_splitter.py +++ b/test/mx/model/seq2seq/test_forking_sequence_splitter.py @@ -44,34 +44,30 @@ def test_forking_sequence_splitter() -> None: enc_len = 5 dec_len = 3 - trans = transform.Chain( - [ - transform.AddAgeFeature( - target_field=FieldName.TARGET, - output_field="age", - pred_length=dec_len, - ), - ForkingSequenceSplitter( - instance_sampler=ValidationSplitSampler(min_future=dec_len), - enc_len=enc_len, - dec_len=dec_len, - encoder_series_fields=["age"], - ), - ] - ) + trans = transform.Chain([ + transform.AddAgeFeature( + target_field=FieldName.TARGET, + output_field="age", + pred_length=dec_len, + ), + ForkingSequenceSplitter( + instance_sampler=ValidationSplitSampler(min_future=dec_len), + enc_len=enc_len, + dec_len=dec_len, + encoder_series_fields=["age"], + ), + ]) out = trans(ds, is_train=True) transformed_data = next(iter(out)) - future_target = np.array( - [ - [13.0, 14.0, 15.0], - [14.0, 15.0, 16.0], - [15.0, 16.0, 17.0], - [16.0, 17.0, 18.0], - [17.0, 18.0, 19.0], - ] - ) + future_target = np.array([ + [13.0, 14.0, 15.0], + [14.0, 15.0, 16.0], + [15.0, 16.0, 17.0], + [16.0, 17.0, 18.0], + [17.0, 18.0, 19.0], + ]) assert ( np.linalg.norm(future_target - transformed_data["future_target"]) < 1e-5 @@ -110,37 +106,35 @@ def make_dataset(N, train_length): num_time_feat_daily_freq = 3 num_age_feat = 1 - trans = transform.Chain( - [ - transform.AddAgeFeature( - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_AGE, - pred_length=10, + trans = transform.Chain([ + transform.AddAgeFeature( + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_AGE, + pred_length=10, + ), + transform.AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=time_features_from_frequency_str("D"), + pred_length=10, + ), + ForkingSequenceSplitter( + instance_sampler=( + ValidationSplitSampler(min_future=dec_len) + if is_train + else TSplitSampler() ), - transform.AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=time_features_from_frequency_str("D"), - pred_length=10, - ), - ForkingSequenceSplitter( - instance_sampler=( - ValidationSplitSampler(min_future=dec_len) - if is_train - else TSplitSampler() - ), - enc_len=enc_len, - dec_len=dec_len, - num_forking=num_forking, - encoder_series_fields=[ - FieldName.FEAT_AGE, - FieldName.FEAT_TIME, - ], - decoder_series_fields=[FieldName.FEAT_TIME], - ), - ] - ) + enc_len=enc_len, + dec_len=dec_len, + num_forking=num_forking, + encoder_series_fields=[ + FieldName.FEAT_AGE, + FieldName.FEAT_TIME, + ], + decoder_series_fields=[FieldName.FEAT_TIME], + ), + ]) out = trans(iter(ds), is_train=is_train) transformed_data = next(iter(out)) diff --git a/test/mx/model/simple_feedforward/test_serde.py b/test/mx/model/simple_feedforward/test_serde.py index 6ebdc5d403..7f4bd03c61 100644 --- a/test/mx/model/simple_feedforward/test_serde.py +++ b/test/mx/model/simple_feedforward/test_serde.py @@ -26,12 +26,10 @@ def test_simplefeedforward_symbol_block_serde(): with tempfile.TemporaryDirectory( prefix="gluonts-predictor-temp-" ) as temp_dir: - dataset = [ - { - "start": pd.Period("2022-01-01", freq="D"), - "target": np.random.normal(size=(200)), - } - ] + dataset = [{ + "start": pd.Period("2022-01-01", freq="D"), + "target": np.random.normal(size=(200)), + }] estimator = SimpleFeedForwardEstimator( prediction_length=10, diff --git a/test/mx/model/tpp/common.py b/test/mx/model/tpp/common.py index 6253eabf51..182ee5cd5b 100644 --- a/test/mx/model/tpp/common.py +++ b/test/mx/model/tpp/common.py @@ -22,13 +22,11 @@ def point_process_dataset(): marks = np.array([0, 1, 2, 0, 1, 2, 2, 2]) lds = ListDataset( - [ - { - "target": np.c_[ia_times, marks].T, - "start": pd.Timestamp("2011-01-01 00:00:00"), - "end": pd.Timestamp("2011-01-01 03:00:00"), - } - ], + [{ + "target": np.c_[ia_times, marks].T, + "start": pd.Timestamp("2011-01-01 00:00:00"), + "end": pd.Timestamp("2011-01-01 03:00:00"), + }], freq="H", one_dim_target=False, use_timestamp=True, diff --git a/test/mx/representation/test_bin.py b/test/mx/representation/test_bin.py index 8558b93e2d..6b3b609d5e 100644 --- a/test/mx/representation/test_bin.py +++ b/test/mx/representation/test_bin.py @@ -20,215 +20,201 @@ binning_cases = [ ( CustomBinning(bin_centers=np.linspace(-1, 10, 5)), - mx.nd.array( + mx.nd.array([ + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([-np.inf, 0.375, 3.125, 5.875, 8.625, np.inf]), + mx.nd.array([ [ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( + 1.0, + 2.0, + 2.0, + 3.0, + 3.0, + 4.0, + 4.0, + 4.0, + 5.0, + 5.0, + ], [ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array([-np.inf, 0.375, 3.125, 5.875, 8.625, np.inf]), - mx.nd.array( + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], [ - [ - 1.0, - 2.0, - 2.0, - 3.0, - 3.0, - 4.0, - 4.0, - 4.0, - 5.0, - 5.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - ] - ), + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + ]), ), ( CustomBinning(bin_centers=np.linspace(-10, 10, 8)), - mx.nd.array( + mx.nd.array([ + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + -np.inf, + -8.57142857, + -5.71428571, + -2.85714286, + 0.0, + 2.85714286, + 5.71428571, + 8.57142857, + np.inf, + ]), + mx.nd.array([ + [ + 4.0, + 5.0, + 6.0, + 6.0, + 6.0, + 7.0, + 7.0, + 7.0, + 8.0, + 8.0, + ], + [ + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + ], [ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + ], [ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + ], [ - -np.inf, - -8.57142857, - -5.71428571, - -2.85714286, - 0.0, - 2.85714286, - 5.71428571, - 8.57142857, - np.inf, - ] - ), - mx.nd.array( + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + ], [ - [ - 4.0, - 5.0, - 6.0, - 6.0, - 6.0, - 7.0, - 7.0, - 7.0, - 8.0, - 8.0, - ], - [ - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - ], - [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - ], - [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - ], - [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - ], - [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - ], - ] - ), + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + ], + ]), ), ] diff --git a/test/mx/representation/test_grb.py b/test/mx/representation/test_grb.py index 0f8b32f845..5b8f5e511b 100644 --- a/test/mx/representation/test_grb.py +++ b/test/mx/representation/test_grb.py @@ -24,136 +24,126 @@ is_quantile=True, quantile_scaling_limit=1.0, ), - mx.nd.array( + mx.nd.array([ [ - [ - -0.188679, - 0.377358, - 0.566038, - 0.754717, - 0.943396, - 1.13208, - 1.32075, - 1.50943, - 1.69811, - 1.88679, - ], - [1.0] * 10, - [0.857143] * 5 + [1.14286] * 5, - [1.05263] * 8 + [0.789474] * 2, - [1.0] * 10, - ] - ), + -0.188679, + 0.377358, + 0.566038, + 0.754717, + 0.943396, + 1.13208, + 1.32075, + 1.50943, + 1.69811, + 1.88679, + ], + [1.0] * 10, + [0.857143] * 5 + [1.14286] * 5, + [1.05263] * 8 + [0.789474] * 2, + [1.0] * 10, + ]), + mx.nd.array([ + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + -np.inf, + 0.334232, + 0.92857149, + 1.0, + 1.03425997, + 1.47765499, + np.inf, + ]), mx.nd.array( - [ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ] + [-0.18867899, 0.85714298, 1.0, 1.0, 1.06851995, 1.88679004] ), - mx.nd.array( + mx.nd.array([ [ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( + 1, + 2, + 2, + 2, + 3, + 5, + 5, + 6, + 6, + 6, + ], [ - -np.inf, - 0.334232, - 0.92857149, - 1.0, - 1.03425997, - 1.47765499, - np.inf, - ] - ), - mx.nd.array( - [-0.18867899, 0.85714298, 1.0, 1.0, 1.06851995, 1.88679004] - ), - mx.nd.array( + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], [ - [ - 1, - 2, - 2, - 2, - 3, - 5, - 5, - 6, - 6, - 6, - ], - [ - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], - [ - 1, - 1, - 1, - 1, - 1, - 4, - 4, - 4, - 4, - 4, - ], - [ - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 2, - 2, - ], - [ - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - ], - [ - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], - ] - ), + 1, + 1, + 1, + 1, + 1, + 4, + 4, + 4, + 4, + 4, + ], + [ + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 2, + 2, + ], + [ + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + ], + [ + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], + ]), ), ( GlobalRelativeBinning( @@ -161,147 +151,135 @@ is_quantile=True, quantile_scaling_limit=1.0, ), - mx.nd.array( + mx.nd.array([ [ - [ - -0.188679, - 0.377358, - 0.566038, - 0.754717, - 0.943396, - 1.13208, - 1.32075, - 1.50943, - 1.69811, - 1.88679, - ], - [1.0] * 10, - [0.857143] * 5 + [1.14286] * 5, - [1.05263] * 8 + [0.789474] * 2, - [1.0] * 10, - ] - ), - mx.nd.array( + -0.188679, + 0.377358, + 0.566038, + 0.754717, + 0.943396, + 1.13208, + 1.32075, + 1.50943, + 1.69811, + 1.88679, + ], + [1.0] * 10, + [0.857143] * 5 + [1.14286] * 5, + [1.05263] * 8 + [0.789474] * 2, + [1.0] * 10, + ]), + mx.nd.array([ + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + -np.inf, + 0.334232, + 0.92857149, + 1.0, + 1.0, + 1.02631497, + 1.097745, + 1.51482505, + np.inf, + ]), + mx.nd.array([ + -0.18867899, + 0.85714298, + 1.0, + 1.0, + 1.0, + 1.05262995, + 1.14286005, + 1.88679004, + ]), + mx.nd.array([ [ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( + 1, + 2, + 2, + 2, + 3, + 7, + 7, + 7, + 8, + 8, + ], [ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + ], [ - -np.inf, - 0.334232, - 0.92857149, - 1.0, - 1.0, - 1.02631497, - 1.097745, - 1.51482505, - np.inf, - ] - ), - mx.nd.array( + 1, + 1, + 1, + 1, + 1, + 5, + 5, + 5, + 5, + 5, + ], [ - -0.18867899, - 0.85714298, - 1.0, - 1.0, - 1.0, - 1.05262995, - 1.14286005, - 1.88679004, - ] - ), - mx.nd.array( + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 2, + 2, + ], [ - [ - 1, - 2, - 2, - 2, - 3, - 7, - 7, - 7, - 8, - 8, - ], - [ - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - ], - [ - 1, - 1, - 1, - 1, - 1, - 5, - 5, - 5, - 5, - 5, - ], - [ - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 2, - 2, - ], - [ - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - ], - [ - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - ], - ] - ), + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + ], + [ + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + ], + ]), ), ( GlobalRelativeBinning( @@ -309,124 +287,116 @@ is_quantile=False, quantile_scaling_limit=1.0, ), - mx.nd.array( - [ - [ - -0.188679, - 0.377358, - 0.566038, - 0.754717, - 0.943396, - 1.13208, - 1.32075, - 1.50943, - 1.69811, - 1.88679, - ], - [1.0] * 10, - [0.857143] * 5 + [1.14286] * 5, - [1.05263] * 8 + [0.789474] * 2, - [1.0] * 10, - ] - ), - mx.nd.array( - [ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( + mx.nd.array([ [ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ] - ), + -0.188679, + 0.377358, + 0.566038, + 0.754717, + 0.943396, + 1.13208, + 1.32075, + 1.50943, + 1.69811, + 1.88679, + ], + [1.0] * 10, + [0.857143] * 5 + [1.14286] * 5, + [1.05263] * 8 + [0.789474] * 2, + [1.0] * 10, + ]), + mx.nd.array([ + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ]), mx.nd.array([-np.inf, -8.0, -4.0, 0.0, 4.0, 8.0, np.inf]), mx.nd.array([-10.0, -6.0, -2.0, 2.0, 6.0, 10.0]), - mx.nd.array( + mx.nd.array([ [ - [ - 3, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], - [ - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], - [ - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], - [ - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], - [ - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - ], - [ - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], - ] - ), + 3, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], + [ + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], + [ + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], + [ + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], + [ + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + ], + [ + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], + ]), ), ] diff --git a/test/mx/representation/test_hyb.py b/test/mx/representation/test_hyb.py index 466311b9da..223262cc44 100644 --- a/test/mx/representation/test_hyb.py +++ b/test/mx/representation/test_hyb.py @@ -41,179 +41,171 @@ ), ] ), - mx.nd.array( - [ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( - [ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ] - ), + mx.nd.array([ + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ]), [ - mx.nd.array( + mx.nd.array([ + [ + 1.0, + 2.0, + 2.0, + 3.0, + 3.0, + 4.0, + 4.0, + 4.0, + 5.0, + 5.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + ]), + mx.nd.array([ + [ + 4.0, + 5.0, + 6.0, + 6.0, + 6.0, + 7.0, + 7.0, + 7.0, + 8.0, + 8.0, + ], + [ + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + ], + [ + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + ], + [ + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + ], [ - [ - 1.0, - 2.0, - 2.0, - 3.0, - 3.0, - 4.0, - 4.0, - 4.0, - 5.0, - 5.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - ] - ), - mx.nd.array( + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + ], [ - [ - 4.0, - 5.0, - 6.0, - 6.0, - 6.0, - 7.0, - 7.0, - 7.0, - 8.0, - 8.0, - ], - [ - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - ], - [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - ], - [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - ], - [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - ], - [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - ], - ] - ), + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + ], + ]), ], ), ( @@ -233,113 +225,105 @@ ), ] ), - mx.nd.array( - [ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( - [ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ] - ), + mx.nd.array([ + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ]), [ - mx.nd.array( + mx.nd.array([ + [ + 1.0, + 2.0, + 2.0, + 3.0, + 3.0, + 4.0, + 4.0, + 4.0, + 5.0, + 5.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], [ - [ - 1.0, - 2.0, - 2.0, - 3.0, - 3.0, - 4.0, - 4.0, - 4.0, - 5.0, - 5.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - ] - ), - mx.nd.array( + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], [ - [1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - ] - ), + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + ]), + mx.nd.array([ + [1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + ]), ], ), ( @@ -353,37 +337,31 @@ ), ] ), - mx.nd.array( - [ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( - [ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ] - ), + mx.nd.array([ + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ]), [ - mx.nd.array( - [ - [1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - ] - ), + mx.nd.array([ + [1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + ]), ], ), ] diff --git a/test/mx/representation/test_lab.py b/test/mx/representation/test_lab.py index 98f83fa763..76cdc3e25b 100644 --- a/test/mx/representation/test_lab.py +++ b/test/mx/representation/test_lab.py @@ -20,289 +20,269 @@ la_binning_cases = [ ( LocalAbsoluteBinning(num_bins=6, is_quantile=True), - mx.nd.array( + mx.nd.array([ + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ [ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( + -np.inf, + 0.9, + 3.7, + 5.5, + 7.3, + 9.1, + np.inf, + ], [ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( + -np.inf, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + np.inf, + ], [ - [ - -np.inf, - 0.9, - 3.7, - 5.5, - 7.3, - 9.1, - np.inf, - ], - [ - -np.inf, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - np.inf, - ], - [ - -np.inf, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - np.inf, - ], - [ - -np.inf, - 1.7, - 1.95, - 2.0, - 2.0, - 2.0, - np.inf, - ], - [ - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - ], - [ - -np.inf, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - np.inf, - ], - ] - ), - mx.nd.array( + -np.inf, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + np.inf, + ], [ - [ - -1.0, - 2.8, - 4.6, - 6.4, - 8.2, - 10.0, - ], - [ - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - ], - [ - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - ], - [ - 1.5, - 1.9, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - ], - [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - ], - ] - ), - mx.nd.array( + -np.inf, + 1.7, + 1.95, + 2.0, + 2.0, + 2.0, + np.inf, + ], [ - [1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - ] - ), + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + ], + [ + -np.inf, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + np.inf, + ], + ]), + mx.nd.array([ + [ + -1.0, + 2.8, + 4.6, + 6.4, + 8.2, + 10.0, + ], + [ + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + ], + [ + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + ], + [ + 1.5, + 1.9, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ], + ]), + mx.nd.array([ + [1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + ]), ), ( LocalAbsoluteBinning(num_bins=6, is_quantile=False), - mx.nd.array( + mx.nd.array([ + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ]), + mx.nd.array([ + [ + -np.inf, + 0.1, + 2.3, + 4.5, + 6.7, + 8.9, + np.inf, + ], + [ + -np.inf, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + np.inf, + ], + [ + -np.inf, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + np.inf, + ], + [ + -np.inf, + 1.55, + 1.65, + 1.75, + 1.85, + 1.95, + np.inf, + ], + [ + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + ], + [ + -np.inf, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + np.inf, + ], + ]), + mx.nd.array([ + [ + -1.0, + 1.2, + 3.4, + 5.6, + 7.8, + 10.0, + ], [ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + ], [ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ] - ), - mx.nd.array( + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + ], [ - [ - -np.inf, - 0.1, - 2.3, - 4.5, - 6.7, - 8.9, - np.inf, - ], - [ - -np.inf, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - np.inf, - ], - [ - -np.inf, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - np.inf, - ], - [ - -np.inf, - 1.55, - 1.65, - 1.75, - 1.85, - 1.95, - np.inf, - ], - [ - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - ], - [ - -np.inf, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - np.inf, - ], - ] - ), - mx.nd.array( + 1.5, + 1.6, + 1.7, + 1.8, + 1.9, + 2.0, + ], [ - [ - -1.0, - 1.2, - 3.4, - 5.6, - 7.8, - 10.0, - ], - [ - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - ], - [ - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - ], - [ - 1.5, - 1.6, - 1.7, - 1.8, - 1.9, - 2.0, - ], - [ - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - ], - [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - ], - ] - ), - mx.nd.array( + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + ], [ - [1.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - ] - ), + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ], + ]), + mx.nd.array([ + [1.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + ]), ), ] diff --git a/test/mx/representation/test_mean.py b/test/mx/representation/test_mean.py index 677232c873..a114eec4df 100644 --- a/test/mx/representation/test_mean.py +++ b/test/mx/representation/test_mean.py @@ -20,44 +20,36 @@ mean_cases = [ ( MeanScaling(), - mx.nd.array( - [ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ] - ), - mx.nd.array( - [ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ] - ), + mx.nd.array([ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ]), + mx.nd.array([ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ]), mx.nd.array([1.0, 3.0, 1.5, 1.00396824, 1.00396824]), ), ( MeanScaling(), - mx.nd.array( - [ - [120.0] * 25 + [150.0] * 25, - [0.0] * 10 + [3.0] * 20 + [61.0] * 20, - [0.0] * 50, - [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, - ] - ), - mx.nd.array( - [ - [1.0] * 25 + [1.0] * 25, - [0.0] * 10 + [1.0] * 20 + [1.0] * 20, - [0.0] * 50, - [1.0] * 10 + [0.0] * 30 + [1.0] * 10, - ] - ), + mx.nd.array([ + [120.0] * 25 + [150.0] * 25, + [0.0] * 10 + [3.0] * 20 + [61.0] * 20, + [0.0] * 50, + [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, + ]), + mx.nd.array([ + [1.0] * 25 + [1.0] * 25, + [0.0] * 10 + [1.0] * 20 + [1.0] * 20, + [0.0] * 50, + [1.0] * 10 + [0.0] * 30 + [1.0] * 10, + ]), mx.nd.array([135.0, 32.0, 73.00454712, 2.5e-2]), ), ( diff --git a/test/mx/representation/test_rep.py b/test/mx/representation/test_rep.py index 622d337c1f..9807c8cfdd 100644 --- a/test/mx/representation/test_rep.py +++ b/test/mx/representation/test_rep.py @@ -18,42 +18,34 @@ cases = [ ( - mx.nd.array( - [ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ] - ), - mx.nd.array( - [ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ] - ), + mx.nd.array([ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ]), + mx.nd.array([ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ]), ), ( - mx.nd.array( - [ - [120.0] * 25 + [150.0] * 25, - [0.0] * 10 + [3.0] * 20 + [61.0] * 20, - [0.0] * 50, - [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, - ] - ), - mx.nd.array( - [ - [1.0] * 25 + [1.0] * 25, - [0.0] * 10 + [1.0] * 20 + [1.0] * 20, - [0.0] * 50, - [1.0] * 10 + [0.0] * 30 + [1.0] * 10, - ] - ), + mx.nd.array([ + [120.0] * 25 + [150.0] * 25, + [0.0] * 10 + [3.0] * 20 + [61.0] * 20, + [0.0] * 50, + [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, + ]), + mx.nd.array([ + [1.0] * 25 + [1.0] * 25, + [0.0] * 10 + [1.0] * 20 + [1.0] * 20, + [0.0] * 50, + [1.0] * 10 + [0.0] * 30 + [1.0] * 10, + ]), ), ( mx.nd.random.normal(shape=(5, 30)), diff --git a/test/mx/test_transform_equals.py b/test/mx/test_transform_equals.py index 23356cd966..0cd7dfefa6 100644 --- a/test/mx/test_transform_equals.py +++ b/test/mx/test_transform_equals.py @@ -131,58 +131,54 @@ def test_continuous_time_splitter(): def test_chain(): - chain = transform.Chain( - [ - transform.AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field="time_feat", - time_features=[ - time_feature.day_of_week, - time_feature.day_of_month, - time_feature.month_of_year, - ], - pred_length=10, - ), - transform.AddAgeFeature( - target_field=FieldName.TARGET, - output_field="age", - pred_length=10, - log_scale=True, - ), - transform.AddObservedValuesIndicator( - target_field=FieldName.TARGET, output_field="observed_values" - ), - ] - ) + chain = transform.Chain([ + transform.AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field="time_feat", + time_features=[ + time_feature.day_of_week, + time_feature.day_of_month, + time_feature.month_of_year, + ], + pred_length=10, + ), + transform.AddAgeFeature( + target_field=FieldName.TARGET, + output_field="age", + pred_length=10, + log_scale=True, + ), + transform.AddObservedValuesIndicator( + target_field=FieldName.TARGET, output_field="observed_values" + ), + ]) assert equals(chain, clone(chain)) assert not equals(chain, clone(chain, {"transformations": []})) - another_chain = transform.Chain( - [ - transform.AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field="time_feat", - time_features=[ - time_feature.day_of_week, - time_feature.day_of_month, - time_feature.month_of_year, - ], - pred_length=10, - ), - transform.AddAgeFeature( - target_field=FieldName.TARGET, - output_field="age", - pred_length=10, - log_scale=False, - ), - transform.AddObservedValuesIndicator( - target_field=FieldName.TARGET, output_field="observed_values" - ), - ] - ) + another_chain = transform.Chain([ + transform.AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field="time_feat", + time_features=[ + time_feature.day_of_week, + time_feature.day_of_month, + time_feature.month_of_year, + ], + pred_length=10, + ), + transform.AddAgeFeature( + target_field=FieldName.TARGET, + output_field="age", + pred_length=10, + log_scale=False, + ), + transform.AddObservedValuesIndicator( + target_field=FieldName.TARGET, output_field="observed_values" + ), + ]) assert not equals(chain, another_chain) diff --git a/test/nursery/anomaly_detection/supervised_metrics/test_precision_recall.py b/test/nursery/anomaly_detection/supervised_metrics/test_precision_recall.py index ce87ecf94e..109c34f349 100644 --- a/test/nursery/anomaly_detection/supervised_metrics/test_precision_recall.py +++ b/test/nursery/anomaly_detection/supervised_metrics/test_precision_recall.py @@ -373,26 +373,24 @@ def test_buffered_precision_recall(test_case): @pytest.fixture def labels_and_scores() -> List[Tuple[np.array, np.array]]: label1 = np.array([0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]) - scores1 = np.array( - [ - 0.2, - 0.3, - 0.5, - 0.7, - 4, - 2.5, - 0.3, - 0.2, - 0.7, - 0.3, - 0.2, - 4, - 3, - 8, - 0.2, - 0.1, - ] - ) + scores1 = np.array([ + 0.2, + 0.3, + 0.5, + 0.7, + 4, + 2.5, + 0.3, + 0.2, + 0.7, + 0.3, + 0.2, + 4, + 3, + 8, + 0.2, + 0.1, + ]) label2 = np.array([0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]) scores2 = np.array( diff --git a/test/nursery/autogluon_tabular/test_autogluon_tabular.py b/test/nursery/autogluon_tabular/test_autogluon_tabular.py index f1cc10cae0..73fb52925c 100644 --- a/test/nursery/autogluon_tabular/test_autogluon_tabular.py +++ b/test/nursery/autogluon_tabular/test_autogluon_tabular.py @@ -134,43 +134,41 @@ def test_get_features_dataframe( @pytest.mark.parametrize( "dataset, freq, prediction_length", - [ - ( - ListDataset( - [ - { - "start": "1750-01-07 00:00:00", - "target": np.array( - [ - 1089.2, - 1078.91, - 1099.88, - 35790.55, - 34096.95, - 34906.95, - ], - ), - }, - { - "start": "1750-01-07 00:00:00", - "target": np.array( - [ - 1099.2, - 1098.91, - 1069.88, - 35990.55, - 34076.95, - 34766.95, - ], - ), - }, - ], - freq="W-TUE", - ), - "W-TUE", - 2, - ) - ], + [( + ListDataset( + [ + { + "start": "1750-01-07 00:00:00", + "target": np.array( + [ + 1089.2, + 1078.91, + 1099.88, + 35790.55, + 34096.95, + 34906.95, + ], + ), + }, + { + "start": "1750-01-07 00:00:00", + "target": np.array( + [ + 1099.2, + 1098.91, + 1069.88, + 35990.55, + 34076.95, + 34766.95, + ], + ), + }, + ], + freq="W-TUE", + ), + "W-TUE", + 2, + )], ) @pytest.mark.parametrize("lag_indices", [[], [1, 2, 5]]) @pytest.mark.parametrize("disable_auto_regression", [False, True]) diff --git a/test/nursery/sagemaker_sdk/test_entry_point_scripts.py b/test/nursery/sagemaker_sdk/test_entry_point_scripts.py index 7944700fbb..5e49244268 100644 --- a/test/nursery/sagemaker_sdk/test_entry_point_scripts.py +++ b/test/nursery/sagemaker_sdk/test_entry_point_scripts.py @@ -91,7 +91,7 @@ def test_train_script(dataset_name, custom_dataset): estimator = estimator_cls.from_hyperparameters( prediction_length=dataset.metadata.prediction_length, freq=dataset.metadata.freq, - **hyperparameters + **hyperparameters, ) serialized = serde.dump_json(estimator) with open(temp_dir_path / "estimator.json", "w") as estimator_file: diff --git a/test/shell/test_nested_params.py b/test/shell/test_nested_params.py index 797c2374df..f00feb3072 100644 --- a/test/shell/test_nested_params.py +++ b/test/shell/test_nested_params.py @@ -15,13 +15,11 @@ def test_nested_params(): - data = decode_nested_parameters( - { - "$env.num_workers": "4", - "$evaluation.quantiles": [0.1, 0.5, 0.9], - "prediction_length": 14, - } - ) + data = decode_nested_parameters({ + "$env.num_workers": "4", + "$evaluation.quantiles": [0.1, 0.5, 0.9], + "prediction_length": 14, + }) hps = data.pop("") assert hps["prediction_length"] == 14 diff --git a/test/time_feature/test_agg_lags.py b/test/time_feature/test_agg_lags.py index dd3b2f2d9b..19b69fa2a9 100644 --- a/test/time_feature/test_agg_lags.py +++ b/test/time_feature/test_agg_lags.py @@ -22,44 +22,36 @@ expected_lags_rolling = { "prediction_length_2": { - "train": np.array( - [ - [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5], - [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - "test": np.array( - [ - [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5, 4.5, 5.5], - [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ), + "train": np.array([ + [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5], + [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]), + "test": np.array([ + [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5, 4.5, 5.5], + [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]), }, "prediction_length_1": { - "train": np.array( - [ - [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5], - [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ), - "test": np.array( - [ - [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5, 4.5], - [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5, 3], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5, 2], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ] - ), + "train": np.array([ + [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5], + [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]), + "test": np.array([ + [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5, 4.5], + [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5, 3], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5, 2], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ]), }, } diff --git a/test/time_feature/test_holiday.py b/test/time_feature/test_holiday.py index 2d55894885..03b444c64f 100644 --- a/test/time_feature/test_holiday.py +++ b/test/time_feature/test_holiday.py @@ -128,13 +128,11 @@ def test_holidays(holiday): test_cases = [ ( pd.date_range(start="2016-12-24", end="2016-12-31", freq="D"), - np.array( - [ - [1, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1], - ] - ), + np.array([ + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + ]), [CHRISTMAS_EVE, CHRISTMAS_DAY, NEW_YEARS_EVE], ), ( @@ -163,91 +161,89 @@ def test_special_date_feature_set_hourly(): start="2016-12-24", end="2016-12-25", freq="H" ) - reference_features = np.array( + reference_features = np.array([ [ - [ - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 0, - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1, - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ], - ] - ) + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ]) sfs = SpecialDateFeatureSet([CHRISTMAS_EVE, CHRISTMAS_DAY, NEW_YEARS_EVE]) computed_features = sfs(date_indices) diff --git a/test/torch/distribution/test_discrete_distribution.py b/test/torch/distribution/test_discrete_distribution.py index 0e89d26275..f1223a71f3 100644 --- a/test/torch/distribution/test_discrete_distribution.py +++ b/test/torch/distribution/test_discrete_distribution.py @@ -62,30 +62,24 @@ def test_rps(values, probs, obs, rps): [ # Duplicate values occur (i) only in the middle (ii) at the extremes ( - torch.tensor( - [ - [-1.0, 0.0, 0.0, 0.0, 2.0, 2.0, 5.0], - [-1.0, -1.0, 0.0, 0.0, 2.0, 5.0, 5.0], - ] - ), - torch.tensor( - [ - [0.1, 0.12, 0.03, 0.15, 0.05, 0.15, 0.4], - [0.15, 0.05, 0.13, 0.12, 0.05, 0.27, 0.23], - ] - ), + torch.tensor([ + [-1.0, 0.0, 0.0, 0.0, 2.0, 2.0, 5.0], + [-1.0, -1.0, 0.0, 0.0, 2.0, 5.0, 5.0], + ]), + torch.tensor([ + [0.1, 0.12, 0.03, 0.15, 0.05, 0.15, 0.4], + [0.15, 0.05, 0.13, 0.12, 0.05, 0.27, 0.23], + ]), ) ], ) @pytest.mark.parametrize( "probs_adjusted", [ - torch.tensor( - [ - [0.1, 0.0, 0.0, 0.3, 0.0, 0.2, 0.4], - [0.0, 0.2, 0.0, 0.25, 0.05, 0.0, 0.5], - ] - ), + torch.tensor([ + [0.1, 0.0, 0.0, 0.3, 0.0, 0.2, 0.4], + [0.0, 0.2, 0.0, 0.25, 0.05, 0.0, 0.5], + ]), ], ) def test_probs_duplicate_values(values, probs, probs_adjusted): diff --git a/test/torch/distribution/test_torch_piecewise_linear.py b/test/torch/distribution/test_torch_piecewise_linear.py index 731596bd24..ae8f804362 100644 --- a/test/torch/distribution/test_torch_piecewise_linear.py +++ b/test/torch/distribution/test_torch_piecewise_linear.py @@ -97,12 +97,12 @@ def test_values( expected_target_crps: List[float], ): target = torch.Tensor(target).reshape(shape=(len(target),)) - expected_target_cdf = np.array(expected_target_cdf).reshape( - (len(expected_target_cdf),) - ) - expected_target_crps = np.array(expected_target_crps).reshape( - (len(expected_target_crps),) - ) + expected_target_cdf = np.array(expected_target_cdf).reshape(( + len(expected_target_cdf), + )) + expected_target_crps = np.array(expected_target_crps).reshape(( + len(expected_target_crps), + )) assert all(np.isclose(distr.cdf(target).numpy(), expected_target_cdf)) assert all(np.isclose(distr.crps(target).numpy(), expected_target_crps)) diff --git a/test/torch/model/test_mqf2_modules.py b/test/torch/model/test_mqf2_modules.py index a8c2685259..c5ed31ab3d 100644 --- a/test/torch/model/test_mqf2_modules.py +++ b/test/torch/model/test_mqf2_modules.py @@ -41,18 +41,16 @@ def test_mqf2_modules( distr_output = MQF2DistributionOutput(prediction_length) - lightning_module = MQF2MultiHorizonLightningModule( - { - "freq": "1H", - "context_length": context_length, - "prediction_length": prediction_length, - "num_feat_dynamic_real": num_feat_dynamic_real, - "num_feat_static_real": num_feat_static_real, - "num_feat_static_cat": num_feat_static_cat, - "cardinality": cardinality, - "distr_output": distr_output, - } - ) + lightning_module = MQF2MultiHorizonLightningModule({ + "freq": "1H", + "context_length": context_length, + "prediction_length": prediction_length, + "num_feat_dynamic_real": num_feat_dynamic_real, + "num_feat_static_real": num_feat_static_real, + "num_feat_static_cat": num_feat_static_cat, + "cardinality": cardinality, + "distr_output": distr_output, + }) model = lightning_module.model feat_static_cat = torch.zeros( diff --git a/test/torch/model/test_tft.py b/test/torch/model/test_tft.py index 5ac1ff029e..ddcfadb7dd 100644 --- a/test/torch/model/test_tft.py +++ b/test/torch/model/test_tft.py @@ -40,19 +40,17 @@ def test_tft_modules( prediction_length = 6 context_length = 12 - lightning_module = TemporalFusionTransformerLightningModule( - { - "context_length": context_length, - "prediction_length": prediction_length, - "d_past_feat_dynamic_real": d_past_feat_dynamic_real, - "c_past_feat_dynamic_cat": c_past_feat_dynamic_cat, - "d_feat_dynamic_real": d_feat_dynamic_real, - "c_feat_dynamic_cat": c_feat_dynamic_cat, - "d_feat_static_real": d_feat_static_real, - "c_feat_static_cat": c_feat_static_cat, - "distr_output": QuantileOutput(quantiles), - } - ) + lightning_module = TemporalFusionTransformerLightningModule({ + "context_length": context_length, + "prediction_length": prediction_length, + "d_past_feat_dynamic_real": d_past_feat_dynamic_real, + "c_past_feat_dynamic_cat": c_past_feat_dynamic_cat, + "d_feat_dynamic_real": d_feat_dynamic_real, + "c_feat_dynamic_cat": c_feat_dynamic_cat, + "d_feat_static_real": d_feat_static_real, + "c_feat_static_cat": c_feat_static_cat, + "distr_output": QuantileOutput(quantiles), + }) model = lightning_module.model feat_static_cat = torch.zeros( diff --git a/test/torch/test_scaler.py b/test/torch/test_scaler.py index 48ffb11330..02fe33e132 100644 --- a/test/torch/test_scaler.py +++ b/test/torch/test_scaler.py @@ -19,88 +19,72 @@ test_cases = [ ( scaler.MeanScaler(), - torch.Tensor( - [ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ] - ), - torch.Tensor( - [ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ] - ), + torch.Tensor([ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ]), + torch.Tensor([ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ]), torch.Tensor([1.0, 3.0, 1.5, 1e-10, 1.00396824]), ), ( scaler.MeanScaler(default_scale=0.5), - torch.Tensor( - [ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ] - ), - torch.Tensor( - [ - [0.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ] - ), + torch.Tensor([ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ]), + torch.Tensor([ + [0.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ]), torch.Tensor([0.5, 3.0, 1.5, 1e-10, 0.5]), ), ( scaler.MeanScaler(keepdim=True), - torch.Tensor( - [ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ] - ), - torch.Tensor( - [ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ] - ), + torch.Tensor([ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ]), + torch.Tensor([ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ]), torch.Tensor([1.0, 3.0, 1.5, 1e-10, 1.00396824]).unsqueeze(1), ), ( scaler.MeanScaler(), - torch.Tensor( - [ - [120.0] * 25 + [150.0] * 25, - [0.0] * 10 + [3.0] * 20 + [61.0] * 20, - [0.0] * 50, - [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, - ] - ), - torch.Tensor( - [ - [1.0] * 25 + [1.0] * 25, - [0.0] * 10 + [1.0] * 20 + [1.0] * 20, - [0.0] * 50, - [1.0] * 10 + [0.0] * 30 + [1.0] * 10, - ] - ), + torch.Tensor([ + [120.0] * 25 + [150.0] * 25, + [0.0] * 10 + [3.0] * 20 + [61.0] * 20, + [0.0] * 50, + [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, + ]), + torch.Tensor([ + [1.0] * 25 + [1.0] * 25, + [0.0] * 10 + [1.0] * 20 + [1.0] * 20, + [0.0] * 50, + [1.0] * 10 + [0.0] * 30 + [1.0] * 10, + ]), torch.Tensor([135.0, 32.0, 73.00454712, 2.5e-2]), ), ( @@ -134,28 +118,22 @@ def test_scaler(s, target, observed, expected_scale): @pytest.mark.parametrize( "target, observed", - [ - ( - torch.Tensor( - [ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ] - ), - torch.Tensor( - [ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ] - ), - ) - ], + [( + torch.Tensor([ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ]), + torch.Tensor([ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ]), + )], ) def test_nopscaler(target, observed): s = scaler.NOPScaler() diff --git a/test/transform/test_transform.py b/test/transform/test_transform.py index fbcdfb9d16..d783e094f3 100644 --- a/test/transform/test_transform.py +++ b/test/transform/test_transform.py @@ -371,47 +371,45 @@ def test_Transformation(): pred_length = 10 - t = transform.Chain( - [ - transform.AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field="time_feat", - time_features=[ - time_feature.day_of_week, - time_feature.day_of_month, - time_feature.month_of_year, - ], - pred_length=pred_length, - ), - transform.AddAgeFeature( - target_field=FieldName.TARGET, - output_field="age", - pred_length=pred_length, - log_scale=True, - ), - transform.AddObservedValuesIndicator( - target_field=FieldName.TARGET, output_field="observed_values" - ), - transform.VstackFeatures( - output_field="dynamic_feat", - input_fields=["age", "time_feat"], - drop_inputs=True, - ), - transform.InstanceSplitter( - target_field=FieldName.TARGET, - is_pad_field=FieldName.IS_PAD, - start_field=FieldName.START, - forecast_start_field=FieldName.FORECAST_START, - instance_sampler=transform.ExpectedNumInstanceSampler( - num_instances=4 - ), - past_length=train_length, - future_length=pred_length, - time_series_fields=["dynamic_feat", "observed_values"], + t = transform.Chain([ + transform.AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field="time_feat", + time_features=[ + time_feature.day_of_week, + time_feature.day_of_month, + time_feature.month_of_year, + ], + pred_length=pred_length, + ), + transform.AddAgeFeature( + target_field=FieldName.TARGET, + output_field="age", + pred_length=pred_length, + log_scale=True, + ), + transform.AddObservedValuesIndicator( + target_field=FieldName.TARGET, output_field="observed_values" + ), + transform.VstackFeatures( + output_field="dynamic_feat", + input_fields=["age", "time_feat"], + drop_inputs=True, + ), + transform.InstanceSplitter( + target_field=FieldName.TARGET, + is_pad_field=FieldName.IS_PAD, + start_field=FieldName.START, + forecast_start_field=FieldName.FORECAST_START, + instance_sampler=transform.ExpectedNumInstanceSampler( + num_instances=4 ), - ] - ) + past_length=train_length, + future_length=pred_length, + time_series_fields=["dynamic_feat", "observed_values"], + ), + ]) assert_serializable(t) @@ -440,54 +438,52 @@ def test_multi_dim_transformation(is_train): first_dim[-1] = np.nan second_dim[0] = np.nan - t = transform.Chain( - [ - transform.AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field="time_feat", - time_features=[ - time_feature.day_of_week, - time_feature.day_of_month, - time_feature.month_of_year, - ], - pred_length=pred_length, - ), - transform.AddAgeFeature( - target_field=FieldName.TARGET, - output_field="age", - pred_length=pred_length, - log_scale=True, - ), - transform.AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field="observed_values", - imputation_method=None, - ), - transform.VstackFeatures( - output_field="dynamic_feat", - input_fields=["age", "time_feat"], - drop_inputs=True, - ), - transform.InstanceSplitter( - target_field=FieldName.TARGET, - is_pad_field=FieldName.IS_PAD, - start_field=FieldName.START, - forecast_start_field=FieldName.FORECAST_START, - instance_sampler=( - transform.ExpectedNumInstanceSampler( - num_instances=4, min_future=pred_length - ) - if is_train - else transform.TestSplitSampler() - ), - past_length=train_length, - future_length=pred_length, - time_series_fields=["dynamic_feat", "observed_values"], - output_NTC=False, + t = transform.Chain([ + transform.AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field="time_feat", + time_features=[ + time_feature.day_of_week, + time_feature.day_of_month, + time_feature.month_of_year, + ], + pred_length=pred_length, + ), + transform.AddAgeFeature( + target_field=FieldName.TARGET, + output_field="age", + pred_length=pred_length, + log_scale=True, + ), + transform.AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field="observed_values", + imputation_method=None, + ), + transform.VstackFeatures( + output_field="dynamic_feat", + input_fields=["age", "time_feat"], + drop_inputs=True, + ), + transform.InstanceSplitter( + target_field=FieldName.TARGET, + is_pad_field=FieldName.IS_PAD, + start_field=FieldName.START, + forecast_start_field=FieldName.FORECAST_START, + instance_sampler=( + transform.ExpectedNumInstanceSampler( + num_instances=4, min_future=pred_length + ) + if is_train + else transform.TestSplitSampler() ), - ] - ) + past_length=train_length, + future_length=pred_length, + time_series_fields=["dynamic_feat", "observed_values"], + output_NTC=False, + ), + ]) assert_serializable(t) @@ -637,16 +633,14 @@ def make_test_data(): ds = gluonts.dataset.common.ListDataset( # Mimic output from InstanceSplitter - [ - { - "start": "2012-01-01", - "target": multi_dim_target, - "past_target": multi_dim_target, - "future_target": multi_dim_target, - "past_is_pad": past_is_pad, - f"past_{FieldName.OBSERVED_VALUES}": past_observed_target, - } - ], + [{ + "start": "2012-01-01", + "target": multi_dim_target, + "past_target": multi_dim_target, + "future_target": multi_dim_target, + "past_is_pad": past_is_pad, + f"past_{FieldName.OBSERVED_VALUES}": past_observed_target, + }], freq="1D", one_dim_target=False, ) @@ -741,13 +735,11 @@ def point_process_dataset(): marks = np.array([0, 1, 2, 0, 1, 2, 2, 2]) return ListDataset( - [ - { - "target": np.c_[ia_times, marks].T, - "start": pd.Timestamp("2011-01-01 00:00:00"), - "end": pd.Timestamp("2011-01-01 03:00:00"), - } - ], + [{ + "target": np.c_[ia_times, marks].T, + "start": pd.Timestamp("2011-01-01 00:00:00"), + "end": pd.Timestamp("2011-01-01 03:00:00"), + }], freq="H", one_dim_target=False, use_timestamp=True, @@ -885,16 +877,14 @@ def test_ctsplitter_train_samples_correct_times(point_process_dataset): iter_de = splitter(point_process_dataset, is_train=True) - assert all( - [ - ( - pd.Timestamp("2011-01-01 01:15:00") - <= d["forecast_start"] - <= pd.Timestamp("2011-01-01 01:45:00") - ) - for d in iter_de - ] - ) + assert all([ + ( + pd.Timestamp("2011-01-01 01:15:00") + <= d["forecast_start"] + <= pd.Timestamp("2011-01-01 01:45:00") + ) + for d in iter_de + ]) def test_ctsplitter_train_short_intervals(point_process_dataset): From 02e102bc19c8eb4f1aebdccad1a2668944c377e0 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 20 Feb 2024 19:16:50 +0100 Subject: [PATCH 06/10] no jupy --- .github/workflows/style_type_checks.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/style_type_checks.yml b/.github/workflows/style_type_checks.yml index 08c1a3a73f..3ada663e2b 100644 --- a/.github/workflows/style_type_checks.yml +++ b/.github/workflows/style_type_checks.yml @@ -13,7 +13,8 @@ jobs: - name: Install dependencies run: | pip install . - pip install click "black[jupyter]==24.01" "mypy==1.8.0" \ + # install also `black[jupyter]` + pip install click "black==24.01" "mypy==1.8.0" \ types-python-dateutil types-waitress types-PyYAML - name: Style check run: just black From bb32087af4e57ad05368e4818f3f60542e1e7782 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 20 Feb 2024 19:22:05 +0100 Subject: [PATCH 07/10] todo? --- .github/workflows/style_type_checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/style_type_checks.yml b/.github/workflows/style_type_checks.yml index 3ada663e2b..f82f0ee4a3 100644 --- a/.github/workflows/style_type_checks.yml +++ b/.github/workflows/style_type_checks.yml @@ -13,7 +13,7 @@ jobs: - name: Install dependencies run: | pip install . - # install also `black[jupyter]` + # todo: install also `black[jupyter]` pip install click "black==24.01" "mypy==1.8.0" \ types-python-dateutil types-waitress types-PyYAML - name: Style check From b727c021baf97f6c6c73eaf97dcf7cf1368cfdc6 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 20 Feb 2024 19:39:14 +0100 Subject: [PATCH 08/10] Empty-Commit From 09af6c8a9e60c3bfb22f537c2513d586bb55fb78 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Thu, 22 Feb 2024 22:12:57 +0100 Subject: [PATCH 09/10] black==24.02 --- .github/workflows/style_type_checks.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/style_type_checks.yml b/.github/workflows/style_type_checks.yml index f82f0ee4a3..48c2178edc 100644 --- a/.github/workflows/style_type_checks.yml +++ b/.github/workflows/style_type_checks.yml @@ -14,7 +14,7 @@ jobs: run: | pip install . # todo: install also `black[jupyter]` - pip install click "black==24.01" "mypy==1.8.0" \ + pip install click "black==24.02" "mypy==1.8.0" \ types-python-dateutil types-waitress types-PyYAML - name: Style check run: just black From 98690f0afda5a3b01538a25e75332a1ac30251c2 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 22 Feb 2024 22:18:04 +0100 Subject: [PATCH 10/10] no preview --- Justfile | 2 +- examples/benchmark_m4.py | 22 +- src/gluonts/core/component.py | 12 +- src/gluonts/dataset/arrow/dec.py | 12 +- src/gluonts/dataset/arrow/file.py | 10 +- src/gluonts/dataset/artificial/recipe.py | 10 +- src/gluonts/dataset/pandas.py | 20 +- src/gluonts/dataset/repository/_lstnet.py | 28 +- src/gluonts/dataset/repository/_m3.py | 30 +- .../dataset/repository/_tsf_datasets.py | 36 +- src/gluonts/dataset/schema/translate.py | 14 +- src/gluonts/ev/metrics.py | 9 +- src/gluonts/evaluation/_base.py | 44 +- src/gluonts/ext/hierarchicalforecast.py | 48 +- src/gluonts/ext/rotbaum/_model.py | 10 +- src/gluonts/ext/rotbaum/_predictor.py | 50 +- src/gluonts/ext/rotbaum/_preprocess.py | 64 +- src/gluonts/model/evaluation.py | 10 +- src/gluonts/model/forecast.py | 40 +- src/gluonts/model/forecast_generator.py | 10 +- src/gluonts/model/npts/_predictor.py | 12 +- src/gluonts/model/trivial/mean.py | 9 +- src/gluonts/mx/block/dropout.py | 12 +- src/gluonts/mx/distribution/iresnet.py | 9 +- .../mx/model/deep_factor/_estimator.py | 28 +- src/gluonts/mx/model/deepvar/_estimator.py | 92 +- src/gluonts/mx/model/gpvar/_estimator.py | 88 +- src/gluonts/mx/model/n_beats/_ensemble.py | 10 +- src/gluonts/mx/model/renewal/_estimator.py | 30 +- src/gluonts/mx/model/renewal/_transform.py | 14 +- .../mx/model/seq2seq/_forking_estimator.py | 18 +- .../mx/model/seq2seq/_seq2seq_estimator.py | 44 +- src/gluonts/mx/model/tft/_estimator.py | 124 +- .../mx/model/tpp/deeptpp/_estimator.py | 26 +- src/gluonts/mx/model/tpp/forecast.py | 20 +- src/gluonts/mx/model/wavenet/_estimator.py | 54 +- .../mx/representation/binning_helpers.py | 12 +- .../SCott/dataset_tools/algo_clustering.py | 66 +- .../SCott/dataset_tools/electricity.py | 48 +- .../SCott/dataset_tools/exchange_rate.py | 36 +- .../SCott/dataset_tools/group_raw_data.py | 192 +- .../nursery/SCott/dataset_tools/synthetic.py | 28 +- .../nursery/SCott/dataset_tools/traffic.py | 72 +- .../nursery/SCott/model/ar/ar_estimator.py | 30 +- .../SCott/model/lstm/lstm_estimator.py | 30 +- src/gluonts/nursery/SCott/preprocess_data.py | 172 +- src/gluonts/nursery/daf/engine/parallel.py | 30 +- src/gluonts/nursery/daf/estimator/modules.py | 10 +- .../nursery/daf/tslib/dataset/timeseries.py | 130 +- src/gluonts/nursery/daf/tslib/metrics/dict.py | 10 +- .../nursery/daf/tslib/nn/attention/posemb.py | 10 +- .../src/meta/datasets/artificial.py | 26 +- .../src/meta/datasets/cheat.py | 108 +- .../src/meta/datasets/datasets.py | 12 +- .../src/meta/datasets/super.py | 20 +- .../src/meta/models/module.py | 9 +- .../few_shot_prediction/src/scripts/data.py | 14 +- .../few_shot_prediction/src/scripts/train.py | 60 +- .../multivariate/datasets/dataset.py | 30 +- .../pts/dataset/repository/_m5.py | 16 +- .../robust-mts-attack/pts/feature/holiday.py | 38 +- .../causal_deepar/causal_deepar_network.py | 38 +- .../pts/model/deepar/deepar_network.py | 24 +- .../pts/model/deepvar/deepvar_estimator.py | 10 +- .../pts/model/deepvar/deepvar_network.py | 14 +- .../pts/model/n_beats/n_beats_ensemble.py | 10 +- .../pts/model/n_beats/n_beats_estimator.py | 30 +- .../pts/model/tempflow/tempflow_estimator.py | 80 +- .../pts/model/tempflow/tempflow_network.py | 14 +- .../pts/model/tft/tft_estimator.py | 210 +- .../pts/model/tft/tft_modules.py | 24 +- .../pts/model/time_grad/epsilon_theta.py | 18 +- .../model/time_grad/time_grad_estimator.py | 80 +- .../pts/model/time_grad/time_grad_network.py | 14 +- .../model/transformer/transformer_network.py | 24 +- .../transformer_tempflow_estimator.py | 90 +- .../transformer_tempflow_network.py | 14 +- .../robust-mts-attack/pts/modules/feature.py | 10 +- .../nursery/robust-mts-attack/read_pickle.py | 30 +- .../nursery/robust-mts-attack/utils.py | 10 +- src/gluonts/nursery/san/_estimator.py | 146 +- .../model/cop_deepar/_estimator.py | 10 +- .../model/cop_deepar/_network.py | 36 +- .../utils/utils.py | 10 +- .../nursery/tsbench/src/cli/utils/config.py | 31 +- .../src/tsbench/config/dataset/datasets.py | 88 +- .../src/tsbench/evaluations/aws/analytics.py | 12 +- .../src/tsbench/evaluations/tracking/_info.py | 40 +- .../tsbench/evaluations/tracking/ensemble.py | 10 +- .../src/tsbench/forecasts/evaluation.py | 20 +- .../tsbench/src/tsbench/recommender/greedy.py | 56 +- .../src/tsbench/surrogate/nonparametric.py | 38 +- .../tsbench/surrogate/transformers/config.py | 56 +- .../surrogate/transformers/performance.py | 11 +- src/gluonts/shell/sagemaker/dyn.py | 46 +- src/gluonts/time_feature/holiday.py | 20 +- src/gluonts/torch/model/estimator.py | 12 +- src/gluonts/torch/model/patch_tst/module.py | 10 +- src/gluonts/torch/model/tft/layers.py | 20 +- src/gluonts/torch/model/wavenet/estimator.py | 102 +- src/gluonts/transform/feature.py | 36 +- src/gluonts/zebras/_period.py | 9 +- src/gluonts/zebras/_time_frame.py | 40 +- src/gluonts/zebras/schema.py | 12 +- test/conftest.py | 12 +- test/dataset/test_data_loader.py | 28 +- test/dataset/test_dataset_mutability.py | 22 +- test/dataset/test_multivariate_grouper.py | 26 +- test/dataset/test_pandas.py | 120 +- test/dataset/test_split.py | 10 +- test/ev/test_aggregations.py | 44 +- ...t_metrics_compared_to_previous_approach.py | 25 +- test/evaluation/test_evaluator.py | 248 +- test/ext/prophet/test_prophet.py | 40 +- .../r_forecast/test_r_multi_seasonality.py | 95 +- test/ext/rotbaum/test_rotbaum_smoke.py | 44 +- test/ext/statsforecast/test_statsforecast.py | 14 +- test/model/npts/test_npts.py | 41 +- test/mx/block/test_scaler.py | 384 +- .../distribution/test_distribution_methods.py | 20 +- .../test_distribution_output_shapes.py | 10 +- .../test_distribution_sampling.py | 26 +- test/mx/distribution/test_nan_mixture.py | 10 +- test/mx/distribution/test_piecewise_linear.py | 12 +- test/mx/kernels/test_periodic_kernel.py | 12 +- test/mx/kernels/test_rbf_kernel.py | 12 +- .../generate_hierarchical_dataset.py | 12 +- .../test_train_prediction_with_hts.py | 10 +- test/mx/model/gp_forecaster/data.py | 10916 ++++++++-------- test/mx/model/renewal/test_predictor.py | 70 +- .../seq2seq/test_forking_sequence_splitter.py | 102 +- .../mx/model/simple_feedforward/test_serde.py | 10 +- test/mx/model/tpp/common.py | 12 +- test/mx/representation/test_bin.py | 384 +- test/mx/representation/test_grb.py | 716 +- test/mx/representation/test_hyb.py | 586 +- test/mx/representation/test_lab.py | 520 +- test/mx/representation/test_mean.py | 60 +- test/mx/representation/test_rep.py | 60 +- test/mx/test_transform_equals.py | 92 +- .../test_precision_recall.py | 38 +- .../test_autogluon_tabular.py | 72 +- test/shell/test_nested_params.py | 12 +- test/time_feature/test_agg_lags.py | 64 +- test/time_feature/test_holiday.py | 178 +- .../test_discrete_distribution.py | 30 +- .../test_torch_piecewise_linear.py | 12 +- test/torch/model/test_mqf2_modules.py | 22 +- test/torch/model/test_tft.py | 24 +- test/torch/test_scaler.py | 162 +- test/transform/test_transform.py | 218 +- 151 files changed, 10202 insertions(+), 9330 deletions(-) diff --git a/Justfile b/Justfile index a90ec87b92..db6f48cedb 100644 --- a/Justfile +++ b/Justfile @@ -34,7 +34,7 @@ release: python setup.py sdist black: - black --check --color --preview src test examples + black --check --color src test examples mypy: python setup.py type_check diff --git a/examples/benchmark_m4.py b/examples/benchmark_m4.py index e9b17717de..5368ee8411 100644 --- a/examples/benchmark_m4.py +++ b/examples/benchmark_m4.py @@ -95,15 +95,17 @@ def evaluate(dataset_name, estimator): df = pd.DataFrame(results) - sub_df = df[[ - "dataset", - "estimator", - "RMSE", - "mean_wQuantileLoss", - "MASE", - "sMAPE", - "OWA", - "MSIS", - ]] + sub_df = df[ + [ + "dataset", + "estimator", + "RMSE", + "mean_wQuantileLoss", + "MASE", + "sMAPE", + "OWA", + "MSIS", + ] + ] print(sub_df.to_string()) diff --git a/src/gluonts/core/component.py b/src/gluonts/core/component.py index 264f46b1eb..c5f18d011a 100644 --- a/src/gluonts/core/component.py +++ b/src/gluonts/core/component.py @@ -355,11 +355,13 @@ def init_wrapper(*args, **kwargs): # __init_args__ is not already set in order to avoid overriding a # value set by a subclass initializer in super().__init__ calls if not getattr(self, "__init_args__", {}): - self.__init_args__ = OrderedDict({ - name: arg - for name, arg in sorted(all_args.items()) - if not skip_encoding(arg) - }) + self.__init_args__ = OrderedDict( + { + name: arg + for name, arg in sorted(all_args.items()) + if not skip_encoding(arg) + } + ) self.__class__.__getnewargs_ex__ = validated_getnewargs_ex self.__class__.__repr__ = validated_repr diff --git a/src/gluonts/dataset/arrow/dec.py b/src/gluonts/dataset/arrow/dec.py index aab36898d0..148d5311c7 100644 --- a/src/gluonts/dataset/arrow/dec.py +++ b/src/gluonts/dataset/arrow/dec.py @@ -23,11 +23,13 @@ class ArrowDecoder: @classmethod def from_schema(cls, schema): - return cls([ - (column.name[: -len("._np_shape")], column.name) - for column in schema - if column.name.endswith("._np_shape") - ]) + return cls( + [ + (column.name[: -len("._np_shape")], column.name) + for column in schema + if column.name.endswith("._np_shape") + ] + ) def decode(self, batch, row_number: int): return next(self.decode_batch(batch.slice(row_number, row_number + 1))) diff --git a/src/gluonts/dataset/arrow/file.py b/src/gluonts/dataset/arrow/file.py index 552dad0f08..7bdb6cf898 100644 --- a/src/gluonts/dataset/arrow/file.py +++ b/src/gluonts/dataset/arrow/file.py @@ -229,10 +229,12 @@ def __post_init__(self): self.decoder = ArrowDecoder.from_schema(self.reader.schema_arrow) if not self._row_group_sizes: - self._row_group_sizes = np.cumsum([ - self.reader.metadata.row_group(row_group).num_rows - for row_group in range(self.reader.metadata.num_row_groups) - ]) + self._row_group_sizes = np.cumsum( + [ + self.reader.metadata.row_group(row_group).num_rows + for row_group in range(self.reader.metadata.num_row_groups) + ] + ) def location_for(self, idx): if idx == 0: diff --git a/src/gluonts/dataset/artificial/recipe.py b/src/gluonts/dataset/artificial/recipe.py index f5ede6a253..36cdbede42 100644 --- a/src/gluonts/dataset/artificial/recipe.py +++ b/src/gluonts/dataset/artificial/recipe.py @@ -714,10 +714,12 @@ def __call__(self, x, field_name, global_state, **kwargs): probs = [self.prob_fun(x, length=c) for c in self.cardinalities] global_state[field_name] = probs probs = global_state[field_name] - cats = np.array([ - np.random.choice(np.arange(len(probs[i])), p=probs[i]) - for i in range(len(probs)) - ]) + cats = np.array( + [ + np.random.choice(np.arange(len(probs[i])), p=probs[i]) + for i in range(len(probs)) + ] + ) return cats diff --git a/src/gluonts/dataset/pandas.py b/src/gluonts/dataset/pandas.py index e8dacf86e2..dcdb1d5456 100644 --- a/src/gluonts/dataset/pandas.py +++ b/src/gluonts/dataset/pandas.py @@ -221,15 +221,17 @@ def __len__(self) -> int: return len(self._data_entries) def __repr__(self) -> str: - info = ", ".join([ - f"size={len(self)}", - f"freq={self.freq}", - f"num_feat_dynamic_real={self.num_feat_dynamic_real}", - f"num_past_feat_dynamic_real={self.num_past_feat_dynamic_real}", - f"num_feat_static_real={self.num_feat_static_real}", - f"num_feat_static_cat={self.num_feat_static_cat}", - f"static_cardinalities={self.static_cardinalities}", - ]) + info = ", ".join( + [ + f"size={len(self)}", + f"freq={self.freq}", + f"num_feat_dynamic_real={self.num_feat_dynamic_real}", + f"num_past_feat_dynamic_real={self.num_past_feat_dynamic_real}", + f"num_feat_static_real={self.num_feat_static_real}", + f"num_feat_static_cat={self.num_feat_static_cat}", + f"static_cardinalities={self.static_cardinalities}", + ] + ) return f"PandasDataset<{info}>" @classmethod diff --git a/src/gluonts/dataset/repository/_lstnet.py b/src/gluonts/dataset/repository/_lstnet.py index 2841bbb92d..e933666c77 100644 --- a/src/gluonts/dataset/repository/_lstnet.py +++ b/src/gluonts/dataset/repository/_lstnet.py @@ -161,12 +161,14 @@ def generate_lstnet_dataset( for cat, ts in enumerate(timeseries): sliced_ts = ts[:training_end] if len(sliced_ts) > 0: - train_ts.append({ - "target": sliced_ts.values, - "start": sliced_ts.index[0], - "feat_static_cat": [cat], - "item_id": cat, - }) + train_ts.append( + { + "target": sliced_ts.values, + "start": sliced_ts.index[0], + "feat_static_cat": [cat], + "item_id": cat, + } + ) assert len(train_ts) == ds_info.num_series @@ -184,12 +186,14 @@ def generate_lstnet_dataset( prediction_start_date + ds_info.prediction_length ) sliced_ts = ts[:prediction_end_date] - test_ts.append({ - "target": sliced_ts.values, - "start": sliced_ts.index[0], - "feat_static_cat": [cat], - "item_id": cat, - }) + test_ts.append( + { + "target": sliced_ts.values, + "start": sliced_ts.index[0], + "feat_static_cat": [cat], + "item_id": cat, + } + ) assert len(test_ts) == ds_info.num_series * ds_info.rolling_evaluations diff --git a/src/gluonts/dataset/repository/_m3.py b/src/gluonts/dataset/repository/_m3.py index 5c0b3ba86c..84206d0f15 100644 --- a/src/gluonts/dataset/repository/_m3.py +++ b/src/gluonts/dataset/repository/_m3.py @@ -163,19 +163,23 @@ def normalize_category(c: str): start = str(pd.Period(time_stamp, freq=subset.freq)) cat = [i, cat_map[category]] - train_data.append({ - "target": target[: -subset.prediction_length], - "start": start, - "feat_static_cat": cat, - "item_id": series, - }) - - test_data.append({ - "target": target, - "start": start, - "feat_static_cat": cat, - "item_id": series, - }) + train_data.append( + { + "target": target[: -subset.prediction_length], + "start": start, + "feat_static_cat": cat, + "item_id": series, + } + ) + + test_data.append( + { + "target": target, + "start": start, + "feat_static_cat": cat, + "item_id": series, + } + ) meta = MetaData( **metadata( diff --git a/src/gluonts/dataset/repository/_tsf_datasets.py b/src/gluonts/dataset/repository/_tsf_datasets.py index d9c27c9ecc..ba073cdf4c 100644 --- a/src/gluonts/dataset/repository/_tsf_datasets.py +++ b/src/gluonts/dataset/repository/_tsf_datasets.py @@ -201,23 +201,27 @@ def convert_data( # timestamps # - `item_id` is added for all datasets ... many datasets provide # the "series_name" - test_data.append({ - "target": data_entry["target"], - "start": str( - data_entry.get("start_timestamp", default_start_timestamp) - ), - "item_id": data_entry.get("series_name", i), - "feat_static_cat": [i], - }) + test_data.append( + { + "target": data_entry["target"], + "start": str( + data_entry.get("start_timestamp", default_start_timestamp) + ), + "item_id": data_entry.get("series_name", i), + "feat_static_cat": [i], + } + ) - train_data.append({ - "target": data_entry["target"][:-train_offset], - "start": str( - data_entry.get("start_timestamp", default_start_timestamp) - ), - "item_id": data_entry.get("series_name", i), - "feat_static_cat": [i], - }) + train_data.append( + { + "target": data_entry["target"][:-train_offset], + "start": str( + data_entry.get("start_timestamp", default_start_timestamp) + ), + "item_id": data_entry.get("series_name", i), + "feat_static_cat": [i], + } + ) return train_data, test_data diff --git a/src/gluonts/dataset/schema/translate.py b/src/gluonts/dataset/schema/translate.py index 09f0e4e602..5ea7c41955 100644 --- a/src/gluonts/dataset/schema/translate.py +++ b/src/gluonts/dataset/schema/translate.py @@ -141,12 +141,14 @@ class TokenStream: @classmethod def from_str(cls, s): - stream = cls([ - Token(name, value, match) - for match in re.finditer(cls.RX, s) - for name, value in valfilter(bool, match.groupdict()).items() - if name != "WHITESPACE" - ]) + stream = cls( + [ + Token(name, value, match) + for match in re.finditer(cls.RX, s) + for name, value in valfilter(bool, match.groupdict()).items() + if name != "WHITESPACE" + ] + ) for token in stream: if token.name == "INVALID": diff --git a/src/gluonts/ev/metrics.py b/src/gluonts/ev/metrics.py index 00f74bb2f6..4d3e55b335 100644 --- a/src/gluonts/ev/metrics.py +++ b/src/gluonts/ev/metrics.py @@ -124,9 +124,12 @@ def update(self, data: Mapping[str, np.ndarray]) -> Self: return self def get(self) -> np.ndarray: - return self.post_process(**{ - name: evaluator.get() for name, evaluator in self.metrics.items() - }) + return self.post_process( + **{ + name: evaluator.get() + for name, evaluator in self.metrics.items() + } + ) @runtime_checkable diff --git a/src/gluonts/evaluation/_base.py b/src/gluonts/evaluation/_base.py index 8280c5a062..b623bf3d75 100644 --- a/src/gluonts/evaluation/_base.py +++ b/src/gluonts/evaluation/_base.py @@ -293,11 +293,13 @@ def __call__( # Thus we set dtype=np.float64 to convert masked values back to NaNs # which are handled correctly by pandas Dataframes during # aggregation. - metrics_per_ts = metrics_per_ts.astype({ - col: np.float64 - for col in metrics_per_ts.columns - if col not in ["item_id", "forecast_start"] - }) + metrics_per_ts = metrics_per_ts.astype( + { + col: np.float64 + for col in metrics_per_ts.columns + if col not in ["item_id", "forecast_start"] + } + ) return self.get_aggregate_metrics(metrics_per_ts) @@ -534,18 +536,26 @@ def get_aggregate_metrics( totals[f"QuantileLoss[{quantile}]"] / totals["abs_target_sum"] ) - totals["mean_absolute_QuantileLoss"] = np.array([ - totals[f"QuantileLoss[{quantile}]"] for quantile in self.quantiles - ]).mean() - - totals["mean_wQuantileLoss"] = np.array([ - totals[f"wQuantileLoss[{quantile}]"] for quantile in self.quantiles - ]).mean() - - totals["MAE_Coverage"] = np.mean([ - np.abs(totals[f"Coverage[{quantile}]"] - np.array([q.value])) - for q in self.quantiles - ]) + totals["mean_absolute_QuantileLoss"] = np.array( + [ + totals[f"QuantileLoss[{quantile}]"] + for quantile in self.quantiles + ] + ).mean() + + totals["mean_wQuantileLoss"] = np.array( + [ + totals[f"wQuantileLoss[{quantile}]"] + for quantile in self.quantiles + ] + ).mean() + + totals["MAE_Coverage"] = np.mean( + [ + np.abs(totals[f"Coverage[{quantile}]"] - np.array([q.value])) + for q in self.quantiles + ] + ) # Compute OWA if required if self.calculate_owa: diff --git a/src/gluonts/ext/hierarchicalforecast.py b/src/gluonts/ext/hierarchicalforecast.py index 0bb4ab4ce5..44ed9d7b54 100644 --- a/src/gluonts/ext/hierarchicalforecast.py +++ b/src/gluonts/ext/hierarchicalforecast.py @@ -96,11 +96,13 @@ def unpivot(df: pd.DataFrame) -> pd.DataFrame: """ n, k = df.shape - return pd.DataFrame({ - "unique_id": np.asarray(df.columns).repeat(n), - "ds": np.tile(np.asarray(df.index), k), - "y": df.to_numpy().ravel("F"), - }) + return pd.DataFrame( + { + "unique_id": np.asarray(df.columns).repeat(n), + "ds": np.tile(np.asarray(df.index), k), + "y": df.to_numpy().ravel("F"), + } + ) def format_reconciled_forecasts( @@ -241,14 +243,16 @@ def __init__( def predict_item(self, entry: DataEntry) -> QuantileForecast: kwargs = {} - if self.config.intervals is not None and all([ - proportion not in _build_fn_name(self.hrec.reconcilers[0]) - for proportion in [ - "forecast_proportions", - "average_proportions", - "proportion_averages", + if self.config.intervals is not None and all( + [ + proportion not in _build_fn_name(self.hrec.reconcilers[0]) + for proportion in [ + "forecast_proportions", + "average_proportions", + "proportion_averages", + ] ] - ]): + ): kwargs["level"] = self.config.intervals Y_df = format_data_entry(entry, self.S) @@ -296,15 +300,17 @@ def predict_item(self, entry: DataEntry) -> QuantileForecast: fcst_col_names = self.config.statsforecast_keys # prepare for QuantileForecast format - forecast_arrays = np.array([ - format_reconciled_forecasts( - df=Y_hat_df_rec, - fcst_col_name=fcst_col_names[e], - prediction_length=self.prediction_length, - S=self.S, - ) - for e, k in enumerate(self.config.statsforecast_keys) - ]) + forecast_arrays = np.array( + [ + format_reconciled_forecasts( + df=Y_hat_df_rec, + fcst_col_name=fcst_col_names[e], + prediction_length=self.prediction_length, + S=self.S, + ) + for e, k in enumerate(self.config.statsforecast_keys) + ] + ) return QuantileForecast( forecast_arrays=forecast_arrays, diff --git a/src/gluonts/ext/rotbaum/_model.py b/src/gluonts/ext/rotbaum/_model.py index fcca9250c3..7233df39c8 100644 --- a/src/gluonts/ext/rotbaum/_model.py +++ b/src/gluonts/ext/rotbaum/_model.py @@ -180,10 +180,12 @@ def fit( if not model_is_already_trained: self.model.fit(x_train, y_train, **kwargs) y_train_pred = self.model.predict(x_train) - df = pd.DataFrame({ - "y_true": y_train, - "y_pred": y_train_pred, - }).reset_index(drop=True) + df = pd.DataFrame( + { + "y_true": y_train, + "y_pred": y_train_pred, + } + ).reset_index(drop=True) self.sorted_train_preds = sorted(df["y_pred"].unique()) cell_values_dict = self.preprocess_df( df, min_bin_size=self.min_bin_size diff --git a/src/gluonts/ext/rotbaum/_predictor.py b/src/gluonts/ext/rotbaum/_predictor.py index ab324fa69c..6631e8dde0 100644 --- a/src/gluonts/ext/rotbaum/_predictor.py +++ b/src/gluonts/ext/rotbaum/_predictor.py @@ -411,15 +411,19 @@ def explain( assert self.model_list is not None - importances = np.array([ + importances = np.array( [ - self.model_list[time_stamp] - .models[quantile] - .booster_.feature_importance(importance_type=importance_type) - for time_stamp in range(self.prediction_length) + [ + self.model_list[time_stamp] + .models[quantile] + .booster_.feature_importance( + importance_type=importance_type + ) + for time_stamp in range(self.prediction_length) + ] + for quantile in self.quantiles ] - for quantile in self.quantiles - ]).transpose((2, 1, 0)) + ).transpose((2, 1, 0)) # The shape is: (features, pred_length, quantiles) importances = importances.mean(axis=2) # Average over quantiles # The shape of importances is: (features, pred_length) @@ -459,15 +463,17 @@ def explain( ) for i in range(num_feat_static_cat): - coordinate_map["feat_static_cat"].append(( - dynamic_length - + num_feat_static_real - + static_cat_features_so_far, - dynamic_length - + num_feat_static_real - + static_cat_features_so_far - + cardinality[i], - )) + coordinate_map["feat_static_cat"].append( + ( + dynamic_length + + num_feat_static_real + + static_cat_features_so_far, + dynamic_length + + num_feat_static_real + + static_cat_features_so_far + + cardinality[i], + ) + ) static_cat_features_so_far += cardinality[i] coordinate_map["past_feat_dynamic_real"] = [ @@ -517,11 +523,13 @@ def explain( ) logger.info(f"shape of importance matrix is: {importances.shape}") assert ( - sum([ - sum([coor[1] - coor[0] for coor in coordinate_map[key]]) - for key in coordinate_map - if key != "target" - ]) + sum( + [ + sum([coor[1] - coor[0] for coor in coordinate_map[key]]) + for key in coordinate_map + if key != "target" + ] + ) + coordinate_map["target"][1] - coordinate_map["target"][0] ) == importances.shape[ diff --git a/src/gluonts/ext/rotbaum/_preprocess.py b/src/gluonts/ext/rotbaum/_preprocess.py index 6ab57d569d..c45835a45e 100644 --- a/src/gluonts/ext/rotbaum/_preprocess.py +++ b/src/gluonts/ext/rotbaum/_preprocess.py @@ -168,13 +168,15 @@ def preprocess_from_single_ts(self, time_series: Dict) -> Tuple: feature_data.append( list(featurized_data) + [forecast_horizon_index] ) - target_data.append([ - time_series["target"][ - starting_index - + self.context_window_size - + forecast_horizon_index + target_data.append( + [ + time_series["target"][ + starting_index + + self.context_window_size + + forecast_horizon_index + ] ] - ]) + ) else: featurized_data = self.make_features( altered_time_series, starting_index @@ -479,37 +481,41 @@ def make_features(self, time_series: Dict, starting_index: int) -> List: past_feat_dynamic_real = ( list( - chain(*[ - prefix + list(ent[0]) + list(ent[1].values()) - for ent in [ - self._pre_transform( - ts if prefix else ts[starting_index:end_index], - self.subtract_mean, - self.count_nans, - ) - for ts in time_series["past_feat_dynamic_real"] + chain( + *[ + prefix + list(ent[0]) + list(ent[1].values()) + for ent in [ + self._pre_transform( + ts if prefix else ts[starting_index:end_index], + self.subtract_mean, + self.count_nans, + ) + for ts in time_series["past_feat_dynamic_real"] + ] ] - ]) + ) ) if self.use_past_feat_dynamic_real else [] ) feat_dynamic_real = ( list( - chain(*[ - list(ent[0]) + list(ent[1].values()) - for ent in [ - self._pre_transform( - ts[ - starting_index : end_index - + self.forecast_horizon - ], - self.subtract_mean, - self.count_nans, - ) - for ts in time_series["feat_dynamic_real"] + chain( + *[ + list(ent[0]) + list(ent[1].values()) + for ent in [ + self._pre_transform( + ts[ + starting_index : end_index + + self.forecast_horizon + ], + self.subtract_mean, + self.count_nans, + ) + for ts in time_series["feat_dynamic_real"] + ] ] - ]) + ) ) if self.use_feat_dynamic_real else [] diff --git a/src/gluonts/model/evaluation.py b/src/gluonts/model/evaluation.py index 597cee9703..473aa2397e 100644 --- a/src/gluonts/model/evaluation.py +++ b/src/gluonts/model/evaluation.py @@ -147,10 +147,12 @@ def evaluate_forecasts_raw( input_batches, label_batches, forecast_batches ): if 0 not in axis: - index_data.extend([ - (forecast.item_id, forecast.start_date) - for forecast in forecast_batch - ]) + index_data.extend( + [ + (forecast.item_id, forecast.start_date) + for forecast in forecast_batch + ] + ) data_batch = _get_data_batch( input_batch, diff --git a/src/gluonts/model/forecast.py b/src/gluonts/model/forecast.py index 73c0a6093f..c69ae385fc 100644 --- a/src/gluonts/model/forecast.py +++ b/src/gluonts/model/forecast.py @@ -532,19 +532,23 @@ def dim(self) -> int: return self._dim def __repr__(self): - return ", ".join([ - f"SampleForecast({self.samples!r})", - f"{self.start_date!r}", - f"item_id={self.item_id!r}", - f"info={self.info!r})", - ]) + return ", ".join( + [ + f"SampleForecast({self.samples!r})", + f"{self.start_date!r}", + f"item_id={self.item_id!r}", + f"info={self.info!r})", + ] + ) def to_quantile_forecast(self, quantiles: List[str]) -> "QuantileForecast": return QuantileForecast( - forecast_arrays=np.array([ - self.quantile(q) if q != "mean" else self.mean - for q in quantiles - ]), + forecast_arrays=np.array( + [ + self.quantile(q) if q != "mean" else self.mean + for q in quantiles + ] + ), start_date=self.start_date, forecast_keys=quantiles, item_id=self.item_id, @@ -686,10 +690,12 @@ def dim(self) -> int: return self._dim def __repr__(self): - return ", ".join([ - f"QuantileForecast({self.forecast_array!r})", - f"start_date={self.start_date!r}", - f"forecast_keys={self.forecast_keys!r}", - f"item_id={self.item_id!r}", - f"info={self.info!r})", - ]) + return ", ".join( + [ + f"QuantileForecast({self.forecast_array!r})", + f"start_date={self.start_date!r}", + f"forecast_keys={self.forecast_keys!r}", + f"item_id={self.item_id!r}", + f"info={self.info!r})", + ] + ) diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index 6408a72e6c..42caf1fa9e 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -171,10 +171,12 @@ def __call__( outputs = output_transform(batch, outputs) collected_samples.append(outputs) num_collected_samples += outputs[0].shape[0] - outputs = np.stack([ - np.concatenate(s)[:num_samples] - for s in zip(*collected_samples) - ]) + outputs = np.stack( + [ + np.concatenate(s)[:num_samples] + for s in zip(*collected_samples) + ] + ) assert len(outputs[0]) == num_samples i = -1 for i, output in enumerate(outputs): diff --git a/src/gluonts/model/npts/_predictor.py b/src/gluonts/model/npts/_predictor.py index 7bb088a796..baf77fdd10 100644 --- a/src/gluonts/model/npts/_predictor.py +++ b/src/gluonts/model/npts/_predictor.py @@ -205,10 +205,14 @@ def predict( custom_features: Optional[np.ndarray] if "feat_dynamic_real" in data.keys(): - custom_features = np.array([ - dynamic_feature[-train_length - self.prediction_length :] - for dynamic_feature in data["feat_dynamic_real"] - ]) + custom_features = np.array( + [ + dynamic_feature[ + -train_length - self.prediction_length : + ] + for dynamic_feature in data["feat_dynamic_real"] + ] + ) else: custom_features = None diff --git a/src/gluonts/model/trivial/mean.py b/src/gluonts/model/trivial/mean.py index 027c66e5c5..ff9ce17b2c 100644 --- a/src/gluonts/model/trivial/mean.py +++ b/src/gluonts/model/trivial/mean.py @@ -157,9 +157,12 @@ def train( training_data: Dataset, validation_dataset: Optional[Dataset] = None, ) -> ConstantPredictor: - contexts = np.array([ - item["target"][-self.prediction_length :] for item in training_data - ]) + contexts = np.array( + [ + item["target"][-self.prediction_length :] + for item in training_data + ] + ) samples = np.broadcast_to( array=contexts.mean(axis=0), diff --git a/src/gluonts/mx/block/dropout.py b/src/gluonts/mx/block/dropout.py index de192777f1..2522820a2c 100644 --- a/src/gluonts/mx/block/dropout.py +++ b/src/gluonts/mx/block/dropout.py @@ -230,11 +230,13 @@ def mask(p, like): # only for RNN, the first element of states is output. Use the same # mask as output, instead of simply copy output to the first element # in case that the base cell is ResidualCell - new_states = [( - F.where(output_mask, next_states[0], states[0]) - if p_outputs != 0.0 - else next_states[0] - )] + new_states = [ + ( + F.where(output_mask, next_states[0], states[0]) + if p_outputs != 0.0 + else next_states[0] + ) + ] new_states.extend( [ F.where(mask(p_states, new_s), new_s, old_s) diff --git a/src/gluonts/mx/distribution/iresnet.py b/src/gluonts/mx/distribution/iresnet.py index 78f56fca9b..5a80c141ca 100644 --- a/src/gluonts/mx/distribution/iresnet.py +++ b/src/gluonts/mx/distribution/iresnet.py @@ -199,6 +199,9 @@ def iresnet(num_blocks: int, **block_kwargs) -> ComposedBijectionHybridBlock: ------- """ - return ComposedBijectionHybridBlock([ - InvertibleResnetHybridBlock(**block_kwargs) for _ in range(num_blocks) - ]) + return ComposedBijectionHybridBlock( + [ + InvertibleResnetHybridBlock(**block_kwargs) + for _ in range(num_blocks) + ] + ) diff --git a/src/gluonts/mx/model/deep_factor/_estimator.py b/src/gluonts/mx/model/deep_factor/_estimator.py index 222726c9cd..e867f17e8e 100644 --- a/src/gluonts/mx/model/deep_factor/_estimator.py +++ b/src/gluonts/mx/model/deep_factor/_estimator.py @@ -168,18 +168,22 @@ def __init__( ) def create_transformation(self) -> Transformation: - return Chain([ - AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=time_features_from_frequency_str(self.freq), - pred_length=self.prediction_length, - ), - SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0.0]), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ]) + return Chain( + [ + AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=time_features_from_frequency_str(self.freq), + pred_length=self.prediction_length, + ), + SetFieldIfNotPresent( + field=FieldName.FEAT_STATIC_CAT, value=[0.0] + ), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ] + ) def _create_instance_splitter(self, mode: str): return transform.InstanceSplitter( diff --git a/src/gluonts/mx/model/deepvar/_estimator.py b/src/gluonts/mx/model/deepvar/_estimator.py index 5e68ec1dfa..9a95ab0c9a 100644 --- a/src/gluonts/mx/model/deepvar/_estimator.py +++ b/src/gluonts/mx/model/deepvar/_estimator.py @@ -331,44 +331,48 @@ def __init__( ) def create_transformation(self) -> Transformation: - return Chain([ - AsNumpyArray( - field=FieldName.TARGET, - expected_ndim=1 + len(self.distr_output.event_shape), - ), - # maps the target to (1, T) - # if the target data is uni dimensional - ExpandDimArray( - field=FieldName.TARGET, - axis=0 if self.distr_output.event_shape[0] == 1 else None, - ), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME] - + ( - [FieldName.FEAT_DYNAMIC_REAL] - if self.use_feat_dynamic_real - else [] + return Chain( + [ + AsNumpyArray( + field=FieldName.TARGET, + expected_ndim=1 + len(self.distr_output.event_shape), ), - ), - SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0.0]), - TargetDimIndicator( - field_name="target_dimension_indicator", - target_field=FieldName.TARGET, - ), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ]) + # maps the target to (1, T) + # if the target data is uni dimensional + ExpandDimArray( + field=FieldName.TARGET, + axis=0 if self.distr_output.event_shape[0] == 1 else None, + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME] + + ( + [FieldName.FEAT_DYNAMIC_REAL] + if self.use_feat_dynamic_real + else [] + ), + ), + SetFieldIfNotPresent( + field=FieldName.FEAT_STATIC_CAT, value=[0.0] + ), + TargetDimIndicator( + field_name="target_dimension_indicator", + target_field=FieldName.TARGET, + ), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ] + ) def _create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] @@ -399,12 +403,14 @@ def _create_instance_splitter(self, mode: str): target_dim=self.target_dim, ) if self.use_marginal_transformation - else RenameFields({ - f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", - f"future_{FieldName.TARGET}": ( - f"future_{FieldName.TARGET}_cdf" - ), - }) + else RenameFields( + { + f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", + f"future_{FieldName.TARGET}": ( + f"future_{FieldName.TARGET}_cdf" + ), + } + ) ) def create_training_data_loader( diff --git a/src/gluonts/mx/model/gpvar/_estimator.py b/src/gluonts/mx/model/gpvar/_estimator.py index c71f0d1a92..29aea6ac31 100644 --- a/src/gluonts/mx/model/gpvar/_estimator.py +++ b/src/gluonts/mx/model/gpvar/_estimator.py @@ -244,39 +244,43 @@ def __init__( ) def create_transformation(self) -> Transformation: - return Chain([ - AsNumpyArray( - field=FieldName.TARGET, - expected_ndim=1 + len(self.distr_output.event_shape), - ), - # maps the target to (1, T) if the target data is uni - # dimensional - ExpandDimArray( - field=FieldName.TARGET, - axis=0 if self.distr_output.event_shape[0] == 1 else None, - ), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME], - ), - SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0.0]), - TargetDimIndicator( - field_name=FieldName.TARGET_DIM_INDICATOR, - target_field=FieldName.TARGET, - ), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ]) + return Chain( + [ + AsNumpyArray( + field=FieldName.TARGET, + expected_ndim=1 + len(self.distr_output.event_shape), + ), + # maps the target to (1, T) if the target data is uni + # dimensional + ExpandDimArray( + field=FieldName.TARGET, + axis=0 if self.distr_output.event_shape[0] == 1 else None, + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME], + ), + SetFieldIfNotPresent( + field=FieldName.FEAT_STATIC_CAT, value=[0.0] + ), + TargetDimIndicator( + field_name=FieldName.TARGET_DIM_INDICATOR, + target_field=FieldName.TARGET, + ), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ] + ) def _create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] @@ -309,14 +313,16 @@ def _create_instance_splitter(self, mode: str): target_dim=self.target_dim, ) if self.use_marginal_transformation - else RenameFields({ - f"past_{FieldName.TARGET}": ( - f"past_{FieldName.TARGET}_cdf" - ), - f"future_{FieldName.TARGET}": ( - f"future_{FieldName.TARGET}_cdf" - ), - }) + else RenameFields( + { + f"past_{FieldName.TARGET}": ( + f"past_{FieldName.TARGET}_cdf" + ), + f"future_{FieldName.TARGET}": ( + f"future_{FieldName.TARGET}_cdf" + ), + } + ) ) + SampleTargetDim( field_name=FieldName.TARGET_DIM_INDICATOR, diff --git a/src/gluonts/mx/model/n_beats/_ensemble.py b/src/gluonts/mx/model/n_beats/_ensemble.py index 838a0bc15c..4351163251 100644 --- a/src/gluonts/mx/model/n_beats/_ensemble.py +++ b/src/gluonts/mx/model/n_beats/_ensemble.py @@ -353,10 +353,12 @@ def __init__( self.freq = freq self.prediction_length = prediction_length - assert meta_loss_function is None or all([ - loss_function in VALID_LOSS_FUNCTIONS - for loss_function in meta_loss_function - ]), ( + assert meta_loss_function is None or all( + [ + loss_function in VALID_LOSS_FUNCTIONS + for loss_function in meta_loss_function + ] + ), ( "Each loss function has to be one of the following:" f" {VALID_LOSS_FUNCTIONS}." ) diff --git a/src/gluonts/mx/model/renewal/_estimator.py b/src/gluonts/mx/model/renewal/_estimator.py index 1443d73ffb..c944850f91 100644 --- a/src/gluonts/mx/model/renewal/_estimator.py +++ b/src/gluonts/mx/model/renewal/_estimator.py @@ -181,20 +181,22 @@ def _create_instance_splitter(self, mode: str): @staticmethod def _create_post_split_transform(): - return Chain([ - CountTrailingZeros( - new_field="time_remaining", - target_field="past_target", - as_array=True, - ), - ToIntervalSizeFormat( - target_field="past_target", discard_first=True - ), - RenameFields({"future_target": "sparse_future"}), - AsNumpyArray(field="past_target", expected_ndim=2), - SwapAxes(input_fields=["past_target"], axes=(0, 1)), - AddAxisLength(target_field="past_target", axis=0), - ]) + return Chain( + [ + CountTrailingZeros( + new_field="time_remaining", + target_field="past_target", + as_array=True, + ), + ToIntervalSizeFormat( + target_field="past_target", discard_first=True + ), + RenameFields({"future_target": "sparse_future"}), + AsNumpyArray(field="past_target", expected_ndim=2), + SwapAxes(input_fields=["past_target"], axes=(0, 1)), + AddAxisLength(target_field="past_target", axis=0), + ] + ) def _stack_fn(self) -> Callable: return partial( diff --git a/src/gluonts/mx/model/renewal/_transform.py b/src/gluonts/mx/model/renewal/_transform.py index 986ec20c34..884ec31051 100644 --- a/src/gluonts/mx/model/renewal/_transform.py +++ b/src/gluonts/mx/model/renewal/_transform.py @@ -51,9 +51,13 @@ def __init__( def transform(self, data: DataEntry) -> DataEntry: target = data[self.target_field] - data[self.output_field] = np.array([( - len(target) - if isinstance(target, list) - else target.shape[self.axis] - )]) + data[self.output_field] = np.array( + [ + ( + len(target) + if isinstance(target, list) + else target.shape[self.axis] + ) + ] + ) return data diff --git a/src/gluonts/mx/model/seq2seq/_forking_estimator.py b/src/gluonts/mx/model/seq2seq/_forking_estimator.py index 910758f9bf..7bb38915f2 100644 --- a/src/gluonts/mx/model/seq2seq/_forking_estimator.py +++ b/src/gluonts/mx/model/seq2seq/_forking_estimator.py @@ -287,14 +287,16 @@ def create_transformation(self) -> Transformation: if not self.use_feat_static_cat: remove_field_names.append(FieldName.FEAT_STATIC_CAT) - chain.extend([ - RemoveFields(field_names=remove_field_names), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - dtype=self.dtype, - ), - ]) + chain.extend( + [ + RemoveFields(field_names=remove_field_names), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + dtype=self.dtype, + ), + ] + ) # --- TRANSFORMATION CHAIN FOR DYNAMIC FEATURES --- diff --git a/src/gluonts/mx/model/seq2seq/_seq2seq_estimator.py b/src/gluonts/mx/model/seq2seq/_seq2seq_estimator.py index a4fc483115..bd73cbc57e 100644 --- a/src/gluonts/mx/model/seq2seq/_seq2seq_estimator.py +++ b/src/gluonts/mx/model/seq2seq/_seq2seq_estimator.py @@ -123,26 +123,30 @@ def __init__( self.num_parallel_samples = num_parallel_samples def create_transformation(self) -> transform.Transformation: - return transform.Chain([ - transform.AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), - transform.AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=time_features_from_frequency_str(self.freq), - pred_length=self.prediction_length, - ), - transform.VstackFeatures( - output_field=FieldName.FEAT_DYNAMIC_REAL, - input_fields=[FieldName.FEAT_TIME], - ), - transform.SetFieldIfNotPresent( - field=FieldName.FEAT_STATIC_CAT, value=[0.0] - ), - transform.AsNumpyArray( - field=FieldName.FEAT_STATIC_CAT, expected_ndim=1 - ), - ]) + return transform.Chain( + [ + transform.AsNumpyArray( + field=FieldName.TARGET, expected_ndim=1 + ), + transform.AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=time_features_from_frequency_str(self.freq), + pred_length=self.prediction_length, + ), + transform.VstackFeatures( + output_field=FieldName.FEAT_DYNAMIC_REAL, + input_fields=[FieldName.FEAT_TIME], + ), + transform.SetFieldIfNotPresent( + field=FieldName.FEAT_STATIC_CAT, value=[0.0] + ), + transform.AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, expected_ndim=1 + ), + ] + ) def _create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] diff --git a/src/gluonts/mx/model/tft/_estimator.py b/src/gluonts/mx/model/tft/_estimator.py index c6fbe6abc6..a4d8335719 100644 --- a/src/gluonts/mx/model/tft/_estimator.py +++ b/src/gluonts/mx/model/tft/_estimator.py @@ -200,10 +200,12 @@ def __init__( def create_transformation(self) -> Transformation: transforms = ( [AsNumpyArray(field=FieldName.TARGET, expected_ndim=1)] - + ([ - AsNumpyArray(field=name, expected_ndim=1) - for name in self.static_cardinalities.keys() - ]) + + ( + [ + AsNumpyArray(field=name, expected_ndim=1) + for name in self.static_cardinalities.keys() + ] + ) + [ AsNumpyArray(field=name, expected_ndim=1) for name in chain( @@ -239,13 +241,17 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend([ - SetField( - output_field=FieldName.FEAT_STATIC_CAT, - value=[0.0], - ), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ]) + transforms.extend( + [ + SetField( + output_field=FieldName.FEAT_STATIC_CAT, + value=[0.0], + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, expected_ndim=1 + ), + ] + ) if self.static_feature_dims: transforms.append( @@ -256,15 +262,17 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend([ - SetField( - output_field=FieldName.FEAT_STATIC_REAL, - value=[0.0], - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_REAL, expected_ndim=1 - ), - ]) + transforms.extend( + [ + SetField( + output_field=FieldName.FEAT_STATIC_REAL, + value=[0.0], + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_REAL, expected_ndim=1 + ), + ] + ) if self.dynamic_cardinalities: transforms.append( @@ -274,20 +282,22 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend([ - SetField( - output_field=FieldName.FEAT_DYNAMIC_CAT, - value=[[0.0]], - ), - AsNumpyArray( - field=FieldName.FEAT_DYNAMIC_CAT, - expected_ndim=2, - ), - BroadcastTo( - field=FieldName.FEAT_DYNAMIC_CAT, - ext_length=self.prediction_length, - ), - ]) + transforms.extend( + [ + SetField( + output_field=FieldName.FEAT_DYNAMIC_CAT, + value=[[0.0]], + ), + AsNumpyArray( + field=FieldName.FEAT_DYNAMIC_CAT, + expected_ndim=2, + ), + BroadcastTo( + field=FieldName.FEAT_DYNAMIC_CAT, + ext_length=self.prediction_length, + ), + ] + ) input_fields = [FieldName.FEAT_TIME] if self.dynamic_feature_dims: @@ -307,17 +317,19 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend([ - SetField( - output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat", - value=[[0.0]], - ), - AsNumpyArray( - field=FieldName.PAST_FEAT_DYNAMIC + "_cat", - expected_ndim=2, - ), - BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC + "_cat"), - ]) + transforms.extend( + [ + SetField( + output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat", + value=[[0.0]], + ), + AsNumpyArray( + field=FieldName.PAST_FEAT_DYNAMIC + "_cat", + expected_ndim=2, + ), + BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC + "_cat"), + ] + ) if self.past_dynamic_feature_dims: transforms.append( @@ -327,16 +339,18 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend([ - SetField( - output_field=FieldName.PAST_FEAT_DYNAMIC_REAL, - value=[[0.0]], - ), - AsNumpyArray( - field=FieldName.PAST_FEAT_DYNAMIC_REAL, expected_ndim=2 - ), - BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC_REAL), - ]) + transforms.extend( + [ + SetField( + output_field=FieldName.PAST_FEAT_DYNAMIC_REAL, + value=[[0.0]], + ), + AsNumpyArray( + field=FieldName.PAST_FEAT_DYNAMIC_REAL, expected_ndim=2 + ), + BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC_REAL), + ] + ) return Chain(transforms) diff --git a/src/gluonts/mx/model/tpp/deeptpp/_estimator.py b/src/gluonts/mx/model/tpp/deeptpp/_estimator.py index 4a16504b3a..e6390e899c 100644 --- a/src/gluonts/mx/model/tpp/deeptpp/_estimator.py +++ b/src/gluonts/mx/model/tpp/deeptpp/_estimator.py @@ -177,17 +177,21 @@ def _create_instance_splitter(self, mode: str): assert isinstance(instance_sampler, ContinuousTimePointSampler) - return Chain([ - ContinuousTimeInstanceSplitter( - past_interval_length=self.context_interval_length, - future_interval_length=self.prediction_interval_length, - instance_sampler=instance_sampler, - ), - RenameFields({ - "past_target": "target", - "past_valid_length": "valid_length", - }), - ]) + return Chain( + [ + ContinuousTimeInstanceSplitter( + past_interval_length=self.context_interval_length, + future_interval_length=self.prediction_interval_length, + instance_sampler=instance_sampler, + ), + RenameFields( + { + "past_target": "target", + "past_valid_length": "valid_length", + } + ), + ] + ) def create_training_data_loader( self, diff --git a/src/gluonts/mx/model/tpp/forecast.py b/src/gluonts/mx/model/tpp/forecast.py index facde8b9dc..4082a12987 100644 --- a/src/gluonts/mx/model/tpp/forecast.py +++ b/src/gluonts/mx/model/tpp/forecast.py @@ -132,15 +132,17 @@ def index(self) -> pd.PeriodIndex: ) def __repr__(self): - return ", ".join([ - f"PointProcessSampleForecast({self.samples!r})", - f"{self.valid_length!r}", - f"{self.start_date!r}", - f"{self.end_date!r}", - f"{self.freq!r}", - f"item_id={self.item_id!r}", - f"info={self.info!r})", - ]) + return ", ".join( + [ + f"PointProcessSampleForecast({self.samples!r})", + f"{self.valid_length!r}", + f"{self.start_date!r}", + f"{self.end_date!r}", + f"{self.freq!r}", + f"item_id={self.item_id!r}", + f"info={self.info!r})", + ] + ) def quantile(self, q: Union[float, str]) -> np.ndarray: raise NotImplementedError( diff --git a/src/gluonts/mx/model/wavenet/_estimator.py b/src/gluonts/mx/model/wavenet/_estimator.py index 20e14e8c49..112ab3fdc8 100644 --- a/src/gluonts/mx/model/wavenet/_estimator.py +++ b/src/gluonts/mx/model/wavenet/_estimator.py @@ -265,31 +265,35 @@ def __init__( ) def create_transformation(self) -> transform.Transformation: - return Chain([ - AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=time_features_from_frequency_str(self.freq), - pred_length=self.prediction_length, - ), - AddAgeFeature( - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_AGE, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE], - ), - SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0.0]), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ]) + return Chain( + [ + AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=time_features_from_frequency_str(self.freq), + pred_length=self.prediction_length, + ), + AddAgeFeature( + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_AGE, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE], + ), + SetFieldIfNotPresent( + field=FieldName.FEAT_STATIC_CAT, value=[0.0] + ), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ] + ) def _create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] diff --git a/src/gluonts/mx/representation/binning_helpers.py b/src/gluonts/mx/representation/binning_helpers.py index b9c555032b..31d37417ae 100644 --- a/src/gluonts/mx/representation/binning_helpers.py +++ b/src/gluonts/mx/representation/binning_helpers.py @@ -31,11 +31,13 @@ def ensure_binning_monotonicity(bin_centers: np.ndarray): def bin_edges_from_bin_centers(bin_centers: np.ndarray): lower_edge = -np.inf upper_edge = np.inf - bin_edges = np.concatenate([ - [lower_edge], - (bin_centers[1:] + bin_centers[:-1]) / 2.0, - [upper_edge], - ]) + bin_edges = np.concatenate( + [ + [lower_edge], + (bin_centers[1:] + bin_centers[:-1]) / 2.0, + [upper_edge], + ] + ) return bin_edges diff --git a/src/gluonts/nursery/SCott/dataset_tools/algo_clustering.py b/src/gluonts/nursery/SCott/dataset_tools/algo_clustering.py index ebcf547400..63009b7cc7 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/algo_clustering.py +++ b/src/gluonts/nursery/SCott/dataset_tools/algo_clustering.py @@ -124,15 +124,19 @@ def KMeans_inside_dataset( 0, len(target) - len_sample, prediction_length ): ts_slice = target[ts_sample_start : ts_sample_start + len_sample] - feature = torch.cat(( - feature, - torch.Tensor([ - ts_slice.mean(), - ts_slice.var(), - index % 7, - index // 90, - ]), - )) + feature = torch.cat( + ( + feature, + torch.Tensor( + [ + ts_slice.mean(), + ts_slice.var(), + index % 7, + index // 90, + ] + ), + ) + ) index += 1 feature = feature.reshape(index, 4) feature = _get_pre_features(feature).contiguous() @@ -150,16 +154,20 @@ def KMeans_inside_dataset( ): ts_slice = target[ts_sample_start : ts_sample_start + len_sample] gid = cl[sample_id] - dataset_group[gid].append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - }) - whole_data.append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - }) + dataset_group[gid].append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) + whole_data.append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) unsplit_start += pd.Timedelta(hours=prediction_length) sample_id += 1 print(len(whole_data)) @@ -220,14 +228,18 @@ def KMeans_m5_dataset( # import pdb;pdb.set_trace() gid = cl[sample_id] unsplit_start = pd.Timestamp("1990-01-01") - dataset_group[gid].append({ - "target": ts_slice, - "start": unsplit_start, - }) # , 'feat_static_cat': train_entry['feat_static_cat']} - whole_data.append({ - "target": ts_slice, - "start": unsplit_start, - }) # , 'feat_static_cat': train_entry['feat_static_cat']} + dataset_group[gid].append( + { + "target": ts_slice, + "start": unsplit_start, + } + ) # , 'feat_static_cat': train_entry['feat_static_cat']} + whole_data.append( + { + "target": ts_slice, + "start": unsplit_start, + } + ) # , 'feat_static_cat': train_entry['feat_static_cat']} sample_id += 1 print(len(whole_data)) ret["group_ratio"] = [len(i) / len(whole_data) for i in dataset_group] diff --git a/src/gluonts/nursery/SCott/dataset_tools/electricity.py b/src/gluonts/nursery/SCott/dataset_tools/electricity.py index 782ea49437..6e75d5d95c 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/electricity.py +++ b/src/gluonts/nursery/SCott/dataset_tools/electricity.py @@ -102,11 +102,13 @@ def group_electricity_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - }) + test_full_data.append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + } + ) print("total number of training examples: ", len(train_full_data)) ret["group_ratio"] = [len(i) / len(train_full_data) for i in dataset_group] @@ -167,16 +169,20 @@ def group_electricity_mb( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append({ - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - }) - dataset_group[gid].append({ - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - }) + train_full_data.append( + { + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) + dataset_group[gid].append( + { + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) unsplit_start += pd.Timedelta(hours=prediction_length) # get ready the test data @@ -190,11 +196,13 @@ def group_electricity_mb( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - }) + test_full_data.append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + } + ) print("total number of training examples: ", len(train_full_data)) ret["group_ratio"] = [len(i) / len(train_full_data) for i in dataset_group] diff --git a/src/gluonts/nursery/SCott/dataset_tools/exchange_rate.py b/src/gluonts/nursery/SCott/dataset_tools/exchange_rate.py index 76f56e1a04..cc5e609296 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/exchange_rate.py +++ b/src/gluonts/nursery/SCott/dataset_tools/exchange_rate.py @@ -58,16 +58,20 @@ def group_exchangerate_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - }) - dataset_group[gid].append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - }) + train_full_data.append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) + dataset_group[gid].append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) unsplit_start += pd.Timedelta("1D") * prediction_length # get ready the test data for i in range(int(num_ts * 0.2)): @@ -80,11 +84,13 @@ def group_exchangerate_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - }) + test_full_data.append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + } + ) print("total number of training examples: ", len(train_full_data)) ret["group_ratio"] = [len(i) / len(train_full_data) for i in dataset_group] print("ratio for each group: ", ret["group_ratio"]) diff --git a/src/gluonts/nursery/SCott/dataset_tools/group_raw_data.py b/src/gluonts/nursery/SCott/dataset_tools/group_raw_data.py index 84bf0402b0..c89a580eac 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/group_raw_data.py +++ b/src/gluonts/nursery/SCott/dataset_tools/group_raw_data.py @@ -71,16 +71,20 @@ def get_m4_by_freq( continue nu = 1 + sum(ts_slice) / len_sample ts_slice = [i / nu for i in ts_slice] - whole_data.append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - }) - dataset_group[i].append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - }) + whole_data.append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) + dataset_group[i].append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) # unsplit_start += pd.Timedelta(hours=prediction_length*hours_factor[i]) unsplit_start += pd.Timedelta(hours=prediction_length) # for j in range(len(dataset_group)): @@ -136,14 +140,18 @@ def get_temperature_data( nu = 1 + sum(ts_slice) / len(ts_slice) ts_slice /= nu if torch.sum(torch.isnan(ts_slice)).item() == 0: - dataset_group[gid].append({ - "target": ts_slice, - "start": pd.Timestamp(datetime[index]), - }) - whole_data.append({ - "target": ts_slice, - "start": pd.Timestamp(datetime[index]), - }) + dataset_group[gid].append( + { + "target": ts_slice, + "start": pd.Timestamp(datetime[index]), + } + ) + whole_data.append( + { + "target": ts_slice, + "start": pd.Timestamp(datetime[index]), + } + ) if num_samples == samples_per_ts: break random.shuffle(whole_data) @@ -231,12 +239,14 @@ def get_group_data_by_var(name, num_groups, len_sample=9): ) ) continue - dataset_group[group_id].append({ - "target": unsplit_ts[ - ts_sample_start : ts_sample_start + len_sample - ], - "start": unsplit_start, - }) + dataset_group[group_id].append( + { + "target": unsplit_ts[ + ts_sample_start : ts_sample_start + len_sample + ], + "start": unsplit_start, + } + ) unsplit_start += pd.Timedelta(hours=1) import pdb @@ -295,14 +305,18 @@ def get_group_data_by_duplicate(name, num_duplicates, num_groups): {"target": train_entry["target"], "start": train_entry["start"]} ) for j in range(num_duplicates): - dataset_group[i % num_groups].append({ - "target": train_entry["target"], - "start": train_entry["start"], - }) - whole_data_list.append({ - "target": train_entry["target"], - "start": train_entry["start"], - }) + dataset_group[i % num_groups].append( + { + "target": train_entry["target"], + "start": train_entry["start"], + } + ) + whole_data_list.append( + { + "target": train_entry["target"], + "start": train_entry["start"], + } + ) random.shuffle(whole_data_list) random.shuffle(no_duplicate_whole_data_list) ret.append( @@ -328,10 +342,12 @@ def get_whole_data_by_duplicate(name, num_duplicates): {"target": train_entry["target"], "start": train_entry["start"]} ) for j in range(num_duplicates): - dataset_group.append({ - "target": train_entry["target"], - "start": train_entry["start"], - }) + dataset_group.append( + { + "target": train_entry["target"], + "start": train_entry["start"], + } + ) random.shuffle(dataset_group) random.shuffle(no_duplicate_whole_data_list) ret.append( @@ -350,10 +366,12 @@ def get_group_data(name): train_entry = next(it) dataset_group.append( ListDataset( - [{ - "target": train_entry["target"], - "start": train_entry["start"], - }], + [ + { + "target": train_entry["target"], + "start": train_entry["start"], + } + ], freq=dataset.metadata.freq, ) ) @@ -402,20 +420,24 @@ def get_synthetic_data(model_name=None, num_groups=8, mean_boundary=1): prediction = net.get_distr(ts_slice).sample((5000,)) prediction = sum(prediction) / len(prediction) ts = torch.cat([ts, prediction], dim=1) - whole_data_list.append({ - "target": ts.view( - len(ts[0]), - )[context_length:], - "start": start, - }) + whole_data_list.append( + { + "target": ts.view( + len(ts[0]), + )[context_length:], + "start": start, + } + ) dataset_group.append( ListDataset( - [{ - "target": ts.view( - len(ts[0]), - )[context_length:], - "start": start, - }], + [ + { + "target": ts.view( + len(ts[0]), + )[context_length:], + "start": start, + } + ], freq="1H", ) ) @@ -586,18 +608,22 @@ def get_synthetic_data_linear_simple( ) for j in range(num_duplicates): ts += torch.normal(0, 0.01, size=ts.shape) - whole_data_list.append({ - "target": ts.view( - len(ts[0]), - ), - "start": start, - }) - pattern_group.append({ - "target": ts.view( - len(ts[0]), - ), - "start": start, - }) + whole_data_list.append( + { + "target": ts.view( + len(ts[0]), + ), + "start": start, + } + ) + pattern_group.append( + { + "target": ts.view( + len(ts[0]), + ), + "start": start, + } + ) dataset_group.append(ListDataset(pattern_group, freq="1D")) random.shuffle(whole_data_list) @@ -632,26 +658,32 @@ def get_synthetic_data_sin( 1, num_time_steps ) ts += torch.FloatTensor((gid + 1) * base).view(1, num_time_steps) - no_duplicate_whole_data_list.append({ - "target": ts.view( - len(ts[0]), - ), - "start": start, - }) - for j in range(num_duplicates): - ts += torch.normal(0, 0.1, size=ts.shape) - whole_data_list.append({ - "target": ts.view( - len(ts[0]), - ), - "start": start, - }) - pattern_group.append({ + no_duplicate_whole_data_list.append( + { "target": ts.view( len(ts[0]), ), "start": start, - }) + } + ) + for j in range(num_duplicates): + ts += torch.normal(0, 0.1, size=ts.shape) + whole_data_list.append( + { + "target": ts.view( + len(ts[0]), + ), + "start": start, + } + ) + pattern_group.append( + { + "target": ts.view( + len(ts[0]), + ), + "start": start, + } + ) dataset_group.append(ListDataset(pattern_group, freq="1D")) random.shuffle(whole_data_list) diff --git a/src/gluonts/nursery/SCott/dataset_tools/synthetic.py b/src/gluonts/nursery/SCott/dataset_tools/synthetic.py index b176ef3658..8e2dc383b6 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/synthetic.py +++ b/src/gluonts/nursery/SCott/dataset_tools/synthetic.py @@ -64,20 +64,24 @@ def get_mixed_pattern(unit_length=16, num_duplicates=1000): pattern[(gid + i) % pattern_number], ) ) - ts_sample = torch.cat([ - context, - _get_mixed_pattern( - torch.arange(prediction_length, dtype=torch.float), - pattern[gid], - ), - ]) + ts_sample = torch.cat( + [ + context, + _get_mixed_pattern( + torch.arange(prediction_length, dtype=torch.float), + pattern[gid], + ), + ] + ) whole_data.append({"target": ts_sample, "start": start}) if j % 5 == 0: - val_data.append({ - "target": ts_sample - + torch.normal(0, 1, ts_sample.shape), - "start": start, - }) + val_data.append( + { + "target": ts_sample + + torch.normal(0, 1, ts_sample.shape), + "start": start, + } + ) dataset_group[m * 4 + gid].append( {"target": ts_sample, "start": start} ) diff --git a/src/gluonts/nursery/SCott/dataset_tools/traffic.py b/src/gluonts/nursery/SCott/dataset_tools/traffic.py index 4ec45b14fd..8582ba87b2 100644 --- a/src/gluonts/nursery/SCott/dataset_tools/traffic.py +++ b/src/gluonts/nursery/SCott/dataset_tools/traffic.py @@ -70,16 +70,20 @@ def group_traffic_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append({ - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - }) - dataset_group[gid].append({ - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - }) + train_full_data.append( + { + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) + dataset_group[gid].append( + { + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) unsplit_start += pd.Timedelta(hours=prediction_length) # get ready the test data @@ -93,11 +97,13 @@ def group_traffic_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - }) + test_full_data.append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + } + ) print("total number of training examples: ", len(train_full_data)) ret["group_ratio"] = [len(i) / len(train_full_data) for i in dataset_group] @@ -158,16 +164,20 @@ def group_traffic_mb( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append({ - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - }) - dataset_group[gid].append({ - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - }) + train_full_data.append( + { + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) + dataset_group[gid].append( + { + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) unsplit_start += pd.Timedelta(hours=prediction_length) # get ready the test data @@ -181,11 +191,13 @@ def group_traffic_mb( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - }) + test_full_data.append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + } + ) print("total number of training examples: ", len(train_full_data)) ret["group_ratio"] = [len(i) / len(train_full_data) for i in dataset_group] diff --git a/src/gluonts/nursery/SCott/model/ar/ar_estimator.py b/src/gluonts/nursery/SCott/model/ar/ar_estimator.py index 116e2f26f6..62d78b0aae 100644 --- a/src/gluonts/nursery/SCott/model/ar/ar_estimator.py +++ b/src/gluonts/nursery/SCott/model/ar/ar_estimator.py @@ -63,20 +63,22 @@ def __init__( # transformation that includes time features, age feature, observed values # indicator, etc. def create_transformation(self, is_full_batch=False) -> Transformation: - return Chain([ - InstanceSplitter( - target_field=FieldName.TARGET, - is_pad_field=FieldName.IS_PAD, - start_field=FieldName.START, - forecast_start_field=FieldName.FORECAST_START, - # train_sampler=ExpectedNumInstanceSampler(num_instances=1), - train_sampler=CustomUniformSampler(), - past_length=self.context_length, - future_length=self.prediction_length, - is_full_batch=is_full_batch, - time_series_fields=[], # [FieldName.FEAT_DYNAMIC_REAL] - ) - ]) + return Chain( + [ + InstanceSplitter( + target_field=FieldName.TARGET, + is_pad_field=FieldName.IS_PAD, + start_field=FieldName.START, + forecast_start_field=FieldName.FORECAST_START, + # train_sampler=ExpectedNumInstanceSampler(num_instances=1), + train_sampler=CustomUniformSampler(), + past_length=self.context_length, + future_length=self.prediction_length, + is_full_batch=is_full_batch, + time_series_fields=[], # [FieldName.FEAT_DYNAMIC_REAL] + ) + ] + ) # defines the network, we get to see one batch to initialize it. # the network should return at least one tensor that is used as a loss to minimize in the training loop. diff --git a/src/gluonts/nursery/SCott/model/lstm/lstm_estimator.py b/src/gluonts/nursery/SCott/model/lstm/lstm_estimator.py index 3519b8c995..2b55e12da0 100644 --- a/src/gluonts/nursery/SCott/model/lstm/lstm_estimator.py +++ b/src/gluonts/nursery/SCott/model/lstm/lstm_estimator.py @@ -66,20 +66,22 @@ def __init__( # transformation that includes time features, age feature, observed values # indicator, etc. def create_transformation(self, is_full_batch=False) -> Transformation: - return Chain([ - InstanceSplitter( - target_field=FieldName.TARGET, - is_pad_field=FieldName.IS_PAD, - start_field=FieldName.START, - forecast_start_field=FieldName.FORECAST_START, - # train_sampler=ExpectedNumInstanceSampler(num_instances=1), - train_sampler=CustomUniformSampler(), - past_length=self.context_length, - future_length=self.prediction_length, - is_full_batch=is_full_batch, - time_series_fields=[], # [FieldName.FEAT_DYNAMIC_REAL] - ) - ]) + return Chain( + [ + InstanceSplitter( + target_field=FieldName.TARGET, + is_pad_field=FieldName.IS_PAD, + start_field=FieldName.START, + forecast_start_field=FieldName.FORECAST_START, + # train_sampler=ExpectedNumInstanceSampler(num_instances=1), + train_sampler=CustomUniformSampler(), + past_length=self.context_length, + future_length=self.prediction_length, + is_full_batch=is_full_batch, + time_series_fields=[], # [FieldName.FEAT_DYNAMIC_REAL] + ) + ] + ) # defines the network, we get to see one batch to initialize it. # the network should return at least one tensor that is used as a loss to minimize in the training loop. diff --git a/src/gluonts/nursery/SCott/preprocess_data.py b/src/gluonts/nursery/SCott/preprocess_data.py index 2b6981edea..40b181820f 100644 --- a/src/gluonts/nursery/SCott/preprocess_data.py +++ b/src/gluonts/nursery/SCott/preprocess_data.py @@ -64,20 +64,24 @@ def get_mixed_pattern(unit_length=16, num_duplicates=1000): pattern[(gid + i) % pattern_number], ) ) - ts_sample = torch.cat([ - context, - _get_mixed_pattern( - torch.arange(prediction_length, dtype=torch.float), - pattern[gid], - ), - ]) + ts_sample = torch.cat( + [ + context, + _get_mixed_pattern( + torch.arange(prediction_length, dtype=torch.float), + pattern[gid], + ), + ] + ) whole_data.append({"target": ts_sample, "start": start}) if j % 5 == 0: - val_data.append({ - "target": ts_sample - + torch.normal(0, 1, ts_sample.shape), - "start": start, - }) + val_data.append( + { + "target": ts_sample + + torch.normal(0, 1, ts_sample.shape), + "start": start, + } + ) dataset_group[m * 4 + gid].append( {"target": ts_sample, "start": start} ) @@ -153,16 +157,20 @@ def group_electricity_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append({ - "target": ts_slice, - "start": t, - "feat_static_cat": np.array([gid]), - }) - dataset_group[gid].append({ - "target": ts_slice, - "start": t, - "feat_static_cat": np.array([gid]), - }) + train_full_data.append( + { + "target": ts_slice, + "start": t, + "feat_static_cat": np.array([gid]), + } + ) + dataset_group[gid].append( + { + "target": ts_slice, + "start": t, + "feat_static_cat": np.array([gid]), + } + ) unsplit_start += pd.Timedelta(hours=prediction_length) # get ready the test data @@ -176,11 +184,13 @@ def group_electricity_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - }) + test_full_data.append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + } + ) print( "Generating the electricity training data, the total number of" @@ -238,16 +248,20 @@ def group_electricity_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append({ - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - }) - dataset_group[gid].append({ - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - }) + train_full_data.append( + { + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) + dataset_group[gid].append( + { + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) unsplit_start += pd.Timedelta(hours=prediction_length) # get ready the test data @@ -261,11 +275,13 @@ def group_electricity_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - }) + test_full_data.append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + } + ) print("total number of training examples: ", len(train_full_data)) ret["group_ratio"] = [len(i) / len(train_full_data) for i in dataset_group] @@ -322,16 +338,20 @@ def group_exchangerate_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - }) - dataset_group[gid].append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": train_entry["feat_static_cat"], - }) + train_full_data.append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) + dataset_group[gid].append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) unsplit_start += pd.Timedelta("1D") * prediction_length # get ready the test data for i in range(int(num_ts * 0.2)): @@ -344,11 +364,13 @@ def group_exchangerate_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - }) + test_full_data.append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + } + ) print( "Generating the exchange rate training data, the total number of" " training examples:", @@ -419,16 +441,20 @@ def group_traffic_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - train_full_data.append({ - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - }) - dataset_group[gid].append({ - "target": ts_slice, - "start": t, - "feat_static_cat": train_entry["feat_static_cat"], - }) + train_full_data.append( + { + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) + dataset_group[gid].append( + { + "target": ts_slice, + "start": t, + "feat_static_cat": train_entry["feat_static_cat"], + } + ) unsplit_start += pd.Timedelta(hours=prediction_length) # get ready the test data @@ -442,11 +468,13 @@ def group_traffic_cv( ts_slice = unsplit_ts[ ts_sample_start : ts_sample_start + len_sample ] - test_full_data.append({ - "target": ts_slice, - "start": unsplit_start, - "feat_static_cat": test_entry["feat_static_cat"], - }) + test_full_data.append( + { + "target": ts_slice, + "start": unsplit_start, + "feat_static_cat": test_entry["feat_static_cat"], + } + ) print( "Generating the traffic training data, the total number of training" diff --git a/src/gluonts/nursery/daf/engine/parallel.py b/src/gluonts/nursery/daf/engine/parallel.py index d95bce51a7..ae3682efa6 100644 --- a/src/gluonts/nursery/daf/engine/parallel.py +++ b/src/gluonts/nursery/daf/engine/parallel.py @@ -300,20 +300,22 @@ def __init__( **kwargs, ) del self.optimizer - self.optimizer = optimizer([ - { - "params": chain( - self.model.src.generative_parameters(), - self.model.tgt.generative_parameters(), - ) - }, - { - "params": chain( - self.model.src.discriminative_parameters(), - self.model.tgt.discriminative_parameters(), - ) - }, - ]) + self.optimizer = optimizer( + [ + { + "params": chain( + self.model.src.generative_parameters(), + self.model.tgt.generative_parameters(), + ) + }, + { + "params": chain( + self.model.src.discriminative_parameters(), + self.model.tgt.discriminative_parameters(), + ) + }, + ] + ) def _train(self, *data): self.model.generative() diff --git a/src/gluonts/nursery/daf/estimator/modules.py b/src/gluonts/nursery/daf/estimator/modules.py index 3f92ac7192..5541b36a3e 100644 --- a/src/gluonts/nursery/daf/estimator/modules.py +++ b/src/gluonts/nursery/daf/estimator/modules.py @@ -103,10 +103,12 @@ def n_layer(self) -> int: @property def tie_layers(self) -> bool: return (self.n_layer == 1) or ( - all([ - (a.encoder is b.encoder) and (a.decoder is b.decoder) - for a, b in product(self.blocks[:1], self.blocks[1:]) - ]) + all( + [ + (a.encoder is b.encoder) and (a.decoder is b.decoder) + for a, b in product(self.blocks[:1], self.blocks[1:]) + ] + ) ) def register_loss_func(self, func: LossFunction) -> None: diff --git a/src/gluonts/nursery/daf/tslib/dataset/timeseries.py b/src/gluonts/nursery/daf/tslib/dataset/timeseries.py index 11270eb9fb..4bd63964b8 100644 --- a/src/gluonts/nursery/daf/tslib/dataset/timeseries.py +++ b/src/gluonts/nursery/daf/tslib/dataset/timeseries.py @@ -108,23 +108,25 @@ def d_data(self) -> int: return self.target.shape[1] def __eq__(self, other) -> bool: - return all([ - isinstance(other, TimeSeriesInstant), - np.array_equal(self.target, other.target), - np.array_equal(self.timestamp, other.timestamp), - self.series_name == other.series_name, - np.array_equal(self.target_names, other.target_names), - _dict_equal( - self.categorical_features, - other.categorical_features, - np.array_equal, - ), - _dict_equal( - self.numerical_features, - other.numerical_features, - np.array_equal, - ), - ]) + return all( + [ + isinstance(other, TimeSeriesInstant), + np.array_equal(self.target, other.target), + np.array_equal(self.timestamp, other.timestamp), + self.series_name == other.series_name, + np.array_equal(self.target_names, other.target_names), + _dict_equal( + self.categorical_features, + other.categorical_features, + np.array_equal, + ), + _dict_equal( + self.numerical_features, + other.numerical_features, + np.array_equal, + ), + ] + ) def __repr__(self): string = f"time = {self.timestamp:%Y-%m-%d %H:%M:%S}\n" @@ -289,43 +291,45 @@ def dynamic_numerical_features(self) -> Dict: } def __eq__(self, other): - return all([ - isinstance(other, TimeSeries), - np.array_equal(self.target, other.target), - np.array_equal(self.time_index, other.time_index), - self.series_name == other.series_name, - np.array_equal(self.target_names, other.target_names), - _dict_equal( - self.static_categorical_features, - other.static_categorical_features, - np.array_equal, - ), - _dict_equal( - self.static_numerical_features, - other.static_numerical_features, - np.array_equal, - ), - _dict_equal( - self.revealed_categorical_features, - other.revealed_categorical_features, - np.array_equal, - ), - _dict_equal( - self.revealed_numerical_features, - other.revealed_numerical_features, - np.array_equal, - ), - _dict_equal( - self.observed_categorical_features, - other.observed_categorical_features, - np.array_equal, - ), - _dict_equal( - self.observed_numerical_features, - other.observed_numerical_features, - np.array_equal, - ), - ]) + return all( + [ + isinstance(other, TimeSeries), + np.array_equal(self.target, other.target), + np.array_equal(self.time_index, other.time_index), + self.series_name == other.series_name, + np.array_equal(self.target_names, other.target_names), + _dict_equal( + self.static_categorical_features, + other.static_categorical_features, + np.array_equal, + ), + _dict_equal( + self.static_numerical_features, + other.static_numerical_features, + np.array_equal, + ), + _dict_equal( + self.revealed_categorical_features, + other.revealed_categorical_features, + np.array_equal, + ), + _dict_equal( + self.revealed_numerical_features, + other.revealed_numerical_features, + np.array_equal, + ), + _dict_equal( + self.observed_categorical_features, + other.observed_categorical_features, + np.array_equal, + ), + _dict_equal( + self.observed_numerical_features, + other.observed_numerical_features, + np.array_equal, + ), + ] + ) def __len__(self): return len(self.target) @@ -558,14 +562,16 @@ def _check_consistency( self, instances: List[TimeSeries] ) -> List[TimeSeries]: def _consistent(ts1: TimeSeries, ts2: TimeSeries) -> bool: - return all([ - np.array_equal(ts1.target_names, ts2.target_names), - ts1._static_features == ts2._static_features, - ts1._revealed_features == ts2._revealed_features, - ts1._observed_features == ts2._observed_features, - ts1._categorical_features == ts2._categorical_features, - ts1._numerical_features == ts2._numerical_features, - ]) + return all( + [ + np.array_equal(ts1.target_names, ts2.target_names), + ts1._static_features == ts2._static_features, + ts1._revealed_features == ts2._revealed_features, + ts1._observed_features == ts2._observed_features, + ts1._categorical_features == ts2._categorical_features, + ts1._numerical_features == ts2._numerical_features, + ] + ) cats = defaultdict(list) nums = defaultdict(list) diff --git a/src/gluonts/nursery/daf/tslib/metrics/dict.py b/src/gluonts/nursery/daf/tslib/metrics/dict.py index 53d0e3ea5b..8ac9fa8c3f 100644 --- a/src/gluonts/nursery/daf/tslib/metrics/dict.py +++ b/src/gluonts/nursery/daf/tslib/metrics/dict.py @@ -193,10 +193,12 @@ def _add_spaces(str_, n_spaces=4): main_str = "\n".join( [f"{name}: {repr(meter)}" for name, meter in self._meters.items()] ) - child_str = "\n".join([ - f"{name}:\n{_add_spaces(repr(meterdict))}" - for name, meterdict in self._meterdicts.items() - ]) + child_str = "\n".join( + [ + f"{name}:\n{_add_spaces(repr(meterdict))}" + for name, meterdict in self._meterdicts.items() + ] + ) if child_str: main_str = "\n".join([main_str, child_str]) return main_str diff --git a/src/gluonts/nursery/daf/tslib/nn/attention/posemb.py b/src/gluonts/nursery/daf/tslib/nn/attention/posemb.py index a002e9d2f5..ae560cfb7d 100644 --- a/src/gluonts/nursery/daf/tslib/nn/attention/posemb.py +++ b/src/gluonts/nursery/daf/tslib/nn/attention/posemb.py @@ -63,10 +63,12 @@ def __init__( self.max_len = max_len self.sub_shape = sub_shape - self._weights = nn.ParameterList([ - nn.Parameter(Tensor(size, dim)) - for size, dim in zip(self.sub_shape, self.d_sub_embeds) - ]) + self._weights = nn.ParameterList( + [ + nn.Parameter(Tensor(size, dim)) + for size, dim in zip(self.sub_shape, self.d_sub_embeds) + ] + ) self._reset_parameters() def _reset_parameters(self): diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py index abce0e6f77..5a2dab200b 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/artificial.py @@ -209,17 +209,21 @@ def generate(self) -> None: def generate_split( self, split: Literal["train", "val", "test"], n_samples: int ) -> None: - queries, support_sets = zip(*[ - generate_artificial_tuplets( - dataset_name=self.dataset_name, - context_length=self.context_length, - support_length=self.support_length, - prediction_length=self.prediction_length, - support_set_size=self.support_set_size, - item_id=i, - ) - for i in tqdm(range(n_samples), desc="generating artificial data") - ]) + queries, support_sets = zip( + *[ + generate_artificial_tuplets( + dataset_name=self.dataset_name, + context_length=self.context_length, + support_length=self.support_length, + prediction_length=self.prediction_length, + support_set_size=self.support_set_size, + item_id=i, + ) + for i in tqdm( + range(n_samples), desc="generating artificial data" + ) + ] + ) _write_data_to_file(self.root / split / "data.json", queries) _write_data_to_file( self.root / split / ".support_set.json", support_sets diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/cheat.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/cheat.py index 964c0fc65a..8f53989c7f 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/cheat.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/cheat.py @@ -234,35 +234,39 @@ def train_dataloader(self) -> DataLoader[TripletBatch]: def val_dataloader(self) -> DataLoader[TripletBatch]: splits = [self.splits.val(d_name) for d_name in self.dataset_names_val] - return list([ - DataLoader( - TripletDataset( - queries=split.data(), support_sets=split.support_set() - ), - collate_fn=TripletBatch.collate, - batch_size=self.batch_size_val_test, - num_workers=self.num_workers, - pin_memory=True, - ) - for split in splits - ]) + return list( + [ + DataLoader( + TripletDataset( + queries=split.data(), support_sets=split.support_set() + ), + collate_fn=TripletBatch.collate, + batch_size=self.batch_size_val_test, + num_workers=self.num_workers, + pin_memory=True, + ) + for split in splits + ] + ) def test_dataloader(self) -> DataLoader[TripletBatch]: splits = [ self.splits.test(d_name) for d_name in self.dataset_names_test ] - return list([ - DataLoader( - TripletDataset( - queries=split.data(), support_sets=split.support_set() - ), - collate_fn=TripletBatch.collate, - batch_size=self.batch_size_val_test, - num_workers=self.num_workers, - pin_memory=True, - ) - for split in splits - ]) + return list( + [ + DataLoader( + TripletDataset( + queries=split.data(), support_sets=split.support_set() + ), + collate_fn=TripletBatch.collate, + batch_size=self.batch_size_val_test, + num_workers=self.num_workers, + pin_memory=True, + ) + for split in splits + ] + ) def generate(self) -> None: if self.root.exists(): @@ -297,14 +301,18 @@ def generate_split( always_cheat=True, query_length_scale=5.0, ) -> None: - queries, support_sets = zip(*[ - self.generate_artificial_tuplets( - item_id=i, - always_cheat=always_cheat, - query_length_scale=query_length_scale, - ) - for i in tqdm(range(n_samples), desc="generating artificial data") - ]) + queries, support_sets = zip( + *[ + self.generate_artificial_tuplets( + item_id=i, + always_cheat=always_cheat, + query_length_scale=query_length_scale, + ) + for i in tqdm( + range(n_samples), desc="generating artificial data" + ) + ] + ) _write_data_to_file(self.root / split / "data.json", queries) _write_data_to_file( self.root / split / ".support_set.json", support_sets @@ -577,12 +585,14 @@ def train_dataloader(self) -> DataLoader[TripletBatch]: for d_name in self.dataset_names_train ] - datasets = ConcatDataset([ - TripletDataset( - queries=split.data(), support_sets=split.support_set() - ) - for split in splits - ]) + datasets = ConcatDataset( + [ + TripletDataset( + queries=split.data(), support_sets=split.support_set() + ) + for split in splits + ] + ) return DataLoader( datasets, @@ -628,15 +638,19 @@ def generate_split( query_length_scale: float = 5.0, counterfactual_size: int = 0, ) -> None: - queries, support_sets = zip(*[ - self.generate_artificial_tuplets( - item_id=i, - always_cheat=always_cheat, - query_length_scale=query_length_scale, - counterfactual_size=counterfactual_size, - ) - for i in tqdm(range(n_samples), desc="generating artificial data") - ]) + queries, support_sets = zip( + *[ + self.generate_artificial_tuplets( + item_id=i, + always_cheat=always_cheat, + query_length_scale=query_length_scale, + counterfactual_size=counterfactual_size, + ) + for i in tqdm( + range(n_samples), desc="generating artificial data" + ) + ] + ) _write_data_to_file(self.root / split / "data.json", queries) _write_data_to_file( self.root / split / ".support_set.json", support_sets diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/datasets.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/datasets.py index c2152b9cc9..ddfec8fbf8 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/datasets.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/datasets.py @@ -132,9 +132,11 @@ def sample_datasets( random_state=random_state.randint(low=0, high=10000), ) folds.append((train_split, val_split, test_split)) - assert not any([ - set(train_split) & (set(val_split)), - set(train_split) & set(test_split), - set(val_split) & set(test_split), - ]), "Splits should not intersect!" + assert not any( + [ + set(train_split) & (set(val_split)), + set(train_split) & set(test_split), + set(val_split) & set(test_split), + ] + ), "Splits should not intersect!" return folds diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/super.py b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/super.py index 10574993a5..2676713e3b 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/super.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/datasets/super.py @@ -195,16 +195,20 @@ def test_dataloader(self) -> DataLoader[TripletBatch]: def get_log_batches(self, n_logging_samples: int) -> Tuple[TripletBatch]: its_train = cycle( - list([ - iter(dm.sampling_triplet_dataset("train")) - for dm in self.data_modules_train - ]) + list( + [ + iter(dm.sampling_triplet_dataset("train")) + for dm in self.data_modules_train + ] + ) ) its_val = cycle( - list([ - iter(dm.sequential_triplet_dataset("val")) - for dm in self.data_modules_val - ]) + list( + [ + iter(dm.sequential_triplet_dataset("val")) + for dm in self.data_modules_val + ] + ) ) def get_log_batch(its): diff --git a/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py b/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py index 799317d38d..c5f7b3ad7f 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py +++ b/src/gluonts/nursery/few_shot_prediction/src/meta/models/module.py @@ -63,9 +63,12 @@ def __init__( self.val_quantile_width = nn.ModuleList( [quantile_width.clone() for _ in range(len(val_dataset_names))] ) - self.test_quantile_width = nn.ModuleList([ - quantile_width.clone() for _ in range(len(test_dataset_names)) - ]) + self.test_quantile_width = nn.ModuleList( + [ + quantile_width.clone() + for _ in range(len(test_dataset_names)) + ] + ) self.val_dataset_names = val_dataset_names ( diff --git a/src/gluonts/nursery/few_shot_prediction/src/scripts/data.py b/src/gluonts/nursery/few_shot_prediction/src/scripts/data.py index 4e5028cc54..0f0faf6589 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/scripts/data.py +++ b/src/gluonts/nursery/few_shot_prediction/src/scripts/data.py @@ -139,13 +139,15 @@ def statistics(): ) n_total_normalized = n_total / sum(n_total) - datasets, lengths_normalized, n_total_normalized = zip(*( - sorted( - list(zip(datasets, lengths_normalized, n_total_normalized)), - key=lambda x: x[1], - reverse=True, + datasets, lengths_normalized, n_total_normalized = zip( + *( + sorted( + list(zip(datasets, lengths_normalized, n_total_normalized)), + key=lambda x: x[1], + reverse=True, + ) ) - )) + ) x = np.arange(len(datasets)) diff --git a/src/gluonts/nursery/few_shot_prediction/src/scripts/train.py b/src/gluonts/nursery/few_shot_prediction/src/scripts/train.py index 556c6a0e65..b250d44577 100644 --- a/src/gluonts/nursery/few_shot_prediction/src/scripts/train.py +++ b/src/gluonts/nursery/few_shot_prediction/src/scripts/train.py @@ -297,35 +297,39 @@ def main( # add callback that only works with attention attention_models = ["iwata", "cnn_iwata", "tcn"] if model_name in attention_models: - callbacks.extend([ - ForecastSupportSetAttentionPlotLoggerCallback( - log_batch_train, - quantiles=quantiles, - split="train", - every_n_epochs=log_plot_every_n_epochs, - ), - ForecastSupportSetAttentionPlotLoggerCallback( - log_batch_val, - quantiles=quantiles, - split="val", - every_n_epochs=log_plot_every_n_epochs, - ), - ]) + callbacks.extend( + [ + ForecastSupportSetAttentionPlotLoggerCallback( + log_batch_train, + quantiles=quantiles, + split="train", + every_n_epochs=log_plot_every_n_epochs, + ), + ForecastSupportSetAttentionPlotLoggerCallback( + log_batch_val, + quantiles=quantiles, + split="val", + every_n_epochs=log_plot_every_n_epochs, + ), + ] + ) else: - callbacks.extend([ - ForecastPlotLoggerCallback( - log_batch_val, - quantiles=quantiles, - split="val", - every_n_epochs=log_plot_every_n_epochs, - ), - ForecastPlotLoggerCallback( - log_batch_train, - quantiles=quantiles, - split="train", - every_n_epochs=log_plot_every_n_epochs, - ), - ]) + callbacks.extend( + [ + ForecastPlotLoggerCallback( + log_batch_val, + quantiles=quantiles, + split="val", + every_n_epochs=log_plot_every_n_epochs, + ), + ForecastPlotLoggerCallback( + log_batch_train, + quantiles=quantiles, + split="train", + every_n_epochs=log_plot_every_n_epochs, + ), + ] + ) # -------------------- train model --------------------------------------------------- trainer = pl.Trainer( diff --git a/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py b/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py index 92f615ef67..768fe06070 100644 --- a/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py +++ b/src/gluonts/nursery/robust-mts-attack/multivariate/datasets/dataset.py @@ -42,11 +42,13 @@ def extract_dataset(dataset_name: str): def pivot_dataset(dataset): ds_list = list(dataset) - return [{ - "item": "0", - "start": ds_list[0]["start"], - "target": np.vstack([d["target"] for d in ds_list]), - }] + return [ + { + "item": "0", + "start": ds_list[0]["start"], + "target": np.vstack([d["target"] for d in ds_list]), + } + ] class MultivariateDatasetInfo(NamedTuple): @@ -238,14 +240,16 @@ def taxi_30min(max_target_dim: int = None): ) -datasets = OrderedDict([ - ("solar", solar), - ("exchange_rate", exchange_rate), - ("electricity", electricity), - ("traffic", traffic), - ("wikipedia", wiki), - ("taxi_30min", taxi_30min), -]) +datasets = OrderedDict( + [ + ("solar", solar), + ("exchange_rate", exchange_rate), + ("electricity", electricity), + ("traffic", traffic), + ("wikipedia", wiki), + ("taxi_30min", taxi_30min), + ] +) if __name__ == "__main__": extract_dataset("electricity_nips") diff --git a/src/gluonts/nursery/robust-mts-attack/pts/dataset/repository/_m5.py b/src/gluonts/nursery/robust-mts-attack/pts/dataset/repository/_m5.py index a94bec02c2..74d2aa7b09 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/dataset/repository/_m5.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/dataset/repository/_m5.py @@ -233,13 +233,15 @@ def get_sell_price(item_id, store_id): meta_file = dataset_path / "metadata.json" with open(meta_file, "w") as f: f.write( - json.dumps({ - "freq": pandas_freq, - "prediction_length": prediction_length, - "feat_static_cat": feat_static_cat, - "feat_dynamic_real": feat_dynamic_real, - "cardinality": len(train_ds), - }) + json.dumps( + { + "freq": pandas_freq, + "prediction_length": prediction_length, + "feat_static_cat": feat_static_cat, + "feat_dynamic_real": feat_dynamic_real, + "cardinality": len(train_ds), + } + ) ) # Build testing set diff --git a/src/gluonts/nursery/robust-mts-attack/pts/feature/holiday.py b/src/gluonts/nursery/robust-mts-attack/pts/feature/holiday.py index 253eb7db5b..aef6823c36 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/feature/holiday.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/feature/holiday.py @@ -88,13 +88,17 @@ def __call__(self, dates): dates Pandas series with Datetimeindex timestamps. """ - return np.vstack([ - np.hstack([ - self.kernel_function((index - ref_date).days) - for index in dates - ]) - for ref_date in self.reference_dates - ]).sum(0, keepdims=True) + return np.vstack( + [ + np.hstack( + [ + self.kernel_function((index - ref_date).days) + for index in dates + ] + ) + for ref_date in self.reference_dates + ] + ).sum(0, keepdims=True) class CustomHolidayFeatureSet: @@ -166,12 +170,16 @@ def __call__(self, dates): dates Pandas series with Datetimeindex timestamps. """ - return np.vstack([ - np.hstack([ - self.kernel_function( - distance_to_holiday(custom_holiday)(index) + return np.vstack( + [ + np.hstack( + [ + self.kernel_function( + distance_to_holiday(custom_holiday)(index) + ) + for index in dates + ] ) - for index in dates - ]) - for custom_holiday in self.custom_holidays - ]) + for custom_holiday in self.custom_holidays + ] + ) diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py index 7abe78f816..12fa525867 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/causal_deepar/causal_deepar_network.py @@ -241,17 +241,21 @@ def unroll_encoder( # from (batch_size, sub_seq_len, *target_shape, num_lags) # to (batch_size, sub_seq_len, prod(target_shape) * num_lags) - input_lags = lags_scaled.reshape(( - -1, - subsequences_length, - len(self.lags_seq) * prod(self.target_shape), - )) - - input_control_lags = control_lags_scaled.reshape(( - -1, - subsequences_length, - len(self.lags_seq) * prod(self.target_shape), - )) + input_lags = lags_scaled.reshape( + ( + -1, + subsequences_length, + len(self.lags_seq) * prod(self.target_shape), + ) + ) + + input_control_lags = control_lags_scaled.reshape( + ( + -1, + subsequences_length, + len(self.lags_seq) * prod(self.target_shape), + ) + ) # (batch_size, sub_seq_len, input_dim) inputs = torch.cat( @@ -558,11 +562,13 @@ def sampling_decoder( samples = torch.cat(future_samples, dim=1) # (batch_size, num_samples, prediction_length, *target_shape) - return samples.reshape(( - (-1, self.num_parallel_samples) - + (self.prediction_length,) - + self.target_shape - )) + return samples.reshape( + ( + (-1, self.num_parallel_samples) + + (self.prediction_length,) + + self.target_shape + ) + ) # noinspection PyMethodOverriding,PyPep8Naming def forward( diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py index 2f0314856c..8d66c91ec0 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepar/deepar_network.py @@ -209,11 +209,13 @@ def unroll_encoder( # from (batch_size, sub_seq_len, *target_shape, num_lags) # to (batch_size, sub_seq_len, prod(target_shape) * num_lags) - input_lags = lags_scaled.reshape(( - -1, - subsequences_length, - len(self.lags_seq) * prod(self.target_shape), - )) + input_lags = lags_scaled.reshape( + ( + -1, + subsequences_length, + len(self.lags_seq) * prod(self.target_shape), + ) + ) # (batch_size, sub_seq_len, input_dim) inputs = torch.cat( @@ -438,11 +440,13 @@ def sampling_decoder( samples = torch.cat(future_samples, dim=1) # (batch_size, num_samples, prediction_length, *target_shape) - return samples.reshape(( - (-1, self.num_parallel_samples) - + (self.prediction_length,) - + self.target_shape - )) + return samples.reshape( + ( + (-1, self.num_parallel_samples) + + (self.prediction_length,) + + self.target_shape + ) + ) # noinspection PyMethodOverriding,PyPep8Naming def forward( diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py index 5adb3c9364..4bea3ddf23 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_estimator.py @@ -260,10 +260,12 @@ def create_instance_splitter(self, mode: str): target_dim=self.target_dim, ) if self.use_marginal_transformation - else RenameFields({ - f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", - f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", - }) + else RenameFields( + { + f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", + f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", + } + ) ) def create_training_network( diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py index e035ccef03..64353b4adf 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/deepvar/deepvar_network.py @@ -596,12 +596,14 @@ def repeat(tensor, dim=0): samples = torch.cat(future_samples, dim=1) locs = torch.cat(loc, dim=1) # (batch_size, num_samples, prediction_length, target_dim) - return samples.reshape(( - -1, - self.num_parallel_samples, - self.prediction_length, - self.target_dim, - )) # , locs.reshape( + return samples.reshape( + ( + -1, + self.num_parallel_samples, + self.prediction_length, + self.target_dim, + ) + ) # , locs.reshape( # -1, # self.num_parallel_samples, # self.prediction_length, diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py b/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py index ab21ca12aa..ff9234273d 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_ensemble.py @@ -214,10 +214,12 @@ def __init__( self.freq = freq self.prediction_length = prediction_length - assert meta_loss_function is None or all([ - loss_function in VALID_LOSS_FUNCTIONS - for loss_function in meta_loss_function - ]), f"Each loss function has to be one of the following: {VALID_LOSS_FUNCTIONS}." + assert meta_loss_function is None or all( + [ + loss_function in VALID_LOSS_FUNCTIONS + for loss_function in meta_loss_function + ] + ), f"Each loss function has to be one of the following: {VALID_LOSS_FUNCTIONS}." assert meta_context_length is None or all( [context_length > 0 for context_length in meta_context_length] ), "The value of each `context_length` should be > 0" diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_estimator.py index 75d9b2adbd..c4ef409bc7 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/n_beats/n_beats_estimator.py @@ -158,20 +158,22 @@ def _validate_nbeats_argument( # that can be digested by our model by only splitting the target in two, a # conditioning part and a to-predict part, for each training example. def create_transformation(self) -> Transformation: - return Chain([ - RemoveFields( - field_names=[ - FieldName.FEAT_STATIC_REAL, - FieldName.FEAT_DYNAMIC_REAL, - FieldName.FEAT_DYNAMIC_CAT, - ] - ), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - dtype=self.dtype, - ), - ]) + return Chain( + [ + RemoveFields( + field_names=[ + FieldName.FEAT_STATIC_REAL, + FieldName.FEAT_DYNAMIC_REAL, + FieldName.FEAT_DYNAMIC_CAT, + ] + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + dtype=self.dtype, + ), + ] + ) def create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_estimator.py index 4049a36c8e..48c230e04c 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_estimator.py @@ -135,39 +135,43 @@ def __init__( ) def create_transformation(self) -> Transformation: - return Chain([ - AsNumpyArray( - field=FieldName.TARGET, - expected_ndim=2, - ), - # maps the target to (1, T) - # if the target data is uni dimensional - ExpandDimArray( - field=FieldName.TARGET, - axis=None, - ), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME], - ), - SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]), - TargetDimIndicator( - field_name="target_dimension_indicator", - target_field=FieldName.TARGET, - ), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ]) + return Chain( + [ + AsNumpyArray( + field=FieldName.TARGET, + expected_ndim=2, + ), + # maps the target to (1, T) + # if the target data is uni dimensional + ExpandDimArray( + field=FieldName.TARGET, + axis=None, + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME], + ), + SetFieldIfNotPresent( + field=FieldName.FEAT_STATIC_CAT, value=[0] + ), + TargetDimIndicator( + field_name="target_dimension_indicator", + target_field=FieldName.TARGET, + ), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ] + ) def create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] @@ -191,10 +195,12 @@ def create_instance_splitter(self, mode: str): FieldName.OBSERVED_VALUES, ], ) + ( - RenameFields({ - f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", - f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", - }) + RenameFields( + { + f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", + f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", + } + ) ) def create_training_network( diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_network.py index b070f3f8f7..6fc6ae1fa2 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/tempflow/tempflow_network.py @@ -550,12 +550,14 @@ def repeat(tensor, dim=0): samples = torch.cat(future_samples, dim=1) # (batch_size, num_samples, prediction_length, target_dim) - return samples.reshape(( - -1, - self.num_parallel_samples, - self.prediction_length, - self.target_dim, - )) + return samples.reshape( + ( + -1, + self.num_parallel_samples, + self.prediction_length, + self.target_dim, + ) + ) def forward( self, diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py index d894ac7277..fdb47f7d87 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_estimator.py @@ -129,10 +129,12 @@ def __init__( def create_transformation(self) -> Transformation: transforms = ( [AsNumpyArray(field=FieldName.TARGET, expected_ndim=1)] - + ([ - AsNumpyArray(field=name, expected_ndim=1) - for name in self.static_cardinalities.keys() - ]) + + ( + [ + AsNumpyArray(field=name, expected_ndim=1) + for name in self.static_cardinalities.keys() + ] + ) + [ AsNumpyArray(field=name, expected_ndim=1) for name in chain( @@ -166,30 +168,34 @@ def create_transformation(self) -> Transformation: ) if self.static_cardinalities: - transforms.extend([ - VstackFeatures( - output_field=FieldName.FEAT_STATIC_CAT, - input_fields=list(self.static_cardinalities.keys()), - h_stack=True, - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_CAT, - expected_ndim=1, - dtype=np.long, - ), - ]) + transforms.extend( + [ + VstackFeatures( + output_field=FieldName.FEAT_STATIC_CAT, + input_fields=list(self.static_cardinalities.keys()), + h_stack=True, + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, + expected_ndim=1, + dtype=np.long, + ), + ] + ) else: - transforms.extend([ - SetField( - output_field=FieldName.FEAT_STATIC_CAT, - value=[0], - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_CAT, - expected_ndim=1, - dtype=np.long, - ), - ]) + transforms.extend( + [ + SetField( + output_field=FieldName.FEAT_STATIC_CAT, + value=[0], + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, + expected_ndim=1, + dtype=np.long, + ), + ] + ) if self.static_feature_dims: transforms.append( @@ -200,44 +206,50 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend([ - SetField( - output_field=FieldName.FEAT_STATIC_REAL, - value=[0.0], - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_REAL, expected_ndim=1 - ), - ]) + transforms.extend( + [ + SetField( + output_field=FieldName.FEAT_STATIC_REAL, + value=[0.0], + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_REAL, expected_ndim=1 + ), + ] + ) if self.dynamic_cardinalities: - transforms.extend([ - VstackFeatures( - output_field=FieldName.FEAT_DYNAMIC_CAT, - input_fields=list(self.dynamic_cardinalities.keys()), - ), - AsNumpyArray( - field=FieldName.FEAT_DYNAMIC_CAT, - expected_ndim=2, - dtype=np.long, - ), - ]) + transforms.extend( + [ + VstackFeatures( + output_field=FieldName.FEAT_DYNAMIC_CAT, + input_fields=list(self.dynamic_cardinalities.keys()), + ), + AsNumpyArray( + field=FieldName.FEAT_DYNAMIC_CAT, + expected_ndim=2, + dtype=np.long, + ), + ] + ) else: - transforms.extend([ - SetField( - output_field=FieldName.FEAT_DYNAMIC_CAT, - value=[[0]], - ), - AsNumpyArray( - field=FieldName.FEAT_DYNAMIC_CAT, - expected_ndim=2, - dtype=np.long, - ), - BroadcastTo( - field=FieldName.FEAT_DYNAMIC_CAT, - ext_length=self.prediction_length, - ), - ]) + transforms.extend( + [ + SetField( + output_field=FieldName.FEAT_DYNAMIC_CAT, + value=[[0]], + ), + AsNumpyArray( + field=FieldName.FEAT_DYNAMIC_CAT, + expected_ndim=2, + dtype=np.long, + ), + BroadcastTo( + field=FieldName.FEAT_DYNAMIC_CAT, + ext_length=self.prediction_length, + ), + ] + ) input_fields = [FieldName.FEAT_TIME, FieldName.FEAT_AGE] if self.dynamic_feature_dims: @@ -250,30 +262,36 @@ def create_transformation(self) -> Transformation: ) if self.past_dynamic_cardinalities: - transforms.extend([ - VstackFeatures( - output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat", - input_fields=list(self.past_dynamic_cardinalities.keys()), - ), - AsNumpyArray( - field=FieldName.PAST_FEAT_DYNAMIC + "_cat", - expected_ndim=2, - dtype=np.long, - ), - ]) + transforms.extend( + [ + VstackFeatures( + output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat", + input_fields=list( + self.past_dynamic_cardinalities.keys() + ), + ), + AsNumpyArray( + field=FieldName.PAST_FEAT_DYNAMIC + "_cat", + expected_ndim=2, + dtype=np.long, + ), + ] + ) else: - transforms.extend([ - SetField( - output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat", - value=[[0]], - ), - AsNumpyArray( - field=FieldName.PAST_FEAT_DYNAMIC + "_cat", - expected_ndim=2, - dtype=np.long, - ), - BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC + "_cat"), - ]) + transforms.extend( + [ + SetField( + output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat", + value=[[0]], + ), + AsNumpyArray( + field=FieldName.PAST_FEAT_DYNAMIC + "_cat", + expected_ndim=2, + dtype=np.long, + ), + BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC + "_cat"), + ] + ) if self.past_dynamic_feature_dims: transforms.append( @@ -283,16 +301,18 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend([ - SetField( - output_field=FieldName.PAST_FEAT_DYNAMIC_REAL, - value=[[0.0]], - ), - AsNumpyArray( - field=FieldName.PAST_FEAT_DYNAMIC_REAL, expected_ndim=2 - ), - BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC_REAL), - ]) + transforms.extend( + [ + SetField( + output_field=FieldName.PAST_FEAT_DYNAMIC_REAL, + value=[[0.0]], + ), + AsNumpyArray( + field=FieldName.PAST_FEAT_DYNAMIC_REAL, expected_ndim=2 + ), + BroadcastTo(field=FieldName.PAST_FEAT_DYNAMIC_REAL), + ] + ) return Chain(transforms) diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_modules.py b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_modules.py index 437f2c6ca3..ef6698f505 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_modules.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/tft/tft_modules.py @@ -38,12 +38,14 @@ def __init__( self.feature_slices = feature_dims self.feature_dims = feature_dims - self._projector = nn.ModuleList([ - nn.Linear(in_features=in_feature, out_features=out_features) - for in_feature, out_features in zip( - self.feature_dims, embedding_dims - ) - ]) + self._projector = nn.ModuleList( + [ + nn.Linear(in_features=in_feature, out_features=out_features) + for in_feature, out_features in zip( + self.feature_dims, embedding_dims + ) + ] + ) def forward(self, features: torch.Tensor) -> List[torch.Tensor]: if self.__num_features > 1: @@ -158,10 +160,12 @@ def __init__( dropout=dropout, ) - self.variable_network = nn.ModuleList([ - GatedResidualNetwork(d_hidden=d_hidden, dropout=dropout) - for _ in range(n_vars) - ]) + self.variable_network = nn.ModuleList( + [ + GatedResidualNetwork(d_hidden=d_hidden, dropout=dropout) + for _ in range(n_vars) + ] + ) def forward( self, diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/epsilon_theta.py b/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/epsilon_theta.py index bc7462641f..a66b5b003c 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/epsilon_theta.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/epsilon_theta.py @@ -120,14 +120,16 @@ def __init__( self.cond_upsampler = CondUpsampler( target_dim=target_dim, cond_length=cond_length ) - self.residual_layers = nn.ModuleList([ - ResidualBlock( - residual_channels=residual_channels, - dilation=2 ** (i % dilation_cycle_length), - hidden_size=residual_hidden, - ) - for i in range(residual_layers) - ]) + self.residual_layers = nn.ModuleList( + [ + ResidualBlock( + residual_channels=residual_channels, + dilation=2 ** (i % dilation_cycle_length), + hidden_size=residual_hidden, + ) + for i in range(residual_layers) + ] + ) self.skip_projection = nn.Conv1d( residual_channels, residual_channels, 3 ) diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_estimator.py index cb60cd4e39..8ccb09d9c9 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_estimator.py @@ -136,39 +136,43 @@ def __init__( ) def create_transformation(self) -> Transformation: - return Chain([ - AsNumpyArray( - field=FieldName.TARGET, - expected_ndim=2, - ), - # maps the target to (1, T) - # if the target data is uni dimensional - ExpandDimArray( - field=FieldName.TARGET, - axis=None, - ), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME], - ), - SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]), - TargetDimIndicator( - field_name="target_dimension_indicator", - target_field=FieldName.TARGET, - ), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ]) + return Chain( + [ + AsNumpyArray( + field=FieldName.TARGET, + expected_ndim=2, + ), + # maps the target to (1, T) + # if the target data is uni dimensional + ExpandDimArray( + field=FieldName.TARGET, + axis=None, + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME], + ), + SetFieldIfNotPresent( + field=FieldName.FEAT_STATIC_CAT, value=[0] + ), + TargetDimIndicator( + field_name="target_dimension_indicator", + target_field=FieldName.TARGET, + ), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ] + ) def create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] @@ -192,10 +196,12 @@ def create_instance_splitter(self, mode: str): FieldName.OBSERVED_VALUES, ], ) + ( - RenameFields({ - f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", - f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", - }) + RenameFields( + { + f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", + f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", + } + ) ) def create_training_network( diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_network.py index 50c6d68973..ecd657bf2f 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/time_grad/time_grad_network.py @@ -565,12 +565,14 @@ def repeat(tensor, dim=0): samples = torch.cat(future_samples, dim=1) # (batch_size, num_samples, prediction_length, target_dim) - return samples.reshape(( - -1, - self.num_parallel_samples, - self.prediction_length, - self.target_dim, - )) + return samples.reshape( + ( + -1, + self.num_parallel_samples, + self.prediction_length, + self.target_dim, + ) + ) def forward( self, diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py index 85326df0e2..3ac17fef6e 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer/transformer_network.py @@ -224,11 +224,13 @@ def create_network_input( # from (batch_size, sub_seq_len, *target_shape, num_lags) # to (batch_size, sub_seq_len, prod(target_shape) * num_lags) - input_lags = lags_scaled.reshape(( - -1, - subsequences_length, - len(self.lags_seq) * prod(self.target_shape), - )) + input_lags = lags_scaled.reshape( + ( + -1, + subsequences_length, + len(self.lags_seq) * prod(self.target_shape), + ) + ) # (batch_size, sub_seq_len, input_dim) inputs = torch.cat( @@ -423,11 +425,13 @@ def sampling_decoder( samples = torch.cat(future_samples, dim=1) # (batch_size, num_samples, *target_shape, prediction_length) - return samples.reshape(( - (-1, self.num_parallel_samples) - + self.target_shape - + (self.prediction_length,) - )) + return samples.reshape( + ( + (-1, self.num_parallel_samples) + + self.target_shape + + (self.prediction_length,) + ) + ) def forward( self, diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_estimator.py b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_estimator.py index b875df3c58..26664034b2 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_estimator.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_estimator.py @@ -146,45 +146,49 @@ def create_transformation(self) -> Transformation: if not self.use_feat_dynamic_real: remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) - return Chain([ - RemoveFields(field_names=remove_field_names), - AsNumpyArray( - field=FieldName.TARGET, - expected_ndim=2, - ), - # maps the target to (1, T) - # if the target data is uni dimensional - ExpandDimArray( - field=FieldName.TARGET, - axis=None, - ), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME] - + ( - [FieldName.FEAT_DYNAMIC_REAL] - if self.use_feat_dynamic_real - else [] + return Chain( + [ + RemoveFields(field_names=remove_field_names), + AsNumpyArray( + field=FieldName.TARGET, + expected_ndim=2, ), - ), - SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]), - TargetDimIndicator( - field_name="target_dimension_indicator", - target_field=FieldName.TARGET, - ), - AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), - ]) + # maps the target to (1, T) + # if the target data is uni dimensional + ExpandDimArray( + field=FieldName.TARGET, + axis=None, + ), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME] + + ( + [FieldName.FEAT_DYNAMIC_REAL] + if self.use_feat_dynamic_real + else [] + ), + ), + SetFieldIfNotPresent( + field=FieldName.FEAT_STATIC_CAT, value=[0] + ), + TargetDimIndicator( + field_name="target_dimension_indicator", + target_field=FieldName.TARGET, + ), + AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1), + ] + ) def create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] @@ -208,10 +212,12 @@ def create_instance_splitter(self, mode: str): FieldName.OBSERVED_VALUES, ], ) + ( - RenameFields({ - f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", - f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", - }) + RenameFields( + { + f"past_{FieldName.TARGET}": f"past_{FieldName.TARGET}_cdf", + f"future_{FieldName.TARGET}": f"future_{FieldName.TARGET}_cdf", + } + ) ) def create_training_network( diff --git a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py index ffd78d9a77..55d60d1537 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/model/transformer_tempflow/transformer_tempflow_network.py @@ -552,12 +552,14 @@ def repeat(tensor, dim=0): samples = torch.cat(future_samples, dim=1) # (batch_size, num_samples, prediction_length, target_dim) - return samples.reshape(( - -1, - self.num_parallel_samples, - self.prediction_length, - self.target_dim, - )) + return samples.reshape( + ( + -1, + self.num_parallel_samples, + self.prediction_length, + self.target_dim, + ) + ) def forward( self, diff --git a/src/gluonts/nursery/robust-mts-attack/pts/modules/feature.py b/src/gluonts/nursery/robust-mts-attack/pts/modules/feature.py index 98c046dcc8..e8f61a25cc 100644 --- a/src/gluonts/nursery/robust-mts-attack/pts/modules/feature.py +++ b/src/gluonts/nursery/robust-mts-attack/pts/modules/feature.py @@ -31,10 +31,12 @@ def create_embedding(c: int, d: int) -> nn.Embedding: embedding = nn.Embedding(c, d) return embedding - self.__embedders = nn.ModuleList([ - create_embedding(c, d) - for c, d in zip(cardinalities, embedding_dims) - ]) + self.__embedders = nn.ModuleList( + [ + create_embedding(c, d) + for c, d in zip(cardinalities, embedding_dims) + ] + ) def forward(self, features: torch.Tensor) -> torch.Tensor: if self.__num_features > 1: diff --git a/src/gluonts/nursery/robust-mts-attack/read_pickle.py b/src/gluonts/nursery/robust-mts-attack/read_pickle.py index 038ddbd3a0..e43284f8b3 100644 --- a/src/gluonts/nursery/robust-mts-attack/read_pickle.py +++ b/src/gluonts/nursery/robust-mts-attack/read_pickle.py @@ -53,10 +53,12 @@ def create_table(path): "+-", np.asarray(result.mse[key]).std() * c, ) - mse.append(( - np.asarray(result.mse[key]).mean(), - np.asarray(result.mse[key]).std() * c, - )) + mse.append( + ( + np.asarray(result.mse[key]).mean(), + np.asarray(result.mse[key]).std() * c, + ) + ) print("mape loss:") for key in result.mape.keys(): @@ -66,10 +68,12 @@ def create_table(path): "+-", np.asarray(result.mape[key]).std() * c, ) - mape.append(( - np.asarray(result.mape[key]).mean(), - np.asarray(result.mape[key]).std() * c, - )) + mape.append( + ( + np.asarray(result.mape[key]).mean(), + np.asarray(result.mape[key]).std() * c, + ) + ) print("wQL:") for key in result.ql.keys(): @@ -79,10 +83,12 @@ def create_table(path): "+-", np.asarray(result.ql[key]).std() * c, ) - wql.append(( - np.asarray(result.ql[key]).mean(), - np.asarray(result.ql[key]).std() * c, - )) + wql.append( + ( + np.asarray(result.ql[key]).mean(), + np.asarray(result.ql[key]).std() * c, + ) + ) with open("table_" + types + ".txt", "w") as f: for i in range(len(mse)): diff --git a/src/gluonts/nursery/robust-mts-attack/utils.py b/src/gluonts/nursery/robust-mts-attack/utils.py index c40e2fd6a2..75a36ec169 100644 --- a/src/gluonts/nursery/robust-mts-attack/utils.py +++ b/src/gluonts/nursery/robust-mts-attack/utils.py @@ -235,10 +235,12 @@ def calc_loss( target_items, quantiles=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], ): - testset_size = sum([ - attack_data[i].true_future_target.shape[0] - for i in range(len(attack_data)) - ]) + testset_size = sum( + [ + attack_data[i].true_future_target.shape[0] + for i in range(len(attack_data)) + ] + ) mse = { key: np.zeros((testset_size, len(attack_idx), len(target_items))) for key in forecasts.keys() diff --git a/src/gluonts/nursery/san/_estimator.py b/src/gluonts/nursery/san/_estimator.py index f72edabb97..3f50433d54 100644 --- a/src/gluonts/nursery/san/_estimator.py +++ b/src/gluonts/nursery/san/_estimator.py @@ -129,19 +129,21 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend([ - SetField( - output_field=FieldName.FEAT_DYNAMIC_REAL, - value=[[]] - * (self.context_length + self.prediction_length), - ), - AsNumpyArray( - field=FieldName.FEAT_DYNAMIC_REAL, - expected_ndim=2, - ), - # SwapAxes(input_fields= - # [FieldName.FEAT_DYNAMIC_REAL], axes=(0,1)), - ]) + transforms.extend( + [ + SetField( + output_field=FieldName.FEAT_DYNAMIC_REAL, + value=[[]] + * (self.context_length + self.prediction_length), + ), + AsNumpyArray( + field=FieldName.FEAT_DYNAMIC_REAL, + expected_ndim=2, + ), + # SwapAxes(input_fields= + # [FieldName.FEAT_DYNAMIC_REAL], axes=(0,1)), + ] + ) if self.use_feat_dynamic_cat: transforms.append( AsNumpyArray( @@ -153,24 +155,26 @@ def create_transformation(self) -> Transformation: # Manually set dummy dynamic categorical features and split by time # Unknown issue in dataloader if leave splitting to # InstanceSplitter - transforms.extend([ - SetField( - output_field="past_" + FieldName.FEAT_DYNAMIC_CAT, - value=[[]] * self.context_length, - ), - AsNumpyArray( - field="past_" + FieldName.FEAT_DYNAMIC_CAT, - expected_ndim=2, - ), - SetField( - output_field="future_" + FieldName.FEAT_DYNAMIC_CAT, - value=[[]] * self.prediction_length, - ), - AsNumpyArray( - field="future_" + FieldName.FEAT_DYNAMIC_CAT, - expected_ndim=2, - ), - ]) + transforms.extend( + [ + SetField( + output_field="past_" + FieldName.FEAT_DYNAMIC_CAT, + value=[[]] * self.context_length, + ), + AsNumpyArray( + field="past_" + FieldName.FEAT_DYNAMIC_CAT, + expected_ndim=2, + ), + SetField( + output_field="future_" + FieldName.FEAT_DYNAMIC_CAT, + value=[[]] * self.prediction_length, + ), + AsNumpyArray( + field="future_" + FieldName.FEAT_DYNAMIC_CAT, + expected_ndim=2, + ), + ] + ) if self.use_feat_static_real: transforms.append( AsNumpyArray( @@ -179,16 +183,18 @@ def create_transformation(self) -> Transformation: ) ) else: - transforms.extend([ - SetField( - output_field=FieldName.FEAT_STATIC_REAL, - value=[], - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_REAL, - expected_ndim=1, - ), - ]) + transforms.extend( + [ + SetField( + output_field=FieldName.FEAT_STATIC_REAL, + value=[], + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_REAL, + expected_ndim=1, + ), + ] + ) if self.use_feat_static_cat: transforms.append( AsNumpyArray( @@ -197,35 +203,37 @@ def create_transformation(self) -> Transformation: ) ) - transforms.extend([ - AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - AddAgeFeature( - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_AGE, - pred_length=self.prediction_length, - log_scale=True, - ), - VstackFeatures( - output_field=FieldName.FEAT_DYNAMIC_REAL, - input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] - + ( - [FieldName.FEAT_DYNAMIC_REAL] - if self.use_feat_dynamic_real - else [] + transforms.extend( + [ + AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, ), - ), - ]) + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + AddAgeFeature( + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_AGE, + pred_length=self.prediction_length, + log_scale=True, + ), + VstackFeatures( + output_field=FieldName.FEAT_DYNAMIC_REAL, + input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] + + ( + [FieldName.FEAT_DYNAMIC_REAL] + if self.use_feat_dynamic_real + else [] + ), + ), + ] + ) return Chain(transforms) def _create_instance_splitter(self, mode: str): diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py index 48769c0f71..8c8b297fa6 100644 --- a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py +++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_estimator.py @@ -110,10 +110,12 @@ def _get_time_features_agg_level( ) # shape: (T, num_features) - full_time_feat = np.array([ - feat_map(full_date_range) - for feat_map in time_features_from_frequency_str(freq) - ]).T + full_time_feat = np.array( + [ + feat_map(full_date_range) + for feat_map in time_features_from_frequency_str(freq) + ] + ).T age_feature = np.log10( 2.0 + np.arange(num_periods, dtype=agg_estimator.dtype) diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_network.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_network.py index 8acbc37f7e..72b8329e25 100644 --- a/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_network.py +++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/model/cop_deepar/_network.py @@ -226,10 +226,12 @@ def get_target_related_feat_at_agg_level( mx.nd.zeros_like(future_observed_values_agg), ) - target_related_feat_agg.update({ - "future_target": future_target_agg, - "future_observed_values": future_observed_values_agg, - }) + target_related_feat_agg.update( + { + "future_target": future_target_agg, + "future_observed_values": future_observed_values_agg, + } + ) return target_related_feat_agg @@ -421,12 +423,14 @@ def hybrid_forward( ) // window_size embeddings_at_all_levels_ls.append( - rnn_outputs.reshape(( - rnn_outputs.shape[0], - num_windows, - -1, - rnn_outputs.shape[-1], - )) + rnn_outputs.reshape( + ( + rnn_outputs.shape[0], + num_windows, + -1, + rnn_outputs.shape[-1], + ) + ) ) target_at_all_levels_ls.append( @@ -831,11 +835,13 @@ def hybrid_forward( ) reconciled_samples_at_bottom_level = ( - reconciled_samples_at_bottom_level.reshape(( - reconciled_samples_at_bottom_level.shape[0], - reconciled_samples_at_bottom_level.shape[1], - -1, - )) + reconciled_samples_at_bottom_level.reshape( + ( + reconciled_samples_at_bottom_level.shape[0], + reconciled_samples_at_bottom_level.shape[1], + -1, + ) + ) ) return reconciled_samples_at_bottom_level diff --git a/src/gluonts/nursery/temporal_hierarchical_forecasting/utils/utils.py b/src/gluonts/nursery/temporal_hierarchical_forecasting/utils/utils.py index 4a8e02229b..289f590e30 100644 --- a/src/gluonts/nursery/temporal_hierarchical_forecasting/utils/utils.py +++ b/src/gluonts/nursery/temporal_hierarchical_forecasting/utils/utils.py @@ -355,10 +355,12 @@ def mapping_matrix_at_level(level: int): M[:, start_ix:end_ix] = M[:, start_ix:end_ix] / row_sum[None, :] return M - mapping_matrices = np.array([ - mapping_matrix_at_level(level=level) - for level in range(len(cum_num_nodes_per_level)) - ]) + mapping_matrices = np.array( + [ + mapping_matrix_at_level(level=level) + for level in range(len(cum_num_nodes_per_level)) + ] + ) mean_mapping_matrix = np.mean(mapping_matrices, axis=0) reconciliation_mat = np.matmul(S, mean_mapping_matrix) diff --git a/src/gluonts/nursery/tsbench/src/cli/utils/config.py b/src/gluonts/nursery/tsbench/src/cli/utils/config.py index 4b5e9b6400..41425e22e7 100644 --- a/src/gluonts/nursery/tsbench/src/cli/utils/config.py +++ b/src/gluonts/nursery/tsbench/src/cli/utils/config.py @@ -74,10 +74,12 @@ def explode_key_values( """ all_combinations = { primary: ( - itertools.product(*[ - [(option["key"], value) for value in option["values"]] - for option in choices - ]) + itertools.product( + *[ + [(option["key"], value) for value in option["values"]] + for option in choices + ] + ) if choices else [] ) @@ -93,9 +95,12 @@ def explode_key_values( primary_config = {primary_key: primary} for key, value in item: if isinstance(key, (list, tuple)): - primary_config.update({ - process_key(primary, k): v for k, v in zip(key, value) - }) + primary_config.update( + { + process_key(primary, k): v + for k, v in zip(key, value) + } + ) else: primary_config[process_key(primary, key)] = value configs.append(primary_config) @@ -132,10 +137,12 @@ def process_key(model: str, key: str) -> str: for seed in seeds: for dataset in datasets: for model_config in configs: - all_configurations.append({ - "seed": seed, - "dataset": dataset, - **model_config, - }) + all_configurations.append( + { + "seed": seed, + "dataset": dataset, + **model_config, + } + ) return all_configurations diff --git a/src/gluonts/nursery/tsbench/src/tsbench/config/dataset/datasets.py b/src/gluonts/nursery/tsbench/src/tsbench/config/dataset/datasets.py index 328cdb42c1..9283e7c600 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/config/dataset/datasets.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/config/dataset/datasets.py @@ -1026,14 +1026,16 @@ def _extract_data( series = [] for i, store_data in data.groupby("Store"): sorted_data = store_data.sort_values("Date") - series.append({ - "item_id": int(i) - 1, - "start": sorted_data.Date.min(), - "target": sorted_data.Sales.to_list(), - "feat_static_cat": [ - int(i) - 1, - ], - }) + series.append( + { + "item_id": int(i) - 1, + "start": sorted_data.Date.min(), + "target": sorted_data.Sales.to_list(), + "feat_static_cat": [ + int(i) - 1, + ], + } + ) return metadata, series @@ -1092,18 +1094,20 @@ def _extract_data( sorted_data.unit_sales.to_numpy(), index=pd.DatetimeIndex(sorted_data.date), ) - series.append({ - "item_id": i, - "start": sorted_data.date.min(), - "target": sales.resample("D") - .first() - .fillna(value=0) - .to_list(), - "feat_static_cat": [ - int(store_id) - 1, - int(item_id), - ], - }) + series.append( + { + "item_id": i, + "start": sorted_data.date.min(), + "target": sales.resample("D") + .first() + .fillna(value=0) + .to_list(), + "feat_static_cat": [ + int(store_id) - 1, + int(item_id), + ], + } + ) return metadata, series @@ -1159,15 +1163,17 @@ def _extract_data( ): department_id = np.where(department_ids == department)[0][0] sorted_data = group_data.sort_values("Date") - series.append({ - "item_id": i, - "start": sorted_data.Date.min(), - "target": sorted_data.Weekly_Sales.to_list(), - "feat_static_cat": [ - int(store_id) - 1, - int(department_id), - ], - }) + series.append( + { + "item_id": i, + "start": sorted_data.Date.min(), + "target": sorted_data.Weekly_Sales.to_list(), + "feat_static_cat": [ + int(store_id) - 1, + int(department_id), + ], + } + ) return metadata, series @@ -1221,16 +1227,18 @@ def _extract_data( sorted_data.visitors.to_numpy(), index=pd.DatetimeIndex(sorted_data.visit_date), ) - series.append({ - "item_id": i, - "start": sorted_data.visit_date.min(), - "target": visitors.resample("D") - .first() - .fillna(value=0) - .to_list(), - "feat_static_cat": [ - int(store_id), - ], - }) + series.append( + { + "item_id": i, + "start": sorted_data.visit_date.min(), + "target": visitors.resample("D") + .first() + .fillna(value=0) + .to_list(), + "feat_static_cat": [ + int(store_id), + ], + } + ) return metadata, series diff --git a/src/gluonts/nursery/tsbench/src/tsbench/evaluations/aws/analytics.py b/src/gluonts/nursery/tsbench/src/tsbench/evaluations/aws/analytics.py index 2a5203a099..f65674ae55 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/evaluations/aws/analytics.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/evaluations/aws/analytics.py @@ -373,11 +373,13 @@ def _fetch_training_jobs( "MaxResults": 100, "Resource": "TrainingJob", "SearchExpression": { - "Filters": [{ - "Name": "Tags.Experiment", - "Operator": "Equals", - "Value": experiment, - }], + "Filters": [ + { + "Name": "Tags.Experiment", + "Operator": "Equals", + "Value": experiment, + } + ], }, } diff --git a/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/_info.py b/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/_info.py index 31099cb938..662817ea4d 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/_info.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/_info.py @@ -142,24 +142,34 @@ def extract_job_infos( ] # And average the performance - averaged_performance = Performance(**{ - metric: Metric( - np.mean([getattr(p, metric).mean for p in performances]), - np.std([getattr(p, metric).mean for p in performances]), - ) - for metric in Performance.metrics() - }) + averaged_performance = Performance( + **{ + metric: Metric( + np.mean( + [getattr(p, metric).mean for p in performances] + ), + np.std( + [getattr(p, metric).mean for p in performances] + ), + ) + for metric in Performance.metrics() + } + ) # Get validation scores if available try: - val_ncrps = np.mean([ - job.metrics[c]["evaluation"]["val_ncrps"] - for (job, c) in zip(jobs, choices) - ]) - val_loss = np.mean([ - job.metrics[c]["evaluation"]["val_loss"] - for (job, c) in zip(jobs, choices) - ]).item() + val_ncrps = np.mean( + [ + job.metrics[c]["evaluation"]["val_ncrps"] + for (job, c) in zip(jobs, choices) + ] + ) + val_loss = np.mean( + [ + job.metrics[c]["evaluation"]["val_loss"] + for (job, c) in zip(jobs, choices) + ] + ).item() val_scores = ValidationScores(val_ncrps, val_loss) except KeyError: val_scores = None diff --git a/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/ensemble.py b/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/ensemble.py index 208d11a587..d6dfe5c460 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/ensemble.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/evaluations/tracking/ensemble.py @@ -39,10 +39,12 @@ def __init__(self, directory: Path): continue with Path(file).open("rb") as f: data = pickle.load(f) - configurations.extend([ - Config(frozenset(x["configurations"]), x["dataset"]) - for x in data - ]) + configurations.extend( + [ + Config(frozenset(x["configurations"]), x["dataset"]) + for x in data + ] + ) performances.extend([x["performance"] for x in data]) self.performance_map: Dict[Config[EnsembleConfig], Performance] = dict( diff --git a/src/gluonts/nursery/tsbench/src/tsbench/forecasts/evaluation.py b/src/gluonts/nursery/tsbench/src/tsbench/forecasts/evaluation.py index cfb8b52150..7c24adc08e 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/forecasts/evaluation.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/forecasts/evaluation.py @@ -101,14 +101,18 @@ def performance(cls, evaluations: list[Evaluation]) -> Performance: Metric(0, 0) if m == "num_model_parameters" else Metric( - np.mean([ - metric[m] if m in metric else np.nan - for metric in metrics - ]), - np.std([ - metric[m] if m in metric else np.nan - for metric in metrics - ]), + np.mean( + [ + metric[m] if m in metric else np.nan + for metric in metrics + ] + ), + np.std( + [ + metric[m] if m in metric else np.nan + for metric in metrics + ] + ), ) ) for m in Performance.metrics() diff --git a/src/gluonts/nursery/tsbench/src/tsbench/recommender/greedy.py b/src/gluonts/nursery/tsbench/src/tsbench/recommender/greedy.py index 30377c7990..effac277ca 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/recommender/greedy.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/recommender/greedy.py @@ -88,16 +88,20 @@ def fit( transformer = QuantileTransformer( n_quantiles=min(1000, self.metrics.shape[0]) ) - self.metrics = np.stack([ - transformer.fit_transform(dataset_metrics) - for dataset_metrics in self.metrics - ]) + self.metrics = np.stack( + [ + transformer.fit_transform(dataset_metrics) + for dataset_metrics in self.metrics + ] + ) else: transformer = StandardScaler() - self.metrics = np.stack([ - transformer.fit_transform(dataset_metrics) - for dataset_metrics in self.metrics - ]) + self.metrics = np.stack( + [ + transformer.fit_transform(dataset_metrics) + for dataset_metrics in self.metrics + ] + ) def recommend( self, @@ -119,10 +123,12 @@ def recommend( # true Pareto front. if not self.enforce_single_objective and len(self.objectives) > 1: reference = np.ones(len(self.objectives)) - hypervolumes = np.array([ - pygmo.hypervolume(dataset_metrics).compute(reference) # type: ignore - for dataset_metrics in self.metrics - ]) + hypervolumes = np.array( + [ + pygmo.hypervolume(dataset_metrics).compute(reference) # type: ignore + for dataset_metrics in self.metrics + ] + ) available_choices = list(range(len(model_configs))) result = [] @@ -146,12 +152,14 @@ def recommend( else: # Otherwise, we need to compute the hypervolumes for all datasets reference = np.ones(len(self.objectives)) - config_hypervolumes = np.array([ - pygmo.hypervolume( # type: ignore - dataset_performances, - ).compute(reference) - for dataset_performances in all_performances - ]) + config_hypervolumes = np.array( + [ + pygmo.hypervolume( # type: ignore + dataset_performances, + ).compute(reference) + for dataset_performances in all_performances + ] + ) # And then compute the cumulative hypervolume error error = (hypervolumes - config_hypervolumes).sum() # type: ignore @@ -169,8 +177,10 @@ def recommend( def _dummy_performance() -> Performance: - return Performance.from_dict({ - mm: np.nan - for m in Performance.metrics() - for mm in [f"{m}_mean", f"{m}_std"] - }) + return Performance.from_dict( + { + mm: np.nan + for m in Performance.metrics() + for mm in [f"{m}_mean", f"{m}_std"] + } + ) diff --git a/src/gluonts/nursery/tsbench/src/tsbench/surrogate/nonparametric.py b/src/gluonts/nursery/tsbench/src/tsbench/surrogate/nonparametric.py index 9ec0f19aeb..5172fe66d2 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/surrogate/nonparametric.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/surrogate/nonparametric.py @@ -67,11 +67,13 @@ def __init__( tracker, predict, output_normalization, impute_simulatable ) - self.use_dataset_features = any([ - use_simple_dataset_features, - use_seasonal_naive_performance, - use_catch22_features, - ]) + self.use_dataset_features = any( + [ + use_simple_dataset_features, + use_seasonal_naive_performance, + use_catch22_features, + ] + ) if self.use_dataset_features: self.config_transformer = ConfigTransformer( add_model_features=False, @@ -95,22 +97,26 @@ def _fit( # Then, we assign the model performances and dataset features self.model_performances_ = { - model: np.stack([ - p["performance"] - for p in sorted( - data, - key=lambda x: x["dataset"].name(), # type: ignore - ) - ]) + model: np.stack( + [ + p["performance"] + for p in sorted( + data, + key=lambda x: x["dataset"].name(), # type: ignore + ) + ] + ) for model, data in performances.items() } # We use the seasonal naive model config here since it is ignored anyway if self.use_dataset_features: - self.dataset_features_ = self.config_transformer.fit_transform([ - Config(SeasonalNaiveModelConfig(), d) - for d in sorted(datasets, key=lambda x: x.name()) # type: ignore - ]) + self.dataset_features_ = self.config_transformer.fit_transform( + [ + Config(SeasonalNaiveModelConfig(), d) + for d in sorted(datasets, key=lambda x: x.name()) # type: ignore + ] + ) def _predict( self, X: List[Config[ModelConfig]] diff --git a/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/config.py b/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/config.py index c7e2bb9cf9..ce5aa18f02 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/config.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/config.py @@ -61,12 +61,14 @@ def __init__( add_catch_22_features: Whether a dataset's catch22 features ought to be added. tracker: An optional tracker to obtain the performance of Seasonal Naïve. """ - assert any([ - add_model_features, - add_dataset_statistics, - add_seasonal_naive_performance, - add_catch22_features, - ]), "ConfigTransformer must be given at least some group of features." + assert any( + [ + add_model_features, + add_dataset_statistics, + add_seasonal_naive_performance, + add_catch22_features, + ] + ), "ConfigTransformer must be given at least some group of features." assert ( not add_seasonal_naive_performance or tracker is not None ), "Tracker must be set if seasonal naive performance is used." @@ -335,10 +337,12 @@ def transform( if self.transform_full_config: return cast( npt.NDArray[np.float32], - self.pipeline.transform([ - x.model.asdict() - for x in cast(List[Config[ModelConfig]], X) - ]), + self.pipeline.transform( + [ + x.model.asdict() + for x in cast(List[Config[ModelConfig]], X) + ] + ), ) return cast( npt.NDArray[np.float32], @@ -379,14 +383,16 @@ def feature_names_(self) -> list[str]: def fit( self, X: list[Config[ModelConfig]], _y: Any = None ) -> DatasetStatisticsEncoder: - self.pipeline.fit([ - { - **x.dataset.stats(), - "frequency": x.dataset.meta.freq, - "prediction_length": x.dataset.meta.prediction_length, - } - for x in X - ]) + self.pipeline.fit( + [ + { + **x.dataset.stats(), + "frequency": x.dataset.meta.freq, + "prediction_length": x.dataset.meta.prediction_length, + } + for x in X + ] + ) return self def transform( @@ -424,12 +430,14 @@ def transform( def _get_performance_array( self, X: list[Config[ModelConfig]] ) -> npt.NDArray[np.float32]: - return np.array([ - self.tracker.get_performance( - Config(SeasonalNaiveModelConfig(), x.dataset) - ).ncrps.mean - for x in X - ])[:, None] + return np.array( + [ + self.tracker.get_performance( + Config(SeasonalNaiveModelConfig(), x.dataset) + ).ncrps.mean + for x in X + ] + )[:, None] class DatasetCatch22Encoder(Encoder): diff --git a/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/performance.py b/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/performance.py index 9ec903f2d8..28f32ce34e 100644 --- a/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/performance.py +++ b/src/gluonts/nursery/tsbench/src/tsbench/surrogate/transformers/performance.py @@ -112,10 +112,13 @@ def transform( def inverse_transform( self, X: npt.NDArray[np.float32], _y: Any = None ) -> list[Performance]: - df = pd.DataFrame(X, columns=self.feature_names_).assign(**{ - col: np.nan - for col in set(self.all_feature_names_) - set(self.feature_names_) - }) + df = pd.DataFrame(X, columns=self.feature_names_).assign( + **{ + col: np.nan + for col in set(self.all_feature_names_) + - set(self.feature_names_) + } + ) return [ Performance.from_dict(row.to_dict()) for _, row in df.iterrows() ] diff --git a/src/gluonts/shell/sagemaker/dyn.py b/src/gluonts/shell/sagemaker/dyn.py index fae3e3137f..bb1bc4d57b 100644 --- a/src/gluonts/shell/sagemaker/dyn.py +++ b/src/gluonts/shell/sagemaker/dyn.py @@ -49,29 +49,33 @@ def copy_install(self, path: Path): shutil.copytree(path, self.packages / path.name) def pip_install(self, path: Path): - subprocess.check_call([ - sys.executable, - "-m", - "pip", - "install", - "--upgrade", - "--target", - str(self.packages), - str(path), - ]) + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + "--upgrade", + "--target", + str(self.packages), + str(path), + ] + ) def install_requirement(self, path: Path): - subprocess.check_call([ - sys.executable, - "-m", - "pip", - "install", - "--upgrade", - "--target", - str(self.packages), - "--requirement", - str(path), - ]) + subprocess.check_call( + [ + sys.executable, + "-m", + "pip", + "install", + "--upgrade", + "--target", + str(self.packages), + "--requirement", + str(path), + ] + ) def install(self, path): if path.is_file(): diff --git a/src/gluonts/time_feature/holiday.py b/src/gluonts/time_feature/holiday.py index bb6422627c..9d395a0bb0 100644 --- a/src/gluonts/time_feature/holiday.py +++ b/src/gluonts/time_feature/holiday.py @@ -215,10 +215,16 @@ def __call__(self, dates): dates Pandas series with Datetimeindex timestamps. """ - return np.vstack([ - np.hstack([ - self.kernel_function(SPECIAL_DATE_FEATURES[feat_name](index)) - for index in dates - ]) - for feat_name in self.feature_names - ]) + return np.vstack( + [ + np.hstack( + [ + self.kernel_function( + SPECIAL_DATE_FEATURES[feat_name](index) + ) + for index in dates + ] + ) + for feat_name in self.feature_names + ] + ) diff --git a/src/gluonts/torch/model/estimator.py b/src/gluonts/torch/model/estimator.py index 41757efa70..b8a1147d44 100644 --- a/src/gluonts/torch/model/estimator.py +++ b/src/gluonts/torch/model/estimator.py @@ -198,11 +198,13 @@ def train_model( ) custom_callbacks = self.trainer_kwargs.pop("callbacks", []) - trainer = pl.Trainer(**{ - "accelerator": "auto", - "callbacks": [checkpoint] + custom_callbacks, - **self.trainer_kwargs, - }) + trainer = pl.Trainer( + **{ + "accelerator": "auto", + "callbacks": [checkpoint] + custom_callbacks, + **self.trainer_kwargs, + } + ) trainer.fit( model=training_network, diff --git a/src/gluonts/torch/model/patch_tst/module.py b/src/gluonts/torch/model/patch_tst/module.py index c4a7149dcb..4e829e2ea1 100644 --- a/src/gluonts/torch/model/patch_tst/module.py +++ b/src/gluonts/torch/model/patch_tst/module.py @@ -38,10 +38,12 @@ def _init_weight(out: torch.Tensor) -> torch.Tensor: Features are not interleaved. The cos features are in the 2nd half of the vector. [dim // 2:] """ n_pos, dim = out.shape - position_enc = np.array([ - [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] - for pos in range(n_pos) - ]) + position_enc = np.array( + [ + [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] + for pos in range(n_pos) + ] + ) # set early to avoid an error in pytorch-1.8+ out.requires_grad = False diff --git a/src/gluonts/torch/model/tft/layers.py b/src/gluonts/torch/model/tft/layers.py index 5a12260b70..d19ed7f059 100644 --- a/src/gluonts/torch/model/tft/layers.py +++ b/src/gluonts/torch/model/tft/layers.py @@ -56,10 +56,12 @@ def __init__( self.feature_dims = feature_dims self._num_features = len(feature_dims) - self._projectors = nn.ModuleList([ - nn.Linear(out_features=d, in_features=c) - for c, d in zip(feature_dims, embedding_dims) - ]) + self._projectors = nn.ModuleList( + [ + nn.Linear(out_features=d, in_features=c) + for c, d in zip(feature_dims, embedding_dims) + ] + ) def forward(self, features: torch.Tensor) -> List[torch.Tensor]: """ @@ -187,10 +189,12 @@ def __init__( d_static=self.d_hidden if add_static else None, dropout=dropout, ) - self.variable_networks = nn.ModuleList([ - GatedResidualNetwork(d_hidden=d_hidden, dropout=dropout) - for _ in range(num_vars) - ]) + self.variable_networks = nn.ModuleList( + [ + GatedResidualNetwork(d_hidden=d_hidden, dropout=dropout) + for _ in range(num_vars) + ] + ) def forward( self, diff --git a/src/gluonts/torch/model/wavenet/estimator.py b/src/gluonts/torch/model/wavenet/estimator.py index 6aacdc3a9a..4a5b2ff0b1 100644 --- a/src/gluonts/torch/model/wavenet/estimator.py +++ b/src/gluonts/torch/model/wavenet/estimator.py @@ -264,56 +264,60 @@ def create_transformation(self) -> Transformation: remove_field_names.append(FieldName.FEAT_STATIC_REAL) if self.num_feat_dynamic_real == 0: remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL) - return Chain([ - RemoveFields(field_names=remove_field_names), - ( - SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0]) - if self.num_feat_static_cat == 0 - else Identity() - ), - ( - SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0]) - if self.num_feat_static_real == 0 - else Identity() - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=int - ), - AsNumpyArray( - field=FieldName.FEAT_STATIC_REAL, - expected_ndim=1, - dtype=np.float32, - ), - AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), - AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field=FieldName.OBSERVED_VALUES, - ), - AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=self.time_features, - pred_length=self.prediction_length, - ), - AddAgeFeature( - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_AGE, - pred_length=self.prediction_length, - ), - VstackFeatures( - output_field=FieldName.FEAT_TIME, - input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] - + ( - [FieldName.FEAT_DYNAMIC_REAL] - if self.num_feat_dynamic_real > 0 - else [] + return Chain( + [ + RemoveFields(field_names=remove_field_names), + ( + SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0]) + if self.num_feat_static_cat == 0 + else Identity() ), - ), - AsNumpyArray( - FieldName.FEAT_TIME, expected_ndim=2, dtype=np.float32 - ), - ]) + ( + SetField( + output_field=FieldName.FEAT_STATIC_REAL, value=[0.0] + ) + if self.num_feat_static_real == 0 + else Identity() + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_CAT, expected_ndim=1, dtype=int + ), + AsNumpyArray( + field=FieldName.FEAT_STATIC_REAL, + expected_ndim=1, + dtype=np.float32, + ), + AsNumpyArray(field=FieldName.TARGET, expected_ndim=1), + AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field=FieldName.OBSERVED_VALUES, + ), + AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=self.time_features, + pred_length=self.prediction_length, + ), + AddAgeFeature( + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_AGE, + pred_length=self.prediction_length, + ), + VstackFeatures( + output_field=FieldName.FEAT_TIME, + input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE] + + ( + [FieldName.FEAT_DYNAMIC_REAL] + if self.num_feat_dynamic_real > 0 + else [] + ), + ), + AsNumpyArray( + FieldName.FEAT_TIME, expected_ndim=2, dtype=np.float32 + ), + ] + ) def _create_instance_splitter(self, mode: str): assert mode in ["training", "validation", "test"] diff --git a/src/gluonts/transform/feature.py b/src/gluonts/transform/feature.py index 52648469c9..f5f519d5a7 100644 --- a/src/gluonts/transform/feature.py +++ b/src/gluonts/transform/feature.py @@ -192,10 +192,12 @@ def __call__(self, values: np.ndarray) -> np.ndarray: last_value_imputation = LastValueImputation() value_no_nans = last_value_imputation(values) - adjusted_values_to_causality = np.concatenate(( - np.repeat(value_no_nans[0], self.window_size + 1), - value_no_nans[:-1], - )) + adjusted_values_to_causality = np.concatenate( + ( + np.repeat(value_no_nans[0], self.window_size + 1), + value_no_nans[:-1], + ) + ) cumsum = np.cumsum(adjusted_values_to_causality) @@ -517,23 +519,25 @@ def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry: # compute the aggregate lags for each time point of the time series agg_vals = np.concatenate( [ - np.zeros(( - max(self.valid_lags) * self.ratio + self.half_window + 1, - )), + np.zeros( + (max(self.valid_lags) * self.ratio + self.half_window + 1,) + ), t_agg.values, ], axis=0, ) - lags = np.vstack([ - agg_vals[ - -(l * self.ratio - self.half_window + len(t)) : ( - -(l * self.ratio - self.half_window) - if -(l * self.ratio - self.half_window) != 0 - else None - ) + lags = np.vstack( + [ + agg_vals[ + -(l * self.ratio - self.half_window + len(t)) : ( + -(l * self.ratio - self.half_window) + if -(l * self.ratio - self.half_window) != 0 + else None + ) + ] + for l in self.valid_lags ] - for l in self.valid_lags - ]) + ) # update the data entry data[self.feature_name] = np.nan_to_num(lags) diff --git a/src/gluonts/zebras/_period.py b/src/gluonts/zebras/_period.py index ed9f52af00..4cda3178a9 100644 --- a/src/gluonts/zebras/_period.py +++ b/src/gluonts/zebras/_period.py @@ -114,9 +114,12 @@ def dayofyear(self) -> np.ndarray: def week(self) -> np.ndarray: # Note: In Python 3.9 `isocalendar()` returns a named tuple, but we # need to support 3.7 and 3.8, so we use index one for the week. - return np.array([ - cal.isocalendar()[1] for cal in self.data.astype(datetime.datetime) - ]) + return np.array( + [ + cal.isocalendar()[1] + for cal in self.data.astype(datetime.datetime) + ] + ) def __add__(self, other): if _is_number(other): diff --git a/src/gluonts/zebras/_time_frame.py b/src/gluonts/zebras/_time_frame.py index 82246bcf12..dec7e591e8 100644 --- a/src/gluonts/zebras/_time_frame.py +++ b/src/gluonts/zebras/_time_frame.py @@ -204,19 +204,23 @@ def move_axis(data, name): head = self.head(5) tail = self.tail(5) - columns.update({ - col: [ - *(move_axis(head[col], col)), - f"[ ... {len(self) - 10} ... ]", - *(move_axis(tail[col], col)), - ] - for col in self.columns - }) + columns.update( + { + col: [ + *(move_axis(head[col], col)), + f"[ ... {len(self) - 10} ... ]", + *(move_axis(tail[col], col)), + ] + for col in self.columns + } + ) else: - columns.update({ - name: move_axis(values, name) - for name, values in self.columns.items() - }) + columns.update( + { + name: move_axis(values, name) + for name, values in self.columns.items() + } + ) return columns @@ -229,10 +233,14 @@ def _repr_html_(self): ] if self.static: - html.extend([ - "

Static Data

", - html_table({name: [val] for name, val in self.static.items()}), - ]) + html.extend( + [ + "

Static Data

", + html_table( + {name: [val] for name, val in self.static.items()} + ), + ] + ) return "\n".join(html) diff --git a/src/gluonts/zebras/schema.py b/src/gluonts/zebras/schema.py index 4ab348eb29..5bfa777fb2 100644 --- a/src/gluonts/zebras/schema.py +++ b/src/gluonts/zebras/schema.py @@ -247,11 +247,13 @@ def load_timeframe( columns = {self.time_series_ref: ref} - columns.update({ - name: field.load_from(data, name, length=length) - for name, field in self.columns.items() - if name != self.time_series_ref - }) + columns.update( + { + name: field.load_from(data, name, length=length) + for name, field in self.columns.items() + if name != self.time_series_ref + } + ) else: columns = {} diff --git a/test/conftest.py b/test/conftest.py index 69222f462a..b624259a15 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -103,11 +103,13 @@ def _sine7( ) train_dataset = ListDataset( - [{ - "start": index[0], - "item_id": "all_items", - "target": Y[:, :-prediction_length], - }], + [ + { + "start": index[0], + "item_id": "all_items", + "target": Y[:, :-prediction_length], + } + ], freq=index.freqstr, one_dim_target=False, ) diff --git a/test/dataset/test_data_loader.py b/test/dataset/test_data_loader.py index 3f937628d6..8dd5684178 100644 --- a/test/dataset/test_data_loader.py +++ b/test/dataset/test_data_loader.py @@ -231,9 +231,12 @@ def test_as_stacked_batches(): def test_as_stacked_batches_iter(): step = 10 - data = iter([ - {"x": np.arange(start, start + step)} for start in range(0, 100, step) - ]) + data = iter( + [ + {"x": np.arange(start, start + step)} + for start in range(0, 100, step) + ] + ) stream = as_stacked_batches(data, batch_size=2) @@ -252,9 +255,12 @@ def test_as_stacked_batches_iter(): def test_as_stacked_batches_iter_num_batches(): step = 10 - data = iter([ - {"x": np.arange(start, start + step)} for start in range(0, 100, step) - ]) + data = iter( + [ + {"x": np.arange(start, start + step)} + for start in range(0, 100, step) + ] + ) stream = as_stacked_batches(data, batch_size=2, num_batches_per_epoch=3) @@ -275,10 +281,12 @@ def test_as_stacked_batches_iter_num_batches(): def test_as_stacked_batches_num_batches_iter_cycle(): step = 10 data = iter( - Cyclic([ - {"x": np.arange(start, start + step)} - for start in range(0, 100, step) - ]) + Cyclic( + [ + {"x": np.arange(start, start + step)} + for start in range(0, 100, step) + ] + ) ) stream = as_stacked_batches(data, batch_size=2, num_batches_per_epoch=3) diff --git a/test/dataset/test_dataset_mutability.py b/test/dataset/test_dataset_mutability.py index d3eba63901..b5c0bbd8fb 100644 --- a/test/dataset/test_dataset_mutability.py +++ b/test/dataset/test_dataset_mutability.py @@ -25,15 +25,21 @@ AddObservedValuesIndicator, ) -ds1 = [{ - "start": pd.Period("2020/01/01", freq="1D"), - "target": np.array([1, 2, 3, np.nan, 5, np.nan, 7, np.nan, np.nan, 10]), -}] +ds1 = [ + { + "start": pd.Period("2020/01/01", freq="1D"), + "target": np.array( + [1, 2, 3, np.nan, 5, np.nan, 7, np.nan, np.nan, 10] + ), + } +] ds2 = ListDataset( - [{ - "start": "2020/01/01", - "target": [1, 2, 3, np.nan, 5, np.nan, 7, np.nan, np.nan, 10], - }], + [ + { + "start": "2020/01/01", + "target": [1, 2, 3, np.nan, 5, np.nan, 7, np.nan, np.nan, 10], + } + ], freq="1D", ) diff --git a/test/dataset/test_multivariate_grouper.py b/test/dataset/test_multivariate_grouper.py index cba2a3f7bc..1f1ac34e75 100644 --- a/test/dataset/test_multivariate_grouper.py +++ b/test/dataset/test_multivariate_grouper.py @@ -43,18 +43,22 @@ MULTIVARIATE_TS = [ [{"start": "2014-09-07", "target": [[1, 2, 3, 4], [5, 6, 7, 8]]}], - [{ - "start": "2014-09-07", - "target": [[1, 2, 3, 4, 2.5], [6.5, 5, 6, 7, 8]], - }], + [ + { + "start": "2014-09-07", + "target": [[1, 2, 3, 4, 2.5], [6.5, 5, 6, 7, 8]], + } + ], [{"start": "2014-09-07", "target": [[1, 2, 3, 4], [0, 0, 0, 0]]}], - [{ - "start": "2014-09-01", - "target": [ - [2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 1, 2, 3, 4], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ], - }], + [ + { + "start": "2014-09-01", + "target": [ + [2.5, 2.5, 2.5, 2.5, 2.5, 2.5, 1, 2, 3, 4], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + } + ], [{"start": "2014-09-07", "target": [[1, 2, 3, 4, 0], [0, 5, 6, 7, 8]]}], ] diff --git a/test/dataset/test_pandas.py b/test/dataset/test_pandas.py index 569b9de143..b3f9834532 100644 --- a/test/dataset/test_pandas.py +++ b/test/dataset/test_pandas.py @@ -241,36 +241,42 @@ def _testcase_dataframes_without_index( dtype=np.float32, ): dataframes = [ - pd.DataFrame.from_dict({ - "timestamp": pd.period_range( - "2021-01-01 00:00:00", periods=10, freq=freq - ) - .map(str) - .to_list(), - "A": 1 + np.arange(10, dtype=dtype), - "B": 2 + np.arange(10, dtype=dtype), - "C": 3 + np.arange(10, dtype=dtype), - }), - pd.DataFrame.from_dict({ - "timestamp": pd.period_range( - "2021-01-02 00:00:00", periods=20, freq=freq - ) - .map(str) - .to_list(), - "A": 1 + np.arange(20, dtype=dtype), - "B": 2 + np.arange(20, dtype=dtype), - "C": 3 + np.arange(20, dtype=dtype), - }), - pd.DataFrame.from_dict({ - "timestamp": pd.period_range( - "2021-01-03 00:00:00", periods=30, freq=freq - ) - .map(str) - .to_list(), - "A": 1 + np.arange(30, dtype=dtype), - "B": 2 + np.arange(30, dtype=dtype), - "C": 3 + np.arange(30, dtype=dtype), - }), + pd.DataFrame.from_dict( + { + "timestamp": pd.period_range( + "2021-01-01 00:00:00", periods=10, freq=freq + ) + .map(str) + .to_list(), + "A": 1 + np.arange(10, dtype=dtype), + "B": 2 + np.arange(10, dtype=dtype), + "C": 3 + np.arange(10, dtype=dtype), + } + ), + pd.DataFrame.from_dict( + { + "timestamp": pd.period_range( + "2021-01-02 00:00:00", periods=20, freq=freq + ) + .map(str) + .to_list(), + "A": 1 + np.arange(20, dtype=dtype), + "B": 2 + np.arange(20, dtype=dtype), + "C": 3 + np.arange(20, dtype=dtype), + } + ), + pd.DataFrame.from_dict( + { + "timestamp": pd.period_range( + "2021-01-03 00:00:00", periods=30, freq=freq + ) + .map(str) + .to_list(), + "A": 1 + np.arange(30, dtype=dtype), + "B": 2 + np.arange(30, dtype=dtype), + "C": 3 + np.arange(30, dtype=dtype), + } + ), ] dataset = pandas.PandasDataset( @@ -301,30 +307,36 @@ def _testcase_dataframes_with_index( dtype=np.float32, ): dataframes = [ - pd.DataFrame.from_dict({ - "timestamp": index_type( - "2021-01-01 00:00:00", periods=10, freq=freq - ), - "A": 1 + np.arange(10, dtype=dtype), - "B": 2 + np.arange(10, dtype=dtype), - "C": 3 + np.arange(10, dtype=dtype), - }).set_index("timestamp"), - pd.DataFrame.from_dict({ - "timestamp": index_type( - "2021-01-02 00:00:00", periods=20, freq=freq - ), - "A": 1 + np.arange(20, dtype=dtype), - "B": 2 + np.arange(20, dtype=dtype), - "C": 3 + np.arange(20, dtype=dtype), - }).set_index("timestamp"), - pd.DataFrame.from_dict({ - "timestamp": index_type( - "2021-01-03 00:00:00", periods=30, freq=freq - ), - "A": 1 + np.arange(30, dtype=dtype), - "B": 2 + np.arange(30, dtype=dtype), - "C": 3 + np.arange(30, dtype=dtype), - }).set_index("timestamp"), + pd.DataFrame.from_dict( + { + "timestamp": index_type( + "2021-01-01 00:00:00", periods=10, freq=freq + ), + "A": 1 + np.arange(10, dtype=dtype), + "B": 2 + np.arange(10, dtype=dtype), + "C": 3 + np.arange(10, dtype=dtype), + } + ).set_index("timestamp"), + pd.DataFrame.from_dict( + { + "timestamp": index_type( + "2021-01-02 00:00:00", periods=20, freq=freq + ), + "A": 1 + np.arange(20, dtype=dtype), + "B": 2 + np.arange(20, dtype=dtype), + "C": 3 + np.arange(20, dtype=dtype), + } + ).set_index("timestamp"), + pd.DataFrame.from_dict( + { + "timestamp": index_type( + "2021-01-03 00:00:00", periods=30, freq=freq + ), + "A": 1 + np.arange(30, dtype=dtype), + "B": 2 + np.arange(30, dtype=dtype), + "C": 3 + np.arange(30, dtype=dtype), + } + ).set_index("timestamp"), ] print(type(dataframes[0].index)) diff --git a/test/dataset/test_split.py b/test/dataset/test_split.py index 606af31a81..33d2cbe98e 100644 --- a/test/dataset/test_split.py +++ b/test/dataset/test_split.py @@ -414,10 +414,12 @@ def test_split_date( @pytest.mark.parametrize( "dataset", [ - [{ - "start": pd.Period("2021-03-01", freq="D"), - "target": np.ones(shape=(28,)), - }], + [ + { + "start": pd.Period("2021-03-01", freq="D"), + "target": np.ones(shape=(28,)), + } + ], ], ) @pytest.mark.parametrize( diff --git a/test/ev/test_aggregations.py b/test/ev/test_aggregations.py index 48a2033fcd..974f8ad199 100644 --- a/test/ev/test_aggregations.py +++ b/test/ev/test_aggregations.py @@ -33,20 +33,24 @@ np.zeros(9), ), ( - np.ma.masked_invalid([ - np.full((3, 5), np.nan), - np.full((3, 5), np.nan), - np.full((3, 5), np.nan), - ]), + np.ma.masked_invalid( + [ + np.full((3, 5), np.nan), + np.full((3, 5), np.nan), + np.full((3, 5), np.nan), + ] + ), 0, np.zeros(5), np.zeros(9), ), ( - np.ma.masked_invalid([ - np.array([[0, np.nan], [0, 0]]), - np.array([[0, 5], [-5, np.nan]]), - ]), + np.ma.masked_invalid( + [ + np.array([[0, np.nan], [0, 0]]), + np.array([[0, 5], [-5, np.nan]]), + ] + ), 0, np.array([-5, 5]), np.array([0, 0, 5, -5]), @@ -87,20 +91,24 @@ def test_Sum(value_stream, res_axis_none, res_axis_0, res_axis_1): np.zeros(9), ), ( - np.ma.masked_invalid([ - np.full((3, 5), np.nan), - np.full((3, 5), np.nan), - np.full((3, 5), np.nan), - ]), + np.ma.masked_invalid( + [ + np.full((3, 5), np.nan), + np.full((3, 5), np.nan), + np.full((3, 5), np.nan), + ] + ), np.nan, np.full(5, np.nan), np.full(9, np.nan), ), ( - np.ma.masked_invalid([ - np.array([[0, np.nan], [0, 0]]), - np.array([[0, 5], [-5, np.nan]]), - ]), + np.ma.masked_invalid( + [ + np.array([[0, np.nan], [0, 0]]), + np.array([[0, 5], [-5, np.nan]]), + ] + ), 0, np.array([-1.25, 2.5]), np.array([0, 0, 2.5, -5]), diff --git a/test/ev/test_metrics_compared_to_previous_approach.py b/test/ev/test_metrics_compared_to_previous_approach.py index 086504744a..5adc2e5d2b 100644 --- a/test/ev/test_metrics_compared_to_previous_approach.py +++ b/test/ev/test_metrics_compared_to_previous_approach.py @@ -127,13 +127,15 @@ def get_data_batches(predictor, test_data): "seasonal_error": np.array( [seasonal_error(input_["target"], seasonality=seasonality)] ), - "naive_2": np.array([ - naive_2( - input_["target"], - len(label["target"]), - season_length=seasonality, - ) - ]), + "naive_2": np.array( + [ + naive_2( + input_["target"], + len(label["target"]), + season_length=seasonality, + ) + ] + ), } yield ChainMap(other_data, forecast_batch) @@ -166,9 +168,12 @@ def get_new_metrics(test_data, predictor, quantile_levels): + MeanSumQuantileLoss([quantile.value for quantile in quantiles]) + MeanWeightedSumQuantileLoss( [quantile.value for quantile in quantiles] - ).add(*[ - WeightedSumQuantileLoss(q=quantile.value) for quantile in quantiles - ]) + ).add( + *[ + WeightedSumQuantileLoss(q=quantile.value) + for quantile in quantiles + ] + ) ) # mask invalid values diff --git a/test/evaluation/test_evaluator.py b/test/evaluation/test_evaluator.py index da57ee389f..ca21c56aa8 100644 --- a/test/evaluation/test_evaluator.py +++ b/test/evaluation/test_evaluator.py @@ -119,130 +119,134 @@ def calculate_metrics( TIMESERIES_M4 = [ - np.array([ + np.array( [ - 2.943_013, - 2.822_251, - 4.196_222, - 1.328_664, - 4.947_390, - 3.333_131, - 1.479_800, - 2.265_094, - 3.413_493, - 3.497_607, - ], - [ - -0.126_781_2, - 3.057_412_2, - 1.901_594_4, - 2.772_549_5, - 3.312_853_1, - 4.411_818_0, - 3.709_025_2, - 4.322_028, - 2.565_359, - 3.074_308, - ], - [ - 2.542_998, - 2.336_757, - 1.417_916, - 1.335_139, - 2.523_035, - 3.645_589, - 3.382_819, - 2.075_960, - 2.643_869, - 2.772_456, - ], - [ - 0.315_685_6, - 1.892_312_1, - 2.476_861_2, - 3.511_628_6, - 4.384_346_5, - 2.960_685_6, - 4.897_572_5, - 3.280_125, - 4.768_556, - 4.958_616, - ], - [ - 2.205_877_3, - 0.782_759_4, - 2.401_420_8, - 2.385_643_4, - 4.845_818_2, - 3.102_322_9, - 3.567_723_7, - 4.878_143, - 3.735_245, - 2.218_113, - ], - ]), - np.array([ - [ - 13.11301, - 13.16225, - 14.70622, - 12.00866, - 15.79739, - 14.35313, - 12.66980, - 13.62509, - 14.94349, - 15.19761, - ], - [ - 10.04322, - 13.39741, - 12.41159, - 13.45255, - 14.16285, - 15.43182, - 14.89903, - 15.68203, - 14.09536, - 14.77431, - ], - [ - 12.71300, - 12.67676, - 11.92792, - 12.01514, - 13.37303, - 14.66559, - 14.57282, - 13.43596, - 14.17387, - 14.47246, - ], - [ - 10.48569, - 12.23231, - 12.98686, - 14.19163, - 15.23435, - 13.98069, - 16.08757, - 14.64012, - 16.29856, - 16.65862, - ], + [ + 2.943_013, + 2.822_251, + 4.196_222, + 1.328_664, + 4.947_390, + 3.333_131, + 1.479_800, + 2.265_094, + 3.413_493, + 3.497_607, + ], + [ + -0.126_781_2, + 3.057_412_2, + 1.901_594_4, + 2.772_549_5, + 3.312_853_1, + 4.411_818_0, + 3.709_025_2, + 4.322_028, + 2.565_359, + 3.074_308, + ], + [ + 2.542_998, + 2.336_757, + 1.417_916, + 1.335_139, + 2.523_035, + 3.645_589, + 3.382_819, + 2.075_960, + 2.643_869, + 2.772_456, + ], + [ + 0.315_685_6, + 1.892_312_1, + 2.476_861_2, + 3.511_628_6, + 4.384_346_5, + 2.960_685_6, + 4.897_572_5, + 3.280_125, + 4.768_556, + 4.958_616, + ], + [ + 2.205_877_3, + 0.782_759_4, + 2.401_420_8, + 2.385_643_4, + 4.845_818_2, + 3.102_322_9, + 3.567_723_7, + 4.878_143, + 3.735_245, + 2.218_113, + ], + ] + ), + np.array( [ - 12.37588, - 11.12276, - 12.91142, - 13.06564, - 15.69582, - 14.12232, - 14.75772, - 16.23814, - 15.26524, - 13.91811, - ], - ]), + [ + 13.11301, + 13.16225, + 14.70622, + 12.00866, + 15.79739, + 14.35313, + 12.66980, + 13.62509, + 14.94349, + 15.19761, + ], + [ + 10.04322, + 13.39741, + 12.41159, + 13.45255, + 14.16285, + 15.43182, + 14.89903, + 15.68203, + 14.09536, + 14.77431, + ], + [ + 12.71300, + 12.67676, + 11.92792, + 12.01514, + 13.37303, + 14.66559, + 14.57282, + 13.43596, + 14.17387, + 14.47246, + ], + [ + 10.48569, + 12.23231, + 12.98686, + 14.19163, + 15.23435, + 13.98069, + 16.08757, + 14.64012, + 16.29856, + 16.65862, + ], + [ + 12.37588, + 11.12276, + 12.91142, + 13.06564, + 15.69582, + 14.12232, + 14.75772, + 16.23814, + 15.26524, + 13.91811, + ], + ] + ), ] RES_M4 = [ diff --git a/test/ext/prophet/test_prophet.py b/test/ext/prophet/test_prophet.py index 865deb9113..5d1335f3cc 100644 --- a/test/ext/prophet/test_prophet.py +++ b/test/ext/prophet/test_prophet.py @@ -32,14 +32,18 @@ def test_feat_dynamic_real_success(freq: str): params = dict(prediction_length=3, prophet_params=dict(n_changepoints=20)) dataset = ListDataset( - data_iter=[{ - "start": "2017-01-01", - "target": np.array([1.0, 2.0, 3.0, 4.0]), - "feat_dynamic_real": np.array([ - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], - ]), - }], + data_iter=[ + { + "start": "2017-01-01", + "target": np.array([1.0, 2.0, 3.0, 4.0]), + "feat_dynamic_real": np.array( + [ + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + ] + ), + } + ], freq=freq, ) @@ -57,14 +61,18 @@ def test_feat_dynamic_real_bad_size(): params = dict(prediction_length=3, prophet_params={}) dataset = ListDataset( - data_iter=[{ - "start": "2017-01-01", - "target": np.array([1.0, 2.0, 3.0, 4.0]), - "feat_dynamic_real": np.array([ - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], - [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], - ]), - }], + data_iter=[ + { + "start": "2017-01-01", + "target": np.array([1.0, 2.0, 3.0, 4.0]), + "feat_dynamic_real": np.array( + [ + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], + ] + ), + } + ], freq="1D", ) diff --git a/test/ext/r_forecast/test_r_multi_seasonality.py b/test/ext/r_forecast/test_r_multi_seasonality.py index ca75283f9d..c130f03f74 100644 --- a/test/ext/r_forecast/test_r_multi_seasonality.py +++ b/test/ext/r_forecast/test_r_multi_seasonality.py @@ -34,24 +34,32 @@ period = 24 ## two weeks of data -dataset = [{ - "start": pd.Period("1990-01-01 00", freq=freq), - "target": np.array([ - item - for i in range(70) - for item in np.sin(2 * np.pi / period * np.arange(1, period + 1, 1)) - ]) - + np.random.normal(0, 0.5, period * 70) - + np.array([ - item - for i in range(10) - for item in [0 for i in range(5 * 24)] - + [8 for i in range(4)] - + [0 for i in range(20)] - + [8 for i in range(4)] - + [0 for i in range(20)] - ]), -}] +dataset = [ + { + "start": pd.Period("1990-01-01 00", freq=freq), + "target": np.array( + [ + item + for i in range(70) + for item in np.sin( + 2 * np.pi / period * np.arange(1, period + 1, 1) + ) + ] + ) + + np.random.normal(0, 0.5, period * 70) + + np.array( + [ + item + for i in range(10) + for item in [0 for i in range(5 * 24)] + + [8 for i in range(4)] + + [0 for i in range(20)] + + [8 for i in range(4)] + + [0 for i in range(20)] + ] + ), + } +] def no_quantile_crossing( @@ -141,25 +149,38 @@ def test_compare_arimas(): ## Below shows improvement in metric when proper x_regressors are included # -dataset_xreg = [{ - "start": pd.Period("1990-01-01 00", freq=freq), - "target": np.array([ - item - for i in range(21) - for item in np.sin(2 * np.pi / period * np.arange(1, period + 1, 1)) - ]) - + np.random.normal(0, 0.5, period * 21) - + np.array([ - item - for i in range(3) - for item in [0 for i in range(167)] + [8 for i in range(0, 1)] - ]), - "feat_dynamic_real": np.array([[ - item - for i in range(3) - for item in [0 for i in range(167)] + [1 for i in range(0, 1)] - ]]), -}] +dataset_xreg = [ + { + "start": pd.Period("1990-01-01 00", freq=freq), + "target": np.array( + [ + item + for i in range(21) + for item in np.sin( + 2 * np.pi / period * np.arange(1, period + 1, 1) + ) + ] + ) + + np.random.normal(0, 0.5, period * 21) + + np.array( + [ + item + for i in range(3) + for item in [0 for i in range(167)] + [8 for i in range(0, 1)] + ] + ), + "feat_dynamic_real": np.array( + [ + [ + item + for i in range(3) + for item in [0 for i in range(167)] + + [1 for i in range(0, 1)] + ] + ] + ), + } +] def test_compare_arimas_xreg(): diff --git a/test/ext/rotbaum/test_rotbaum_smoke.py b/test/ext/rotbaum/test_rotbaum_smoke.py index 04c07e2398..93d1e96dd5 100644 --- a/test/ext/rotbaum/test_rotbaum_smoke.py +++ b/test/ext/rotbaum/test_rotbaum_smoke.py @@ -72,27 +72,31 @@ def test_short_history_item_pred(): { "start": "2017-10-11", "item_id": "item_1", - "target": np.array([ - 1.0, - 9.0, - 2.0, - 0.0, - 0.0, - 1.0, - 5.0, - 3.0, - 4.0, - 2.0, - 0.0, - 0.0, - 1.0, - 6.0, - ]), + "target": np.array( + [ + 1.0, + 9.0, + 2.0, + 0.0, + 0.0, + 1.0, + 5.0, + 3.0, + 4.0, + 2.0, + 0.0, + 0.0, + 1.0, + 6.0, + ] + ), "feat_static_cat": np.array([0.0, 0.0], dtype=float), - "past_feat_dynamic_real": np.array([ - [1.0222e06 for i in range(14)], - [750.0 for i in range(14)], - ]), + "past_feat_dynamic_real": np.array( + [ + [1.0222e06 for i in range(14)], + [750.0 for i in range(14)], + ] + ), }, { "start": "2017-10-11", diff --git a/test/ext/statsforecast/test_statsforecast.py b/test/ext/statsforecast/test_statsforecast.py index 2bb90db539..3c09327b0a 100644 --- a/test/ext/statsforecast/test_statsforecast.py +++ b/test/ext/statsforecast/test_statsforecast.py @@ -123,12 +123,14 @@ def test_model_config( ) @pytest.mark.parametrize( "dataset", - [[ - dict( - start=pd.Period("2021-02-03 00", freq="H"), - target=np.random.normal(loc=10, scale=0.5, size=(100,)), - ) - ]], + [ + [ + dict( + start=pd.Period("2021-02-03 00", freq="H"), + target=np.random.normal(loc=10, scale=0.5, size=(100,)), + ) + ] + ], ) def test_predictor_working( predictor: StatsForecastPredictor, dataset: Dataset diff --git a/test/model/npts/test_npts.py b/test/model/npts/test_npts.py index 97faf4e0b7..a9b7183b58 100644 --- a/test/model/npts/test_npts.py +++ b/test/model/npts/test_npts.py @@ -102,10 +102,12 @@ def test_climatological_forecaster( kernel_type=KernelType.uniform, ) - dataset = [{ - "start": pd.Period(train_ts.index[0], freq=freq), - "target": train_ts.values, - }] + dataset = [ + { + "start": pd.Period(train_ts.index[0], freq=freq), + "target": train_ts.values, + } + ] # validate that the predictor works with targets with NaNs _test_nans_in_target(predictor, dataset) @@ -265,10 +267,12 @@ def test_npts_forecaster( use_seasonal_model=use_seasonal_model, ) - dataset = [{ - "start": pd.Period(train_ts.index[0], freq=freq), - "target": train_ts.values, - }] + dataset = [ + { + "start": pd.Period(train_ts.index[0], freq=freq), + "target": train_ts.values, + } + ] # validate that the predictor works with targets with NaNs _test_nans_in_target(predictor, dataset) @@ -413,9 +417,12 @@ def test_npts_custom_features( freq=train_ts.index.freq, ) # Dummy feature defining 52 seasons - feat_dynamic_real = [[ - (ix % 52) / 51.0 - 0.5 for ix, timestamp in enumerate(full_time_index) - ]] + feat_dynamic_real = [ + [ + (ix % 52) / 51.0 - 0.5 + for ix, timestamp in enumerate(full_time_index) + ] + ] predictor = NPTSPredictor( prediction_length=pred_length, @@ -426,11 +433,13 @@ def test_npts_custom_features( use_default_time_features=False, # disable default time features ) - dataset = [{ - "start": pd.Period(train_ts.index[0], freq=freq), - "target": train_ts.values, - "feat_dynamic_real": np.array(feat_dynamic_real), - }] + dataset = [ + { + "start": pd.Period(train_ts.index[0], freq=freq), + "target": train_ts.values, + "feat_dynamic_real": np.array(feat_dynamic_real), + } + ] # validate that the predictor works with targets with NaNs _test_nans_in_target(predictor, dataset) diff --git a/test/mx/block/test_scaler.py b/test/mx/block/test_scaler.py index a949bdab94..4edec3483d 100644 --- a/test/mx/block/test_scaler.py +++ b/test/mx/block/test_scaler.py @@ -20,118 +20,144 @@ test_cases = [ ( scaler.MeanScaler(), - mx.nd.array([ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ]), - mx.nd.array([ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ]), + mx.nd.array( + [ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ] + ), + mx.nd.array( + [ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ] + ), mx.nd.array([1.0, 3.0, 1.5, 1.00396824, 1.00396824]), ), ( scaler.MeanScaler(default_scale=0.5), - mx.nd.array([ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ]), - mx.nd.array([ - [0.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ]), + mx.nd.array( + [ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ] + ), + mx.nd.array( + [ + [0.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ] + ), mx.nd.array([0.5, 3.0, 1.5, 0.5, 0.5]), ), ( scaler.MeanScaler(keepdims=True), - mx.nd.array([ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ]), - mx.nd.array([ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ]), + mx.nd.array( + [ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ] + ), + mx.nd.array( + [ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ] + ), mx.nd.array([1.0, 3.0, 1.5, 1.00396824, 1.00396824]).expand_dims( axis=1 ), ), ( scaler.MeanScaler(), - mx.nd.array([ - [[1.0]] * 50, - [[0.0]] * 25 + [[3.0]] * 25, - [[2.0]] * 49 + [[1.5]] * 1, - [[0.0]] * 50, - [[1.0]] * 50, - ]), - mx.nd.array([ - [[1.0]] * 50, - [[0.0]] * 25 + [[1.0]] * 25, - [[0.0]] * 49 + [[1.0]] * 1, - [[1.0]] * 50, - [[0.0]] * 50, - ]), + mx.nd.array( + [ + [[1.0]] * 50, + [[0.0]] * 25 + [[3.0]] * 25, + [[2.0]] * 49 + [[1.5]] * 1, + [[0.0]] * 50, + [[1.0]] * 50, + ] + ), + mx.nd.array( + [ + [[1.0]] * 50, + [[0.0]] * 25 + [[1.0]] * 25, + [[0.0]] * 49 + [[1.0]] * 1, + [[1.0]] * 50, + [[0.0]] * 50, + ] + ), mx.nd.array([1.0, 3.0, 1.5, 1.00396824, 1.00396824]).expand_dims( axis=1 ), ), ( scaler.MeanScaler(minimum_scale=1e-8), - mx.nd.array([ - [[1.0, 2.0]] * 50, - [[0.0, 0.0]] * 25 + [[3.0, 6.0]] * 25, - [[2.0, 4.0]] * 49 + [[1.5, 3.0]] * 1, - [[0.0, 0.0]] * 50, - [[1.0, 2.0]] * 50, - ]), - mx.nd.array([ - [[1.0, 1.0]] * 50, - [[0.0, 1.0]] * 25 + [[1.0, 0.0]] * 25, - [[1.0, 0.0]] * 49 + [[0.0, 1.0]] * 1, - [[1.0, 0.0]] * 50, - [[0.0, 1.0]] * 50, - ]), - mx.nd.array([ - [1.0, 2.0], - [3.0, 1.61111116], - [2.0, 3.0], - [1.28160918, 1.61111116], - [1.28160918, 2.0], - ]), + mx.nd.array( + [ + [[1.0, 2.0]] * 50, + [[0.0, 0.0]] * 25 + [[3.0, 6.0]] * 25, + [[2.0, 4.0]] * 49 + [[1.5, 3.0]] * 1, + [[0.0, 0.0]] * 50, + [[1.0, 2.0]] * 50, + ] + ), + mx.nd.array( + [ + [[1.0, 1.0]] * 50, + [[0.0, 1.0]] * 25 + [[1.0, 0.0]] * 25, + [[1.0, 0.0]] * 49 + [[0.0, 1.0]] * 1, + [[1.0, 0.0]] * 50, + [[0.0, 1.0]] * 50, + ] + ), + mx.nd.array( + [ + [1.0, 2.0], + [3.0, 1.61111116], + [2.0, 3.0], + [1.28160918, 1.61111116], + [1.28160918, 2.0], + ] + ), ), ( scaler.MeanScaler(), - mx.nd.array([ - [120.0] * 25 + [150.0] * 25, - [0.0] * 10 + [3.0] * 20 + [61.0] * 20, - [0.0] * 50, - [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, - ]), - mx.nd.array([ - [1.0] * 25 + [1.0] * 25, - [0.0] * 10 + [1.0] * 20 + [1.0] * 20, - [0.0] * 50, - [1.0] * 10 + [0.0] * 30 + [1.0] * 10, - ]), + mx.nd.array( + [ + [120.0] * 25 + [150.0] * 25, + [0.0] * 10 + [3.0] * 20 + [61.0] * 20, + [0.0] * 50, + [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, + ] + ), + mx.nd.array( + [ + [1.0] * 25 + [1.0] * 25, + [0.0] * 10 + [1.0] * 20 + [1.0] * 20, + [0.0] * 50, + [1.0] * 10 + [0.0] * 30 + [1.0] * 10, + ] + ), mx.nd.array([135.0, 32.0, 73.00454712, 2.5e-2]), ), ( @@ -181,93 +207,129 @@ test_minmax = [ ( scaler.MinMax(), - mx.nd.array([ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - ]), - mx.nd.array([ - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - ]), - mx.nd.array([ - [0.0, 0.5, 1.0], - [0.0, 0.5, 1.0], - ]), + mx.nd.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ), + mx.nd.array( + [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ] + ), + mx.nd.array( + [ + [0.0, 0.5, 1.0], + [0.0, 0.5, 1.0], + ] + ), ), ( scaler.MinMax(), - mx.nd.array([ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - ]), - mx.nd.array([ - [0.0, 1.0, 1.0], - [1.0, 1.0, 0.0], - ]), - mx.nd.array([ - [0.0, 0, 1.0], - [0.0, 1.0, 0.0], - ]), + mx.nd.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ), + mx.nd.array( + [ + [0.0, 1.0, 1.0], + [1.0, 1.0, 0.0], + ] + ), + mx.nd.array( + [ + [0.0, 0, 1.0], + [0.0, 1.0, 0.0], + ] + ), ), ( scaler.MinMax(), - mx.nd.array([ - [9.0, 9.0, 9.0], - [4.0, 5.0, 6.0], - ]), - mx.nd.array([ - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - ]), - mx.nd.array([ - [1.0, 1.0, 1.0], - [0.0, 0.5, 1.0], - ]), + mx.nd.array( + [ + [9.0, 9.0, 9.0], + [4.0, 5.0, 6.0], + ] + ), + mx.nd.array( + [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ] + ), + mx.nd.array( + [ + [1.0, 1.0, 1.0], + [0.0, 0.5, 1.0], + ] + ), ), ( scaler.MinMax(), - mx.nd.array([ - [9.0, 9.0, 9.0], - [4.0, 5.0, 6.0], - ]), - mx.nd.array([ - [0.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - ]), - mx.nd.array([ - [0.0, 1.0, 1.0], - [0.0, 0.5, 1.0], - ]), + mx.nd.array( + [ + [9.0, 9.0, 9.0], + [4.0, 5.0, 6.0], + ] + ), + mx.nd.array( + [ + [0.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ] + ), + mx.nd.array( + [ + [0.0, 1.0, 1.0], + [0.0, 0.5, 1.0], + ] + ), ), ( scaler.MinMax(), - mx.nd.array([ - [0.0, 0.0, 0.0], - [4.0, 5.0, 6.0], - ]), - mx.nd.array([ - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - ]), - mx.nd.array([ - [0.0, 0.0, 0.0], - [0.0, 0.5, 1.0], - ]), + mx.nd.array( + [ + [0.0, 0.0, 0.0], + [4.0, 5.0, 6.0], + ] + ), + mx.nd.array( + [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ] + ), + mx.nd.array( + [ + [0.0, 0.0, 0.0], + [0.0, 0.5, 1.0], + ] + ), ), ( scaler.MinMax(axis=0), - mx.nd.array([ - [1.0, 2.0, 3.0], - [4.0, 5.0, 6.0], - ]), - mx.nd.array([ - [1.0, 1.0, 1.0], - [1.0, 1.0, 1.0], - ]), - mx.nd.array([ - [0.0, 0.0, 0.0], - [1.0, 1.0, 1.0], - ]), + mx.nd.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ), + mx.nd.array( + [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + ] + ), + mx.nd.array( + [ + [0.0, 0.0, 0.0], + [1.0, 1.0, 1.0], + ] + ), ), ] diff --git a/test/mx/distribution/test_distribution_methods.py b/test/mx/distribution/test_distribution_methods.py index 9ada4f7bd6..9266fd0428 100644 --- a/test/mx/distribution/test_distribution_methods.py +++ b/test/mx/distribution/test_distribution_methods.py @@ -147,10 +147,12 @@ ( EmpiricalDistribution, { - "samples": mx.nd.stack(*[ - mx.nd.arange(start=0, stop=20, step=2), - mx.nd.arange(start=100, stop=0, step=-10), - ]).transpose(), + "samples": mx.nd.stack( + *[ + mx.nd.arange(start=0, stop=20, step=2), + mx.nd.arange(start=100, stop=0, step=-10), + ] + ).transpose(), "event_dim": 1, }, ), @@ -252,10 +254,12 @@ ( EmpiricalDistribution, { - "samples": mx.nd.stack(*[ - mx.nd.arange(start=0, stop=20, step=2), - mx.nd.arange(start=100, stop=0, step=-10), - ]).transpose(), + "samples": mx.nd.stack( + *[ + mx.nd.arange(start=0, stop=20, step=2), + mx.nd.arange(start=100, stop=0, step=-10), + ] + ).transpose(), "event_dim": 1, }, ), diff --git a/test/mx/distribution/test_distribution_output_shapes.py b/test/mx/distribution/test_distribution_output_shapes.py index c55783397c..7f8ffb857f 100644 --- a/test/mx/distribution/test_distribution_output_shapes.py +++ b/test/mx/distribution/test_distribution_output_shapes.py @@ -157,10 +157,12 @@ TEST_CASES_WITHOUT_VARIANCE = [ ( - MixtureDistributionOutput([ - MultivariateGaussianOutput(dim=5), - MultivariateGaussianOutput(dim=5), - ]), + MixtureDistributionOutput( + [ + MultivariateGaussianOutput(dim=5), + MultivariateGaussianOutput(dim=5), + ] + ), mx.nd.random.normal(shape=(3, 4, 10)), [None, mx.nd.ones(shape=(3, 4, 5))], [None], diff --git a/test/mx/distribution/test_distribution_sampling.py b/test/mx/distribution/test_distribution_sampling.py index e827221f7f..80b61882ce 100644 --- a/test/mx/distribution/test_distribution_sampling.py +++ b/test/mx/distribution/test_distribution_sampling.py @@ -233,18 +233,20 @@ def test_multivariate_sampling(distr, params, dim, serialize_fn) -> None: ) -test_cases_pwl_sqf = [( - PiecewiseLinear, - { - "gamma": mx.nd.array([2]).repeat(axis=0, repeats=2), - "slopes": mx.nd.array([[3, 1, 3, 0.2, 5, 4]]).repeat( - axis=0, repeats=2 - ), - "knot_spacings": mx.nd.array( - [[0.3, 0.2, 0.2, 0.15, 0.1, 0.05]] - ).repeat(axis=0, repeats=2), - }, -)] +test_cases_pwl_sqf = [ + ( + PiecewiseLinear, + { + "gamma": mx.nd.array([2]).repeat(axis=0, repeats=2), + "slopes": mx.nd.array([[3, 1, 3, 0.2, 5, 4]]).repeat( + axis=0, repeats=2 + ), + "knot_spacings": mx.nd.array( + [[0.3, 0.2, 0.2, 0.15, 0.1, 0.05]] + ).repeat(axis=0, repeats=2), + }, + ) +] @pytest.mark.parametrize("distr, params", test_cases_pwl_sqf) diff --git a/test/mx/distribution/test_nan_mixture.py b/test/mx/distribution/test_nan_mixture.py index b3bdec423c..e806d82b7b 100644 --- a/test/mx/distribution/test_nan_mixture.py +++ b/test/mx/distribution/test_nan_mixture.py @@ -63,10 +63,12 @@ def diff(x: np.ndarray, y: np.ndarray) -> np.ndarray: sigma_grad_true[p == 1] = 0 params_gauss_grad = {"mu": mu_grad_true, "sigma": sigma_grad_true} -p_cat = np.array([ - [[[0.1, 0.9], [0.9, 0.1], [0.5, 0.5]]], - [[[0.9, 0.1], [0.05, 0.95], [0.45, 0.55]]], -]) +p_cat = np.array( + [ + [[[0.1, 0.9], [0.9, 0.1], [0.5, 0.5]]], + [[[0.9, 0.1], [0.05, 0.95], [0.45, 0.55]]], + ] +) params_cat = {"log_probs": mx.nd.array(np.log(p_cat))} x_cat = np.array([[[np.nan, 0, 1]], [[np.nan, 0, np.nan]]]) diff --git a/test/mx/distribution/test_piecewise_linear.py b/test/mx/distribution/test_piecewise_linear.py index 406ffb5a77..ea2a44c1ca 100644 --- a/test/mx/distribution/test_piecewise_linear.py +++ b/test/mx/distribution/test_piecewise_linear.py @@ -67,12 +67,12 @@ def test_values( ): distr = serialize_fn(distr) target = mx.nd.array(target).reshape(shape=(len(target),)) - expected_target_cdf = np.array(expected_target_cdf).reshape(( - len(expected_target_cdf), - )) - expected_target_crps = np.array(expected_target_crps).reshape(( - len(expected_target_crps), - )) + expected_target_cdf = np.array(expected_target_cdf).reshape( + (len(expected_target_cdf),) + ) + expected_target_crps = np.array(expected_target_crps).reshape( + (len(expected_target_crps),) + ) assert all(np.isclose(distr.cdf(target).asnumpy(), expected_target_cdf)) assert all(np.isclose(distr.crps(target).asnumpy(), expected_target_crps)) diff --git a/test/mx/kernels/test_periodic_kernel.py b/test/mx/kernels/test_periodic_kernel.py index 329eb8230e..7424f927ed 100644 --- a/test/mx/kernels/test_periodic_kernel.py +++ b/test/mx/kernels/test_periodic_kernel.py @@ -56,11 +56,13 @@ nd.array([[0, 1, 3], [2, -1, 1], [1, 0, -1], [-1, -2, 3]]), nd.array([3, 2.1, 4.2]), nd.array([1.3, 2.5, 3.2]), - nd.array([ - [[14, 2, 2, 14], [57, 41, 19, 83]], - [[40, 56, 24, 72], [84, 116, 172, 12]], - [[22, 42, 26, 38], [217, 249, 299, 155]], - ]), + nd.array( + [ + [[14, 2, 2, 14], [57, 41, 19, 83]], + [[40, 56, 24, 72], [84, 116, 172, 12]], + [[22, 42, 26, 38], [217, 249, 299, 155]], + ] + ), ), ] diff --git a/test/mx/kernels/test_rbf_kernel.py b/test/mx/kernels/test_rbf_kernel.py index 3e13d79408..e60a905e20 100644 --- a/test/mx/kernels/test_rbf_kernel.py +++ b/test/mx/kernels/test_rbf_kernel.py @@ -55,11 +55,13 @@ nd.array([[0, 1, 3], [2, -1, 1], [1, 0, -1], [-1, -2, 3]]), nd.array([3, 2.1, 4.2]), nd.array([1.3, 2.5, 3.2]), - nd.array([ - [[14, 2, 2, 14], [57, 41, 19, 83]], - [[40, 56, 24, 72], [84, 116, 172, 12]], - [[22, 42, 26, 38], [217, 249, 299, 155]], - ]), + nd.array( + [ + [[14, 2, 2, 14], [57, 41, 19, 83]], + [[40, 56, 24, 72], [84, 116, 172, 12]], + [[22, 42, 26, 38], [217, 249, 299, 155]], + ] + ), ), ] diff --git a/test/mx/model/deepvar_hierarchical/generate_hierarchical_dataset.py b/test/mx/model/deepvar_hierarchical/generate_hierarchical_dataset.py index 6e66978d60..8e1737abb2 100644 --- a/test/mx/model/deepvar_hierarchical/generate_hierarchical_dataset.py +++ b/test/mx/model/deepvar_hierarchical/generate_hierarchical_dataset.py @@ -73,11 +73,13 @@ def sine7(seq_length: int, prediction_length: int): ) train_dataset = ListDataset( - [{ - "start": index[0], - "item_id": "all_items", - "target": Y[:, :-prediction_length], - }], + [ + { + "start": index[0], + "item_id": "all_items", + "target": Y[:, :-prediction_length], + } + ], freq=index.freqstr, one_dim_target=False, ) diff --git a/test/mx/model/deepvar_hierarchical/test_train_prediction_with_hts.py b/test/mx/model/deepvar_hierarchical/test_train_prediction_with_hts.py index ddec3e01ea..22fc735217 100644 --- a/test/mx/model/deepvar_hierarchical/test_train_prediction_with_hts.py +++ b/test/mx/model/deepvar_hierarchical/test_train_prediction_with_hts.py @@ -89,7 +89,9 @@ def test_train_prediction(features_df: Optional[pd.DataFrame]): forecasts = list(predictor.predict(predictor_input)) assert len(forecasts) == len(dataset) - assert all([ - forecast.samples.shape == (100, PREDICTION_LENGTH, hts.num_ts) - for forecast in forecasts - ]) + assert all( + [ + forecast.samples.shape == (100, PREDICTION_LENGTH, hts.num_ts) + for forecast in forecasts + ] + ) diff --git a/test/mx/model/gp_forecaster/data.py b/test/mx/model/gp_forecaster/data.py index 3f1450fe8b..3b78ce1ae7 100644 --- a/test/mx/model/gp_forecaster/data.py +++ b/test/mx/model/gp_forecaster/data.py @@ -15,5471 +15,5481 @@ def load_gp_params(): - return nd.array([ - [412263.3050, 8.0703, 57.3620], - [387.5274, 5.0673, 4.2793], - [41625.4972, 6.7450, 9.5796], - [2639.2794, 3.6458, 5.1566], - [4423.1468, 6.0896, 5.1452], - [19065.0601, 4.2969, 13.6918], - [28449.8824, 4.7820, 10.6377], - [300837.0179, 6.8132, 41.2674], - [1374.7983, 4.3378, 4.0292], - [5304.0383, 6.1224, 1.1865], - ]).expand_dims(axis=2) + return nd.array( + [ + [412263.3050, 8.0703, 57.3620], + [387.5274, 5.0673, 4.2793], + [41625.4972, 6.7450, 9.5796], + [2639.2794, 3.6458, 5.1566], + [4423.1468, 6.0896, 5.1452], + [19065.0601, 4.2969, 13.6918], + [28449.8824, 4.7820, 10.6377], + [300837.0179, 6.8132, 41.2674], + [1374.7983, 4.3378, 4.0292], + [5304.0383, 6.1224, 1.1865], + ] + ).expand_dims(axis=2) def load_exact_mean(): - return nd.array([ - [ - 329.91, - 318.8, - 326.58, - 352.84, - 395.77, - 452.58, - 519.87, - 594.14, - 672.03, - 750.42, - 826.4, - 897.05, - 959.22, - 1009.3, - 1043.4, - 1057.1, - 1046.3, - 1007.4, - 938.3, - 838.59, - 710.16, - 557.29, - 386.44, - 370.18, - 341.38, - 330.33, - 338.16, - 364.32, - 406.87, - 462.85, - 528.8, - 601.17, - 676.64, - 752.21, - 825.11, - 892.63, - 951.82, - 999.28, - 1031.1, - 1043.2, - 1031.2, - 991.72, - 922.39, - 822.75, - 694.55, - 541.99, - 371.45, - 376.02, - 347.33, - 336.44, - 344.41, - 370.61, - 412.98, - 468.49, - 533.61, - 604.78, - 678.67, - 752.35, - 823.15, - 888.47, - 945.49, - 990.94, - 1021.1, - 1031.7, - 1018.8, - 978.69, - 909.16, - 809.66, - 681.91, - 529.99, - 360.27, - 377.02, - 348.76, - 338.32, - 346.72, - 373.3, - 415.97, - 471.67, - 536.82, - 607.84, - 681.37, - 754.49, - 824.52, - 888.89, - 944.83, - 989.13, - 1018.1, - 1027.7, - 1013.8, - 973.03, - 903.11, - 803.53, - 676.04, - 524.73, - 355.89, - 374.28, - 346.91, - 337.34, - 346.63, - 374.14, - 417.81, - 474.57, - 540.85, - 613.01, - 687.65, - 761.74, - 832.55, - 897.41, - 953.51, - 997.61, - 1026, - 1034.8, - 1020.1, - 978.36, - 907.64, - 807.53, - 679.87, - 528.79, - 360.61, - 368.96, - 343.04, - 334.91, - 345.67, - 374.79, - 420.26, - 479.07, - 547.66, - 622.33, - 699.59, - 776.26, - 849.41, - 916.22, - 973.7, - 1018.5, - 1047.1, - 1055.3, - 1039.6, - 996.64, - 924.63, - 823.41, - 694.98, - 543.61, - 375.65, - 362.13, - 338.31, - 332.23, - 345.12, - 376.54, - 424.62, - 486.44, - 558.45, - 636.92, - 718.19, - 798.86, - 875.77, - 945.81, - 1005.8, - 1052.2, - 1081.3, - 1089.2, - 1072.4, - 1027.9, - 954.1, - 851.17, - 721.38, - 569.17, - 400.96, - 358.65, - ], - [ - 26.239, - 26.612, - 26.31, - 25.441, - 24.318, - 23.394, - 23.112, - 23.72, - 25.148, - 27.001, - 28.699, - 29.716, - 29.81, - 29.155, - 28.283, - 27.875, - 28.464, - 30.184, - 32.67, - 35.15, - 36.697, - 36.53, - 34.26, - 26.079, - 26.356, - 26.123, - 25.377, - 24.269, - 23.133, - 22.412, - 22.513, - 23.625, - 25.603, - 27.972, - 30.08, - 31.355, - 31.549, - 30.869, - 29.916, - 29.444, - 30.048, - 31.882, - 34.561, - 37.256, - 38.964, - 38.84, - 36.463, - 27.34, - 27.022, - 26.272, - 25.145, - 23.828, - 22.661, - 22.071, - 22.427, - 23.876, - 26.22, - 28.935, - 31.326, - 32.793, - 33.079, - 32.405, - 31.404, - 30.873, - 31.438, - 33.276, - 35.992, - 38.735, - 40.468, - 40.314, - 37.834, - 28.936, - 28.179, - 27.044, - 25.629, - 24.14, - 22.91, - 22.342, - 22.777, - 24.33, - 26.785, - 29.604, - 32.084, - 33.617, - 33.942, - 33.274, - 32.242, - 31.649, - 32.13, - 33.871, - 36.492, - 39.148, - 40.809, - 40.599, - 38.075, - 30.565, - 29.545, - 28.174, - 26.574, - 24.953, - 23.624, - 22.961, - 23.278, - 24.681, - 26.963, - 29.608, - 31.943, - 33.377, - 33.648, - 32.951, - 31.886, - 31.224, - 31.584, - 33.16, - 35.596, - 38.088, - 39.641, - 39.41, - 36.961, - 31.817, - 30.693, - 29.225, - 27.539, - 25.826, - 24.371, - 23.51, - 23.538, - 24.562, - 26.408, - 28.617, - 30.572, - 31.727, - 31.828, - 31.044, - 29.925, - 29.185, - 29.41, - 30.789, - 33.006, - 35.314, - 36.788, - 36.628, - 34.415, - 32.283, - 31.18, - 29.725, - 28.034, - 26.271, - 24.68, - 23.561, - 23.184, - 23.667, - 24.883, - 26.449, - 27.832, - 28.549, - 28.369, - 27.435, - 26.235, - 25.413, - 25.505, - 26.693, - 28.702, - 30.862, - 32.327, - 32.358, - 30.553, - 25.266, - ], - [ - 128.95, - 122.39, - 120.64, - 123.55, - 130.36, - 139.91, - 150.88, - 162.03, - 172.42, - 181.48, - 189.02, - 195.16, - 200.19, - 204.44, - 208.13, - 211.28, - 213.65, - 214.75, - 213.92, - 210.41, - 203.54, - 192.86, - 178.23, - 163.31, - 155.31, - 151.08, - 150.99, - 154.77, - 161.58, - 170.28, - 179.68, - 188.74, - 196.76, - 203.4, - 208.68, - 212.85, - 216.27, - 219.23, - 221.88, - 224.12, - 225.6, - 225.74, - 223.82, - 219.07, - 210.87, - 198.8, - 182.79, - 174.14, - 167.52, - 164.42, - 165.1, - 169.16, - 175.75, - 183.73, - 191.99, - 199.65, - 206.14, - 211.3, - 215.28, - 218.42, - 221.08, - 223.54, - 225.85, - 227.84, - 229.03, - 228.78, - 226.33, - 220.92, - 211.96, - 199.07, - 182.29, - 173.98, - 167.47, - 164.48, - 165.21, - 169.22, - 175.62, - 183.29, - 191.13, - 198.31, - 204.34, - 209.12, - 212.84, - 215.87, - 218.56, - 221.15, - 223.64, - 225.77, - 227.02, - 226.7, - 224.03, - 218.3, - 208.92, - 195.62, - 178.46, - 166, - 158.67, - 155.08, - 155.45, - 159.35, - 165.87, - 173.86, - 182.19, - 189.98, - 196.7, - 202.2, - 206.66, - 210.38, - 213.7, - 216.82, - 219.71, - 222.07, - 223.4, - 223, - 220.13, - 214.1, - 204.4, - 190.8, - 173.44, - 154.11, - 145.46, - 140.96, - 140.92, - 144.95, - 152.12, - 161.23, - 171.03, - 180.51, - 189, - 196.23, - 202.23, - 207.24, - 211.54, - 215.29, - 218.49, - 220.88, - 221.98, - 221.18, - 217.81, - 211.23, - 201.01, - 186.99, - 169.34, - 142.18, - 132.2, - 126.9, - 126.73, - 131.35, - 139.83, - 150.86, - 163.06, - 175.18, - 186.35, - 196.06, - 204.16, - 210.78, - 216.11, - 220.35, - 223.51, - 225.44, - 225.78, - 224.02, - 219.59, - 211.98, - 200.83, - 186.04, - 167.83, - 139.73, - ], - [ - 70.415, - 63.185, - 60.465, - 60.591, - 60.143, - 57.038, - 52.504, - 50.145, - 52.795, - 59.874, - 67.577, - 71.705, - 70.705, - 66.507, - 62.829, - 62.672, - 66.912, - 74.551, - 83.744, - 92.514, - 98.891, - 100.95, - 97.216, - 79.011, - 69.417, - 63.16, - 61.764, - 63.012, - 62.935, - 59.324, - 53.838, - 50.763, - 53.311, - 60.684, - 68.47, - 72.03, - 69.915, - 64.583, - 60.254, - 60.045, - 64.612, - 72.707, - 82.445, - 91.957, - 99.291, - 102.29, - 99.127, - 78.038, - 68.621, - 62.845, - 62.133, - 63.938, - 63.956, - 59.934, - 53.875, - 50.547, - 53.377, - 61.309, - 69.423, - 72.78, - 70.092, - 64.277, - 59.88, - 59.941, - 64.754, - 72.791, - 82.217, - 91.458, - 98.799, - 102.05, - 99.099, - 79.084, - 69.517, - 63.545, - 62.667, - 64.294, - 64.063, - 59.786, - 53.71, - 50.862, - 54.69, - 63.816, - 72.863, - 76.62, - 73.907, - 67.976, - 63.596, - 63.71, - 68.311, - 75.671, - 84.059, - 92.247, - 98.83, - 101.68, - 98.532, - 82.217, - 72.301, - 65.592, - 63.804, - 64.591, - 63.807, - 59.46, - 53.943, - 52.302, - 57.788, - 68.623, - 78.974, - 83.346, - 80.57, - 74.14, - 69.092, - 68.491, - 72.306, - 78.769, - 86.214, - 93.562, - 99.537, - 102.01, - 98.59, - 85.922, - 75.755, - 68.142, - 65.06, - 64.61, - 63.112, - 58.883, - 54.354, - 54.334, - 61.656, - 74.067, - 85.316, - 89.653, - 85.894, - 77.776, - 70.782, - 68.475, - 71.193, - 77.305, - 85.039, - 93.058, - 99.781, - 102.83, - 99.672, - 87.851, - 77.836, - 69.584, - 65.269, - 63.539, - 61.355, - 57.414, - 54.058, - 55.603, - 64.259, - 77.282, - 88.15, - 91.02, - 84.779, - 73.49, - 63.277, - 58.54, - 60.333, - 67.279, - 77.276, - 88.242, - 97.786, - 102.97, - 101.07, - 80.395, - ], - [ - 41.641, - 40.194, - 40.186, - 42.23, - 46.594, - 53.121, - 61.229, - 70.012, - 78.439, - 85.592, - 90.892, - 94.251, - 96.067, - 97.076, - 98.062, - 99.52, - 101.37, - 102.81, - 102.42, - 98.456, - 89.282, - 73.864, - 52.127, - 46.345, - 44.621, - 43.544, - 43.912, - 46.325, - 51.036, - 57.864, - 66.199, - 75.106, - 83.524, - 90.514, - 95.492, - 98.379, - 99.607, - 99.965, - 100.3, - 101.19, - 102.58, - 103.72, - 103.16, - 99.119, - 89.886, - 74.352, - 52.377, - 47.029, - 45.418, - 44.464, - 44.971, - 47.544, - 52.44, - 59.48, - 68.047, - 77.196, - 85.846, - 93.031, - 98.137, - 101.06, - 102.23, - 102.43, - 102.55, - 103.18, - 104.36, - 105.36, - 104.77, - 100.79, - 91.703, - 76.343, - 54.515, - 46.417, - 44.725, - 43.69, - 44.133, - 46.678, - 51.605, - 58.757, - 67.538, - 77.004, - 86.064, - 93.719, - 99.307, - 102.67, - 104.17, - 104.59, - 104.77, - 105.37, - 106.44, - 107.34, - 106.73, - 102.87, - 94.051, - 79.104, - 57.779, - 45.332, - 43.46, - 42.227, - 42.471, - 44.847, - 49.673, - 56.832, - 65.762, - 75.54, - 85.07, - 93.32, - 99.567, - 103.57, - 105.62, - 106.39, - 106.74, - 107.3, - 108.21, - 108.9, - 108.17, - 104.36, - 95.857, - 81.497, - 60.997, - 44.582, - 42.54, - 41.092, - 41.092, - 43.223, - 47.848, - 54.899, - 63.862, - 73.848, - 83.763, - 92.548, - 99.412, - 104.02, - 106.56, - 107.63, - 108.02, - 108.36, - 108.88, - 109.12, - 108.03, - 104.13, - 95.9, - 82.242, - 62.844, - 44.76, - 42.661, - 41.081, - 40.875, - 42.75, - 47.106, - 53.93, - 62.761, - 72.751, - 82.824, - 91.901, - 99.134, - 104.1, - 106.86, - 107.94, - 108.05, - 107.85, - 107.62, - 107.08, - 105.35, - 101.14, - 93.104, - 80.215, - 62.151, - 43.724, - ], - [ - 118.98, - 118.77, - 114.81, - 110.94, - 111.76, - 120.81, - 139.03, - 164.13, - 191.14, - 214.26, - 229.24, - 235.31, - 235.48, - 234.91, - 237.59, - 243.43, - 247.56, - 242.8, - 224.16, - 192.83, - 156.88, - 128.03, - 116.11, - 114.56, - 120.47, - 121.48, - 119.8, - 119.4, - 124.67, - 138.61, - 161.35, - 189.66, - 217.79, - 239.51, - 250.63, - 250.97, - 244.63, - 238.01, - 236.21, - 239.67, - 243.29, - 239.03, - 220.99, - 189.83, - 153.87, - 125.62, - 115.76, - 113.73, - 119.34, - 120.43, - 119.35, - 120.07, - 126.81, - 142.28, - 166.27, - 195.25, - 223.31, - 244.16, - 253.68, - 251.93, - 243.39, - 234.99, - 232.27, - 235.87, - 240.45, - 237.39, - 220.08, - 188.86, - 152.21, - 123.36, - 113.8, - 112.5, - 117.39, - 117.83, - 116.14, - 116.19, - 122.09, - 136.48, - 159.24, - 187.11, - 214.47, - 235.32, - 245.61, - 245.23, - 238.46, - 231.97, - 231.22, - 236.76, - 243.08, - 241.27, - 224.36, - 192.43, - 154.12, - 123.17, - 111.77, - 111.84, - 116.07, - 115.81, - 113.28, - 112.16, - 116.42, - 128.77, - 149.41, - 175.62, - 202.44, - 224.24, - 237.01, - 240.27, - 237.61, - 235.01, - 237.45, - 245.27, - 252.98, - 251.67, - 234.34, - 200.95, - 160.12, - 125.78, - 110.61, - 111.8, - 115.85, - 115.53, - 112.91, - 111.46, - 114.95, - 126.08, - 145.37, - 170.67, - 197.62, - 221.06, - 236.98, - 244.41, - 246.13, - 247.26, - 252.18, - 261.08, - 268.7, - 266.51, - 247.75, - 212.36, - 168.66, - 130.37, - 110.27, - 111.49, - 115.98, - 116.58, - 115.25, - 115.16, - 119.79, - 131.68, - 151.41, - 177.19, - 205.1, - 230.34, - 248.92, - 259.5, - 264.14, - 267.18, - 272.48, - 280.25, - 285.71, - 281.08, - 260.09, - 222.65, - 176.54, - 134.76, - 109.72, - 113.52, - ], - [ - 193.76, - 175.38, - 162.27, - 160.36, - 172.25, - 195.97, - 225.74, - 254.47, - 276.83, - 291.1, - 299.03, - 303.88, - 308.06, - 311.91, - 314.25, - 314.13, - 312.28, - 310.93, - 311.97, - 314.65, - 314.29, - 303.4, - 274.7, - 224.26, - 207.31, - 188.71, - 175.68, - 174.45, - 187.74, - 213.48, - 245.57, - 276.56, - 300.69, - 315.98, - 324.06, - 328.17, - 330.76, - 332.29, - 331.78, - 328.62, - 323.95, - 320.4, - 320.06, - 321.98, - 320.97, - 308.8, - 277.58, - 228.89, - 211.58, - 193.09, - 180.66, - 180.38, - 194.76, - 221.52, - 254.41, - 285.92, - 310.33, - 325.71, - 333.69, - 337.44, - 339.35, - 339.79, - 337.84, - 333.07, - 326.89, - 322.18, - 321.12, - 322.62, - 321.14, - 308.02, - 275.04, - 226.62, - 208.9, - 190.91, - 179.59, - 180.54, - 195.75, - 222.62, - 254.92, - 285.45, - 308.91, - 323.71, - 331.58, - 335.58, - 337.83, - 338.47, - 336.46, - 331.43, - 324.91, - 319.91, - 318.63, - 319.91, - 318.05, - 304.27, - 270.31, - 220.88, - 202.94, - 185.9, - 176.17, - 178.56, - 194.36, - 220.62, - 251.31, - 279.83, - 301.61, - 315.58, - 323.64, - 328.59, - 332.13, - 334.02, - 333.01, - 328.67, - 322.59, - 317.79, - 316.49, - 317.49, - 315.17, - 300.91, - 266.66, - 215.01, - 197.33, - 181.72, - 173.93, - 177.86, - 194.01, - 219.16, - 247.65, - 273.68, - 293.62, - 306.99, - 315.73, - 322.31, - 327.86, - 331.65, - 332.18, - 328.96, - 323.58, - 319.06, - 317.57, - 318.01, - 314.99, - 300.29, - 266.31, - 211.08, - 194.33, - 180.62, - 174.97, - 180.35, - 196.57, - 220.29, - 246.36, - 269.94, - 288.33, - 301.58, - 311.6, - 320.31, - 328.22, - 334.09, - 336.21, - 334.05, - 329.23, - 324.7, - 322.64, - 322.07, - 318.01, - 302.85, - 269.56, - 210.47, - ], + return nd.array( [ - 379.51, - 379.84, - 395.91, - 427.92, - 473.69, - 529.07, - 588.7, - 647.08, - 699.6, - 743.3, - 777.29, - 802.52, - 821.09, - 835.21, - 846.12, - 853.17, - 853.52, - 842.42, - 814.09, - 763.04, - 685.5, - 580.62, - 451.07, - 410.65, - 399.71, - 402.4, - 421.14, - 455.94, - 504.31, - 561.69, - 622.3, - 680.28, - 730.75, - 770.76, - 799.62, - 818.74, - 830.83, - 838.79, - 844.48, - 847.73, - 845.89, - 834.15, - 806.4, - 756.67, - 680.62, - 576.86, - 447.69, - 414.98, - 404.12, - 407.11, - 426.47, - 462.19, - 511.7, - 570.26, - 631.84, - 690.32, - 740.61, - 779.61, - 806.68, - 823.39, - 832.79, - 838.15, - 841.71, - 843.6, - 841.34, - 830.09, - 803.55, - 755.43, - 681.02, - 578.58, - 450.14, - 410.03, - 397.89, - 399.64, - 418.02, - 453.23, - 502.83, - 562.11, - 624.99, - 685.17, - 737.35, - 778.23, - 806.94, - 824.96, - 835.31, - 841.33, - 845.37, - 847.71, - 845.99, - 835.41, - 809.68, - 762.41, - 688.78, - 586.93, - 458.78, - 401.4, - 387.35, - 387.06, - 403.58, - 437.47, - 486.6, - 546.53, - 611.24, - 674.4, - 730.49, - 775.86, - 809.27, - 831.79, - 846.08, - 855.24, - 861.51, - 865.15, - 863.91, - 853.15, - 826.78, - 778.64, - 704.06, - 601.31, - 472.38, - 394.53, - 378.79, - 376.46, - 391, - 423.41, - 472.04, - 532.82, - 600, - 667.24, - 728.84, - 780.73, - 821.08, - 850.34, - 870.52, - 884.23, - 893.36, - 898.07, - 896.23, - 883.52, - 854.27, - 802.83, - 725, - 619.43, - 488.37, - 393.66, - 377.18, - 373.6, - 386.72, - 418.08, - 466.56, - 528.58, - 598.73, - 670.77, - 738.81, - 798.32, - 846.78, - 883.81, - 910.59, - 929.03, - 940.55, - 945.15, - 940.84, - 923.78, - 888.98, - 831.54, - 748.02, - 637.69, - 503.18, - 392.42, - ], - [ - 41.737, - 33.076, - 24.388, - 16.425, - 10.052, - 6.4268, - 6.8474, - 12.206, - 22.299, - 35.43, - 48.68, - 58.883, - 63.898, - 63.561, - 59.771, - 55.585, - 53.758, - 55.447, - 59.744, - 64.226, - 66.185, - 63.872, - 57.142, - 51.474, - 43.482, - 34.651, - 25.854, - 17.828, - 11.468, - 8.0229, - 8.9065, - 15.065, - 26.203, - 40.365, - 54.275, - 64.445, - 68.619, - 66.805, - 61.329, - 55.764, - 53.26, - 55.074, - 60.078, - 65.417, - 67.947, - 65.656, - 58.396, - 51.897, - 43.901, - 35.098, - 26.31, - 18.242, - 11.813, - 8.3517, - 9.3741, - 15.887, - 27.558, - 42.285, - 56.579, - 66.77, - 70.543, - 68.02, - 61.774, - 55.659, - 53.001, - 55.044, - 60.453, - 66.077, - 68.523, - 65.69, - 57.541, - 51.467, - 43.599, - 34.896, - 26.14, - 18.03, - 11.515, - 7.9791, - 8.9882, - 15.567, - 27.362, - 42.223, - 56.617, - 66.867, - 70.691, - 68.278, - 62.263, - 56.511, - 54.28, - 56.674, - 62.18, - 67.505, - 69.216, - 65.307, - 55.944, - 50.993, - 43.278, - 34.645, - 25.895, - 17.776, - 11.289, - 7.8312, - 8.9214, - 15.506, - 27.173, - 41.794, - 55.957, - 66.162, - 70.288, - 68.58, - 63.578, - 58.925, - 57.59, - 60.424, - 65.773, - 70.355, - 70.884, - 65.589, - 54.893, - 51.015, - 43.422, - 34.798, - 26.04, - 18.023, - 11.818, - 8.7764, - 10.249, - 16.969, - 28.395, - 42.451, - 55.976, - 65.827, - 70.181, - 69.366, - 65.74, - 62.552, - 62.332, - 65.609, - 70.634, - 74.263, - 73.492, - 66.863, - 55.035, - 51.552, - 44.063, - 35.433, - 26.729, - 19.01, - 13.424, - 11.212, - 13.437, - 20.492, - 31.624, - 44.814, - 57.244, - 66.265, - 70.467, - 70.32, - 67.965, - 66.186, - 67.005, - 70.615, - 75.233, - 77.94, - 76.092, - 68.545, - 56.119, - 49.486, - ], - [ - 75.287, - 73.758, - 72.659, - 72.035, - 71.962, - 72.496, - 73.613, - 75.186, - 76.995, - 78.785, - 80.34, - 81.53, - 82.325, - 82.764, - 82.909, - 82.804, - 82.474, - 81.942, - 81.26, - 80.525, - 79.851, - 79.319, - 78.904, - 81.595, - 81.221, - 81.035, - 80.886, - 80.722, - 80.593, - 80.601, - 80.828, - 81.284, - 81.896, - 82.541, - 83.098, - 83.49, - 83.696, - 83.727, - 83.595, - 83.288, - 82.784, - 82.081, - 81.234, - 80.368, - 79.638, - 79.167, - 78.954, - 81.439, - 81.784, - 82.216, - 82.507, - 82.564, - 82.436, - 82.264, - 82.193, - 82.307, - 82.599, - 82.99, - 83.378, - 83.682, - 83.857, - 83.881, - 83.732, - 83.374, - 82.773, - 81.934, - 80.937, - 79.941, - 79.139, - 78.685, - 78.595, - 78.829, - 79.15, - 79.549, - 79.794, - 79.799, - 79.638, - 79.482, - 79.503, - 79.802, - 80.367, - 81.098, - 81.857, - 82.52, - 83.002, - 83.256, - 83.251, - 82.957, - 82.36, - 81.493, - 80.463, - 79.451, - 78.668, - 78.274, - 78.283, - 76.309, - 76.127, - 76.043, - 75.873, - 75.58, - 75.277, - 75.155, - 75.391, - 76.059, - 77.103, - 78.362, - 79.636, - 80.745, - 81.574, - 82.07, - 82.22, - 82.031, - 81.527, - 80.773, - 79.89, - 79.055, - 78.46, - 78.234, - 78.365, - 76.287, - 75.462, - 74.722, - 73.96, - 73.203, - 72.608, - 72.384, - 72.692, - 73.569, - 74.899, - 76.456, - 77.981, - 79.261, - 80.174, - 80.69, - 80.844, - 80.699, - 80.328, - 79.815, - 79.271, - 78.83, - 78.618, - 78.695, - 78.994, - 80.444, - 79.175, - 77.895, - 76.574, - 75.301, - 74.273, - 73.708, - 73.754, - 74.41, - 75.518, - 76.813, - 78.012, - 78.905, - 79.401, - 79.532, - 79.402, - 79.144, - 78.872, - 78.674, - 78.616, - 78.746, - 79.081, - 79.576, - 80.082, - 77.239, - ], - ]) + [ + 329.91, + 318.8, + 326.58, + 352.84, + 395.77, + 452.58, + 519.87, + 594.14, + 672.03, + 750.42, + 826.4, + 897.05, + 959.22, + 1009.3, + 1043.4, + 1057.1, + 1046.3, + 1007.4, + 938.3, + 838.59, + 710.16, + 557.29, + 386.44, + 370.18, + 341.38, + 330.33, + 338.16, + 364.32, + 406.87, + 462.85, + 528.8, + 601.17, + 676.64, + 752.21, + 825.11, + 892.63, + 951.82, + 999.28, + 1031.1, + 1043.2, + 1031.2, + 991.72, + 922.39, + 822.75, + 694.55, + 541.99, + 371.45, + 376.02, + 347.33, + 336.44, + 344.41, + 370.61, + 412.98, + 468.49, + 533.61, + 604.78, + 678.67, + 752.35, + 823.15, + 888.47, + 945.49, + 990.94, + 1021.1, + 1031.7, + 1018.8, + 978.69, + 909.16, + 809.66, + 681.91, + 529.99, + 360.27, + 377.02, + 348.76, + 338.32, + 346.72, + 373.3, + 415.97, + 471.67, + 536.82, + 607.84, + 681.37, + 754.49, + 824.52, + 888.89, + 944.83, + 989.13, + 1018.1, + 1027.7, + 1013.8, + 973.03, + 903.11, + 803.53, + 676.04, + 524.73, + 355.89, + 374.28, + 346.91, + 337.34, + 346.63, + 374.14, + 417.81, + 474.57, + 540.85, + 613.01, + 687.65, + 761.74, + 832.55, + 897.41, + 953.51, + 997.61, + 1026, + 1034.8, + 1020.1, + 978.36, + 907.64, + 807.53, + 679.87, + 528.79, + 360.61, + 368.96, + 343.04, + 334.91, + 345.67, + 374.79, + 420.26, + 479.07, + 547.66, + 622.33, + 699.59, + 776.26, + 849.41, + 916.22, + 973.7, + 1018.5, + 1047.1, + 1055.3, + 1039.6, + 996.64, + 924.63, + 823.41, + 694.98, + 543.61, + 375.65, + 362.13, + 338.31, + 332.23, + 345.12, + 376.54, + 424.62, + 486.44, + 558.45, + 636.92, + 718.19, + 798.86, + 875.77, + 945.81, + 1005.8, + 1052.2, + 1081.3, + 1089.2, + 1072.4, + 1027.9, + 954.1, + 851.17, + 721.38, + 569.17, + 400.96, + 358.65, + ], + [ + 26.239, + 26.612, + 26.31, + 25.441, + 24.318, + 23.394, + 23.112, + 23.72, + 25.148, + 27.001, + 28.699, + 29.716, + 29.81, + 29.155, + 28.283, + 27.875, + 28.464, + 30.184, + 32.67, + 35.15, + 36.697, + 36.53, + 34.26, + 26.079, + 26.356, + 26.123, + 25.377, + 24.269, + 23.133, + 22.412, + 22.513, + 23.625, + 25.603, + 27.972, + 30.08, + 31.355, + 31.549, + 30.869, + 29.916, + 29.444, + 30.048, + 31.882, + 34.561, + 37.256, + 38.964, + 38.84, + 36.463, + 27.34, + 27.022, + 26.272, + 25.145, + 23.828, + 22.661, + 22.071, + 22.427, + 23.876, + 26.22, + 28.935, + 31.326, + 32.793, + 33.079, + 32.405, + 31.404, + 30.873, + 31.438, + 33.276, + 35.992, + 38.735, + 40.468, + 40.314, + 37.834, + 28.936, + 28.179, + 27.044, + 25.629, + 24.14, + 22.91, + 22.342, + 22.777, + 24.33, + 26.785, + 29.604, + 32.084, + 33.617, + 33.942, + 33.274, + 32.242, + 31.649, + 32.13, + 33.871, + 36.492, + 39.148, + 40.809, + 40.599, + 38.075, + 30.565, + 29.545, + 28.174, + 26.574, + 24.953, + 23.624, + 22.961, + 23.278, + 24.681, + 26.963, + 29.608, + 31.943, + 33.377, + 33.648, + 32.951, + 31.886, + 31.224, + 31.584, + 33.16, + 35.596, + 38.088, + 39.641, + 39.41, + 36.961, + 31.817, + 30.693, + 29.225, + 27.539, + 25.826, + 24.371, + 23.51, + 23.538, + 24.562, + 26.408, + 28.617, + 30.572, + 31.727, + 31.828, + 31.044, + 29.925, + 29.185, + 29.41, + 30.789, + 33.006, + 35.314, + 36.788, + 36.628, + 34.415, + 32.283, + 31.18, + 29.725, + 28.034, + 26.271, + 24.68, + 23.561, + 23.184, + 23.667, + 24.883, + 26.449, + 27.832, + 28.549, + 28.369, + 27.435, + 26.235, + 25.413, + 25.505, + 26.693, + 28.702, + 30.862, + 32.327, + 32.358, + 30.553, + 25.266, + ], + [ + 128.95, + 122.39, + 120.64, + 123.55, + 130.36, + 139.91, + 150.88, + 162.03, + 172.42, + 181.48, + 189.02, + 195.16, + 200.19, + 204.44, + 208.13, + 211.28, + 213.65, + 214.75, + 213.92, + 210.41, + 203.54, + 192.86, + 178.23, + 163.31, + 155.31, + 151.08, + 150.99, + 154.77, + 161.58, + 170.28, + 179.68, + 188.74, + 196.76, + 203.4, + 208.68, + 212.85, + 216.27, + 219.23, + 221.88, + 224.12, + 225.6, + 225.74, + 223.82, + 219.07, + 210.87, + 198.8, + 182.79, + 174.14, + 167.52, + 164.42, + 165.1, + 169.16, + 175.75, + 183.73, + 191.99, + 199.65, + 206.14, + 211.3, + 215.28, + 218.42, + 221.08, + 223.54, + 225.85, + 227.84, + 229.03, + 228.78, + 226.33, + 220.92, + 211.96, + 199.07, + 182.29, + 173.98, + 167.47, + 164.48, + 165.21, + 169.22, + 175.62, + 183.29, + 191.13, + 198.31, + 204.34, + 209.12, + 212.84, + 215.87, + 218.56, + 221.15, + 223.64, + 225.77, + 227.02, + 226.7, + 224.03, + 218.3, + 208.92, + 195.62, + 178.46, + 166, + 158.67, + 155.08, + 155.45, + 159.35, + 165.87, + 173.86, + 182.19, + 189.98, + 196.7, + 202.2, + 206.66, + 210.38, + 213.7, + 216.82, + 219.71, + 222.07, + 223.4, + 223, + 220.13, + 214.1, + 204.4, + 190.8, + 173.44, + 154.11, + 145.46, + 140.96, + 140.92, + 144.95, + 152.12, + 161.23, + 171.03, + 180.51, + 189, + 196.23, + 202.23, + 207.24, + 211.54, + 215.29, + 218.49, + 220.88, + 221.98, + 221.18, + 217.81, + 211.23, + 201.01, + 186.99, + 169.34, + 142.18, + 132.2, + 126.9, + 126.73, + 131.35, + 139.83, + 150.86, + 163.06, + 175.18, + 186.35, + 196.06, + 204.16, + 210.78, + 216.11, + 220.35, + 223.51, + 225.44, + 225.78, + 224.02, + 219.59, + 211.98, + 200.83, + 186.04, + 167.83, + 139.73, + ], + [ + 70.415, + 63.185, + 60.465, + 60.591, + 60.143, + 57.038, + 52.504, + 50.145, + 52.795, + 59.874, + 67.577, + 71.705, + 70.705, + 66.507, + 62.829, + 62.672, + 66.912, + 74.551, + 83.744, + 92.514, + 98.891, + 100.95, + 97.216, + 79.011, + 69.417, + 63.16, + 61.764, + 63.012, + 62.935, + 59.324, + 53.838, + 50.763, + 53.311, + 60.684, + 68.47, + 72.03, + 69.915, + 64.583, + 60.254, + 60.045, + 64.612, + 72.707, + 82.445, + 91.957, + 99.291, + 102.29, + 99.127, + 78.038, + 68.621, + 62.845, + 62.133, + 63.938, + 63.956, + 59.934, + 53.875, + 50.547, + 53.377, + 61.309, + 69.423, + 72.78, + 70.092, + 64.277, + 59.88, + 59.941, + 64.754, + 72.791, + 82.217, + 91.458, + 98.799, + 102.05, + 99.099, + 79.084, + 69.517, + 63.545, + 62.667, + 64.294, + 64.063, + 59.786, + 53.71, + 50.862, + 54.69, + 63.816, + 72.863, + 76.62, + 73.907, + 67.976, + 63.596, + 63.71, + 68.311, + 75.671, + 84.059, + 92.247, + 98.83, + 101.68, + 98.532, + 82.217, + 72.301, + 65.592, + 63.804, + 64.591, + 63.807, + 59.46, + 53.943, + 52.302, + 57.788, + 68.623, + 78.974, + 83.346, + 80.57, + 74.14, + 69.092, + 68.491, + 72.306, + 78.769, + 86.214, + 93.562, + 99.537, + 102.01, + 98.59, + 85.922, + 75.755, + 68.142, + 65.06, + 64.61, + 63.112, + 58.883, + 54.354, + 54.334, + 61.656, + 74.067, + 85.316, + 89.653, + 85.894, + 77.776, + 70.782, + 68.475, + 71.193, + 77.305, + 85.039, + 93.058, + 99.781, + 102.83, + 99.672, + 87.851, + 77.836, + 69.584, + 65.269, + 63.539, + 61.355, + 57.414, + 54.058, + 55.603, + 64.259, + 77.282, + 88.15, + 91.02, + 84.779, + 73.49, + 63.277, + 58.54, + 60.333, + 67.279, + 77.276, + 88.242, + 97.786, + 102.97, + 101.07, + 80.395, + ], + [ + 41.641, + 40.194, + 40.186, + 42.23, + 46.594, + 53.121, + 61.229, + 70.012, + 78.439, + 85.592, + 90.892, + 94.251, + 96.067, + 97.076, + 98.062, + 99.52, + 101.37, + 102.81, + 102.42, + 98.456, + 89.282, + 73.864, + 52.127, + 46.345, + 44.621, + 43.544, + 43.912, + 46.325, + 51.036, + 57.864, + 66.199, + 75.106, + 83.524, + 90.514, + 95.492, + 98.379, + 99.607, + 99.965, + 100.3, + 101.19, + 102.58, + 103.72, + 103.16, + 99.119, + 89.886, + 74.352, + 52.377, + 47.029, + 45.418, + 44.464, + 44.971, + 47.544, + 52.44, + 59.48, + 68.047, + 77.196, + 85.846, + 93.031, + 98.137, + 101.06, + 102.23, + 102.43, + 102.55, + 103.18, + 104.36, + 105.36, + 104.77, + 100.79, + 91.703, + 76.343, + 54.515, + 46.417, + 44.725, + 43.69, + 44.133, + 46.678, + 51.605, + 58.757, + 67.538, + 77.004, + 86.064, + 93.719, + 99.307, + 102.67, + 104.17, + 104.59, + 104.77, + 105.37, + 106.44, + 107.34, + 106.73, + 102.87, + 94.051, + 79.104, + 57.779, + 45.332, + 43.46, + 42.227, + 42.471, + 44.847, + 49.673, + 56.832, + 65.762, + 75.54, + 85.07, + 93.32, + 99.567, + 103.57, + 105.62, + 106.39, + 106.74, + 107.3, + 108.21, + 108.9, + 108.17, + 104.36, + 95.857, + 81.497, + 60.997, + 44.582, + 42.54, + 41.092, + 41.092, + 43.223, + 47.848, + 54.899, + 63.862, + 73.848, + 83.763, + 92.548, + 99.412, + 104.02, + 106.56, + 107.63, + 108.02, + 108.36, + 108.88, + 109.12, + 108.03, + 104.13, + 95.9, + 82.242, + 62.844, + 44.76, + 42.661, + 41.081, + 40.875, + 42.75, + 47.106, + 53.93, + 62.761, + 72.751, + 82.824, + 91.901, + 99.134, + 104.1, + 106.86, + 107.94, + 108.05, + 107.85, + 107.62, + 107.08, + 105.35, + 101.14, + 93.104, + 80.215, + 62.151, + 43.724, + ], + [ + 118.98, + 118.77, + 114.81, + 110.94, + 111.76, + 120.81, + 139.03, + 164.13, + 191.14, + 214.26, + 229.24, + 235.31, + 235.48, + 234.91, + 237.59, + 243.43, + 247.56, + 242.8, + 224.16, + 192.83, + 156.88, + 128.03, + 116.11, + 114.56, + 120.47, + 121.48, + 119.8, + 119.4, + 124.67, + 138.61, + 161.35, + 189.66, + 217.79, + 239.51, + 250.63, + 250.97, + 244.63, + 238.01, + 236.21, + 239.67, + 243.29, + 239.03, + 220.99, + 189.83, + 153.87, + 125.62, + 115.76, + 113.73, + 119.34, + 120.43, + 119.35, + 120.07, + 126.81, + 142.28, + 166.27, + 195.25, + 223.31, + 244.16, + 253.68, + 251.93, + 243.39, + 234.99, + 232.27, + 235.87, + 240.45, + 237.39, + 220.08, + 188.86, + 152.21, + 123.36, + 113.8, + 112.5, + 117.39, + 117.83, + 116.14, + 116.19, + 122.09, + 136.48, + 159.24, + 187.11, + 214.47, + 235.32, + 245.61, + 245.23, + 238.46, + 231.97, + 231.22, + 236.76, + 243.08, + 241.27, + 224.36, + 192.43, + 154.12, + 123.17, + 111.77, + 111.84, + 116.07, + 115.81, + 113.28, + 112.16, + 116.42, + 128.77, + 149.41, + 175.62, + 202.44, + 224.24, + 237.01, + 240.27, + 237.61, + 235.01, + 237.45, + 245.27, + 252.98, + 251.67, + 234.34, + 200.95, + 160.12, + 125.78, + 110.61, + 111.8, + 115.85, + 115.53, + 112.91, + 111.46, + 114.95, + 126.08, + 145.37, + 170.67, + 197.62, + 221.06, + 236.98, + 244.41, + 246.13, + 247.26, + 252.18, + 261.08, + 268.7, + 266.51, + 247.75, + 212.36, + 168.66, + 130.37, + 110.27, + 111.49, + 115.98, + 116.58, + 115.25, + 115.16, + 119.79, + 131.68, + 151.41, + 177.19, + 205.1, + 230.34, + 248.92, + 259.5, + 264.14, + 267.18, + 272.48, + 280.25, + 285.71, + 281.08, + 260.09, + 222.65, + 176.54, + 134.76, + 109.72, + 113.52, + ], + [ + 193.76, + 175.38, + 162.27, + 160.36, + 172.25, + 195.97, + 225.74, + 254.47, + 276.83, + 291.1, + 299.03, + 303.88, + 308.06, + 311.91, + 314.25, + 314.13, + 312.28, + 310.93, + 311.97, + 314.65, + 314.29, + 303.4, + 274.7, + 224.26, + 207.31, + 188.71, + 175.68, + 174.45, + 187.74, + 213.48, + 245.57, + 276.56, + 300.69, + 315.98, + 324.06, + 328.17, + 330.76, + 332.29, + 331.78, + 328.62, + 323.95, + 320.4, + 320.06, + 321.98, + 320.97, + 308.8, + 277.58, + 228.89, + 211.58, + 193.09, + 180.66, + 180.38, + 194.76, + 221.52, + 254.41, + 285.92, + 310.33, + 325.71, + 333.69, + 337.44, + 339.35, + 339.79, + 337.84, + 333.07, + 326.89, + 322.18, + 321.12, + 322.62, + 321.14, + 308.02, + 275.04, + 226.62, + 208.9, + 190.91, + 179.59, + 180.54, + 195.75, + 222.62, + 254.92, + 285.45, + 308.91, + 323.71, + 331.58, + 335.58, + 337.83, + 338.47, + 336.46, + 331.43, + 324.91, + 319.91, + 318.63, + 319.91, + 318.05, + 304.27, + 270.31, + 220.88, + 202.94, + 185.9, + 176.17, + 178.56, + 194.36, + 220.62, + 251.31, + 279.83, + 301.61, + 315.58, + 323.64, + 328.59, + 332.13, + 334.02, + 333.01, + 328.67, + 322.59, + 317.79, + 316.49, + 317.49, + 315.17, + 300.91, + 266.66, + 215.01, + 197.33, + 181.72, + 173.93, + 177.86, + 194.01, + 219.16, + 247.65, + 273.68, + 293.62, + 306.99, + 315.73, + 322.31, + 327.86, + 331.65, + 332.18, + 328.96, + 323.58, + 319.06, + 317.57, + 318.01, + 314.99, + 300.29, + 266.31, + 211.08, + 194.33, + 180.62, + 174.97, + 180.35, + 196.57, + 220.29, + 246.36, + 269.94, + 288.33, + 301.58, + 311.6, + 320.31, + 328.22, + 334.09, + 336.21, + 334.05, + 329.23, + 324.7, + 322.64, + 322.07, + 318.01, + 302.85, + 269.56, + 210.47, + ], + [ + 379.51, + 379.84, + 395.91, + 427.92, + 473.69, + 529.07, + 588.7, + 647.08, + 699.6, + 743.3, + 777.29, + 802.52, + 821.09, + 835.21, + 846.12, + 853.17, + 853.52, + 842.42, + 814.09, + 763.04, + 685.5, + 580.62, + 451.07, + 410.65, + 399.71, + 402.4, + 421.14, + 455.94, + 504.31, + 561.69, + 622.3, + 680.28, + 730.75, + 770.76, + 799.62, + 818.74, + 830.83, + 838.79, + 844.48, + 847.73, + 845.89, + 834.15, + 806.4, + 756.67, + 680.62, + 576.86, + 447.69, + 414.98, + 404.12, + 407.11, + 426.47, + 462.19, + 511.7, + 570.26, + 631.84, + 690.32, + 740.61, + 779.61, + 806.68, + 823.39, + 832.79, + 838.15, + 841.71, + 843.6, + 841.34, + 830.09, + 803.55, + 755.43, + 681.02, + 578.58, + 450.14, + 410.03, + 397.89, + 399.64, + 418.02, + 453.23, + 502.83, + 562.11, + 624.99, + 685.17, + 737.35, + 778.23, + 806.94, + 824.96, + 835.31, + 841.33, + 845.37, + 847.71, + 845.99, + 835.41, + 809.68, + 762.41, + 688.78, + 586.93, + 458.78, + 401.4, + 387.35, + 387.06, + 403.58, + 437.47, + 486.6, + 546.53, + 611.24, + 674.4, + 730.49, + 775.86, + 809.27, + 831.79, + 846.08, + 855.24, + 861.51, + 865.15, + 863.91, + 853.15, + 826.78, + 778.64, + 704.06, + 601.31, + 472.38, + 394.53, + 378.79, + 376.46, + 391, + 423.41, + 472.04, + 532.82, + 600, + 667.24, + 728.84, + 780.73, + 821.08, + 850.34, + 870.52, + 884.23, + 893.36, + 898.07, + 896.23, + 883.52, + 854.27, + 802.83, + 725, + 619.43, + 488.37, + 393.66, + 377.18, + 373.6, + 386.72, + 418.08, + 466.56, + 528.58, + 598.73, + 670.77, + 738.81, + 798.32, + 846.78, + 883.81, + 910.59, + 929.03, + 940.55, + 945.15, + 940.84, + 923.78, + 888.98, + 831.54, + 748.02, + 637.69, + 503.18, + 392.42, + ], + [ + 41.737, + 33.076, + 24.388, + 16.425, + 10.052, + 6.4268, + 6.8474, + 12.206, + 22.299, + 35.43, + 48.68, + 58.883, + 63.898, + 63.561, + 59.771, + 55.585, + 53.758, + 55.447, + 59.744, + 64.226, + 66.185, + 63.872, + 57.142, + 51.474, + 43.482, + 34.651, + 25.854, + 17.828, + 11.468, + 8.0229, + 8.9065, + 15.065, + 26.203, + 40.365, + 54.275, + 64.445, + 68.619, + 66.805, + 61.329, + 55.764, + 53.26, + 55.074, + 60.078, + 65.417, + 67.947, + 65.656, + 58.396, + 51.897, + 43.901, + 35.098, + 26.31, + 18.242, + 11.813, + 8.3517, + 9.3741, + 15.887, + 27.558, + 42.285, + 56.579, + 66.77, + 70.543, + 68.02, + 61.774, + 55.659, + 53.001, + 55.044, + 60.453, + 66.077, + 68.523, + 65.69, + 57.541, + 51.467, + 43.599, + 34.896, + 26.14, + 18.03, + 11.515, + 7.9791, + 8.9882, + 15.567, + 27.362, + 42.223, + 56.617, + 66.867, + 70.691, + 68.278, + 62.263, + 56.511, + 54.28, + 56.674, + 62.18, + 67.505, + 69.216, + 65.307, + 55.944, + 50.993, + 43.278, + 34.645, + 25.895, + 17.776, + 11.289, + 7.8312, + 8.9214, + 15.506, + 27.173, + 41.794, + 55.957, + 66.162, + 70.288, + 68.58, + 63.578, + 58.925, + 57.59, + 60.424, + 65.773, + 70.355, + 70.884, + 65.589, + 54.893, + 51.015, + 43.422, + 34.798, + 26.04, + 18.023, + 11.818, + 8.7764, + 10.249, + 16.969, + 28.395, + 42.451, + 55.976, + 65.827, + 70.181, + 69.366, + 65.74, + 62.552, + 62.332, + 65.609, + 70.634, + 74.263, + 73.492, + 66.863, + 55.035, + 51.552, + 44.063, + 35.433, + 26.729, + 19.01, + 13.424, + 11.212, + 13.437, + 20.492, + 31.624, + 44.814, + 57.244, + 66.265, + 70.467, + 70.32, + 67.965, + 66.186, + 67.005, + 70.615, + 75.233, + 77.94, + 76.092, + 68.545, + 56.119, + 49.486, + ], + [ + 75.287, + 73.758, + 72.659, + 72.035, + 71.962, + 72.496, + 73.613, + 75.186, + 76.995, + 78.785, + 80.34, + 81.53, + 82.325, + 82.764, + 82.909, + 82.804, + 82.474, + 81.942, + 81.26, + 80.525, + 79.851, + 79.319, + 78.904, + 81.595, + 81.221, + 81.035, + 80.886, + 80.722, + 80.593, + 80.601, + 80.828, + 81.284, + 81.896, + 82.541, + 83.098, + 83.49, + 83.696, + 83.727, + 83.595, + 83.288, + 82.784, + 82.081, + 81.234, + 80.368, + 79.638, + 79.167, + 78.954, + 81.439, + 81.784, + 82.216, + 82.507, + 82.564, + 82.436, + 82.264, + 82.193, + 82.307, + 82.599, + 82.99, + 83.378, + 83.682, + 83.857, + 83.881, + 83.732, + 83.374, + 82.773, + 81.934, + 80.937, + 79.941, + 79.139, + 78.685, + 78.595, + 78.829, + 79.15, + 79.549, + 79.794, + 79.799, + 79.638, + 79.482, + 79.503, + 79.802, + 80.367, + 81.098, + 81.857, + 82.52, + 83.002, + 83.256, + 83.251, + 82.957, + 82.36, + 81.493, + 80.463, + 79.451, + 78.668, + 78.274, + 78.283, + 76.309, + 76.127, + 76.043, + 75.873, + 75.58, + 75.277, + 75.155, + 75.391, + 76.059, + 77.103, + 78.362, + 79.636, + 80.745, + 81.574, + 82.07, + 82.22, + 82.031, + 81.527, + 80.773, + 79.89, + 79.055, + 78.46, + 78.234, + 78.365, + 76.287, + 75.462, + 74.722, + 73.96, + 73.203, + 72.608, + 72.384, + 72.692, + 73.569, + 74.899, + 76.456, + 77.981, + 79.261, + 80.174, + 80.69, + 80.844, + 80.699, + 80.328, + 79.815, + 79.271, + 78.83, + 78.618, + 78.695, + 78.994, + 80.444, + 79.175, + 77.895, + 76.574, + 75.301, + 74.273, + 73.708, + 73.754, + 74.41, + 75.518, + 76.813, + 78.012, + 78.905, + 79.401, + 79.532, + 79.402, + 79.144, + 78.872, + 78.674, + 78.616, + 78.746, + 79.081, + 79.576, + 80.082, + 77.239, + ], + ] + ) def load_exact_std(): - return nd.array([ - [ - 62.202, - 61.116, - 60.813, - 60.657, - 60.515, - 60.403, - 60.333, - 60.297, - 60.282, - 60.276, - 60.275, - 60.275, - 60.276, - 60.282, - 60.297, - 60.333, - 60.403, - 60.515, - 60.657, - 60.813, - 61.116, - 62.202, - 65.704, - 62.155, - 59.724, - 59.199, - 59.129, - 59.061, - 58.969, - 58.9, - 58.866, - 58.85, - 58.838, - 58.826, - 58.819, - 58.819, - 58.826, - 58.838, - 58.85, - 58.866, - 58.9, - 58.969, - 59.061, - 59.129, - 59.199, - 59.724, - 62.155, - 61.374, - 59.422, - 59.048, - 58.973, - 58.87, - 58.764, - 58.706, - 58.692, - 58.689, - 58.68, - 58.665, - 58.654, - 58.654, - 58.665, - 58.68, - 58.689, - 58.692, - 58.706, - 58.764, - 58.87, - 58.973, - 59.048, - 59.422, - 61.374, - 61.321, - 59.489, - 59.128, - 59.023, - 58.892, - 58.773, - 58.716, - 58.708, - 58.711, - 58.704, - 58.688, - 58.675, - 58.675, - 58.688, - 58.704, - 58.711, - 58.708, - 58.716, - 58.773, - 58.892, - 59.023, - 59.128, - 59.489, - 61.321, - 61.374, - 59.422, - 59.048, - 58.973, - 58.87, - 58.764, - 58.706, - 58.692, - 58.689, - 58.68, - 58.665, - 58.654, - 58.654, - 58.665, - 58.68, - 58.689, - 58.692, - 58.706, - 58.764, - 58.87, - 58.973, - 59.048, - 59.422, - 61.374, - 62.155, - 59.724, - 59.199, - 59.129, - 59.061, - 58.969, - 58.9, - 58.866, - 58.85, - 58.838, - 58.826, - 58.819, - 58.819, - 58.826, - 58.838, - 58.85, - 58.866, - 58.9, - 58.969, - 59.061, - 59.129, - 59.199, - 59.724, - 62.155, - 65.704, - 62.202, - 61.116, - 60.813, - 60.657, - 60.515, - 60.403, - 60.333, - 60.297, - 60.282, - 60.276, - 60.275, - 60.275, - 60.276, - 60.282, - 60.297, - 60.333, - 60.403, - 60.515, - 60.657, - 60.813, - 61.116, - 62.202, - 65.704, - 65.704, - ], - [ - 4.6693, - 4.6057, - 4.5913, - 4.5803, - 4.5721, - 4.5684, - 4.5667, - 4.5655, - 4.5644, - 4.5637, - 4.5633, - 4.5633, - 4.5637, - 4.5644, - 4.5655, - 4.5667, - 4.5684, - 4.5721, - 4.5803, - 4.5913, - 4.6057, - 4.6693, - 4.9388, - 4.6701, - 4.4794, - 4.4525, - 4.4488, - 4.4409, - 4.4349, - 4.4328, - 4.4317, - 4.4306, - 4.4299, - 4.4295, - 4.4293, - 4.4293, - 4.4295, - 4.4299, - 4.4306, - 4.4317, - 4.4328, - 4.4349, - 4.4409, - 4.4488, - 4.4525, - 4.4794, - 4.6701, - 4.618, - 4.4607, - 4.44, - 4.4336, - 4.4245, - 4.4199, - 4.4191, - 4.4183, - 4.4169, - 4.4162, - 4.416, - 4.416, - 4.416, - 4.416, - 4.4162, - 4.4169, - 4.4183, - 4.4191, - 4.4199, - 4.4245, - 4.4336, - 4.44, - 4.4607, - 4.618, - 4.616, - 4.4654, - 4.4433, - 4.434, - 4.4239, - 4.4199, - 4.4198, - 4.419, - 4.4175, - 4.4167, - 4.4166, - 4.4166, - 4.4166, - 4.4166, - 4.4167, - 4.4175, - 4.419, - 4.4198, - 4.4199, - 4.4239, - 4.434, - 4.4433, - 4.4654, - 4.616, - 4.618, - 4.4607, - 4.44, - 4.4336, - 4.4245, - 4.4199, - 4.4191, - 4.4183, - 4.4169, - 4.4162, - 4.416, - 4.416, - 4.416, - 4.416, - 4.4162, - 4.4169, - 4.4183, - 4.4191, - 4.4199, - 4.4245, - 4.4336, - 4.44, - 4.4607, - 4.618, - 4.6701, - 4.4794, - 4.4525, - 4.4488, - 4.4409, - 4.4349, - 4.4328, - 4.4317, - 4.4306, - 4.4299, - 4.4295, - 4.4293, - 4.4293, - 4.4295, - 4.4299, - 4.4306, - 4.4317, - 4.4328, - 4.4349, - 4.4409, - 4.4488, - 4.4525, - 4.4794, - 4.6701, - 4.9388, - 4.6693, - 4.6057, - 4.5913, - 4.5803, - 4.5721, - 4.5684, - 4.5667, - 4.5655, - 4.5644, - 4.5637, - 4.5633, - 4.5633, - 4.5637, - 4.5644, - 4.5655, - 4.5667, - 4.5684, - 4.5721, - 4.5803, - 4.5913, - 4.6057, - 4.6693, - 4.9388, - 4.9388, - ], - [ - 10.6, - 10.437, - 10.384, - 10.335, - 10.304, - 10.29, - 10.28, - 10.271, - 10.264, - 10.261, - 10.259, - 10.259, - 10.261, - 10.264, - 10.271, - 10.28, - 10.29, - 10.304, - 10.335, - 10.384, - 10.437, - 10.6, - 11.514, - 10.618, - 10.069, - 10.028, - 10.002, - 9.9658, - 9.9464, - 9.9396, - 9.9341, - 9.9288, - 9.9256, - 9.9239, - 9.9228, - 9.9228, - 9.9239, - 9.9256, - 9.9288, - 9.9341, - 9.9396, - 9.9464, - 9.9658, - 10.002, - 10.028, - 10.069, - 10.618, - 10.505, - 10.051, - 10.003, - 9.9655, - 9.9329, - 9.9221, - 9.9184, - 9.9114, - 9.9045, - 9.9017, - 9.9012, - 9.9009, - 9.9009, - 9.9012, - 9.9017, - 9.9045, - 9.9114, - 9.9184, - 9.9221, - 9.9329, - 9.9655, - 10.003, - 10.051, - 10.505, - 10.492, - 10.043, - 9.9813, - 9.9381, - 9.9089, - 9.9031, - 9.9013, - 9.8936, - 9.8858, - 9.8829, - 9.8828, - 9.8828, - 9.8828, - 9.8828, - 9.8829, - 9.8858, - 9.8936, - 9.9013, - 9.9031, - 9.9089, - 9.9381, - 9.9813, - 10.043, - 10.492, - 10.505, - 10.051, - 10.003, - 9.9655, - 9.9329, - 9.9221, - 9.9184, - 9.9114, - 9.9045, - 9.9017, - 9.9012, - 9.9009, - 9.9009, - 9.9012, - 9.9017, - 9.9045, - 9.9114, - 9.9184, - 9.9221, - 9.9329, - 9.9655, - 10.003, - 10.051, - 10.505, - 10.618, - 10.069, - 10.028, - 10.002, - 9.9658, - 9.9464, - 9.9396, - 9.9341, - 9.9288, - 9.9256, - 9.9239, - 9.9228, - 9.9228, - 9.9239, - 9.9256, - 9.9288, - 9.9341, - 9.9396, - 9.9464, - 9.9658, - 10.002, - 10.028, - 10.069, - 10.618, - 11.514, - 10.6, - 10.437, - 10.384, - 10.335, - 10.304, - 10.29, - 10.28, - 10.271, - 10.264, - 10.261, - 10.259, - 10.259, - 10.261, - 10.264, - 10.271, - 10.28, - 10.29, - 10.304, - 10.335, - 10.384, - 10.437, - 10.6, - 11.514, - 11.514, - ], - [ - 5.8917, - 5.847, - 5.803, - 5.7866, - 5.7789, - 5.7729, - 5.7707, - 5.7687, - 5.7674, - 5.7671, - 5.7668, - 5.7668, - 5.7671, - 5.7674, - 5.7687, - 5.7707, - 5.7729, - 5.7789, - 5.7866, - 5.803, - 5.847, - 5.8917, - 6.4598, - 5.9005, - 5.5807, - 5.5595, - 5.5222, - 5.5167, - 5.5118, - 5.507, - 5.5059, - 5.5046, - 5.5038, - 5.5035, - 5.5033, - 5.5033, - 5.5035, - 5.5038, - 5.5046, - 5.5059, - 5.507, - 5.5118, - 5.5167, - 5.5222, - 5.5595, - 5.5807, - 5.9005, - 5.8645, - 5.566, - 5.5328, - 5.502, - 5.4976, - 5.4912, - 5.4879, - 5.487, - 5.4852, - 5.4848, - 5.4846, - 5.4843, - 5.4843, - 5.4846, - 5.4848, - 5.4852, - 5.487, - 5.4879, - 5.4912, - 5.4976, - 5.502, - 5.5328, - 5.566, - 5.8645, - 5.8339, - 5.534, - 5.507, - 5.4832, - 5.4774, - 5.4694, - 5.467, - 5.4663, - 5.4643, - 5.4639, - 5.4638, - 5.4634, - 5.4634, - 5.4638, - 5.4639, - 5.4643, - 5.4663, - 5.467, - 5.4694, - 5.4774, - 5.4832, - 5.507, - 5.534, - 5.8339, - 5.8645, - 5.566, - 5.5328, - 5.502, - 5.4976, - 5.4912, - 5.4879, - 5.487, - 5.4852, - 5.4848, - 5.4846, - 5.4843, - 5.4843, - 5.4846, - 5.4848, - 5.4852, - 5.487, - 5.4879, - 5.4912, - 5.4976, - 5.502, - 5.5328, - 5.566, - 5.8645, - 5.9005, - 5.5807, - 5.5595, - 5.5222, - 5.5167, - 5.5118, - 5.507, - 5.5059, - 5.5046, - 5.5038, - 5.5035, - 5.5033, - 5.5033, - 5.5035, - 5.5038, - 5.5046, - 5.5059, - 5.507, - 5.5118, - 5.5167, - 5.5222, - 5.5595, - 5.5807, - 5.9005, - 6.4598, - 5.8917, - 5.847, - 5.803, - 5.7866, - 5.7789, - 5.7729, - 5.7707, - 5.7687, - 5.7674, - 5.7671, - 5.7668, - 5.7668, - 5.7671, - 5.7674, - 5.7687, - 5.7707, - 5.7729, - 5.7789, - 5.7866, - 5.803, - 5.847, - 5.8917, - 6.4598, - 6.4598, - ], - [ - 5.6763, - 5.5919, - 5.567, - 5.5445, - 5.5302, - 5.5238, - 5.5196, - 5.5158, - 5.513, - 5.5114, - 5.5108, - 5.5108, - 5.5114, - 5.513, - 5.5158, - 5.5196, - 5.5238, - 5.5302, - 5.5445, - 5.567, - 5.5919, - 5.6763, - 6.1243, - 5.6824, - 5.4021, - 5.3783, - 5.3679, - 5.3509, - 5.3416, - 5.3385, - 5.336, - 5.3336, - 5.3322, - 5.3315, - 5.3309, - 5.3309, - 5.3315, - 5.3322, - 5.3336, - 5.336, - 5.3385, - 5.3416, - 5.3509, - 5.3679, - 5.3783, - 5.4021, - 5.6824, - 5.6208, - 5.3897, - 5.3653, - 5.3488, - 5.3329, - 5.3275, - 5.3261, - 5.3231, - 5.3201, - 5.3189, - 5.3187, - 5.3184, - 5.3184, - 5.3187, - 5.3189, - 5.3201, - 5.3231, - 5.3261, - 5.3275, - 5.3329, - 5.3488, - 5.3653, - 5.3897, - 5.6208, - 5.617, - 5.3904, - 5.3595, - 5.339, - 5.324, - 5.3207, - 5.3202, - 5.3169, - 5.3134, - 5.3122, - 5.3121, - 5.312, - 5.312, - 5.3121, - 5.3122, - 5.3134, - 5.3169, - 5.3202, - 5.3207, - 5.324, - 5.339, - 5.3595, - 5.3904, - 5.617, - 5.6208, - 5.3897, - 5.3653, - 5.3488, - 5.3329, - 5.3275, - 5.3261, - 5.3231, - 5.3201, - 5.3189, - 5.3187, - 5.3184, - 5.3184, - 5.3187, - 5.3189, - 5.3201, - 5.3231, - 5.3261, - 5.3275, - 5.3329, - 5.3488, - 5.3653, - 5.3897, - 5.6208, - 5.6824, - 5.4021, - 5.3783, - 5.3679, - 5.3509, - 5.3416, - 5.3385, - 5.336, - 5.3336, - 5.3322, - 5.3315, - 5.3309, - 5.3309, - 5.3315, - 5.3322, - 5.3336, - 5.336, - 5.3385, - 5.3416, - 5.3509, - 5.3679, - 5.3783, - 5.4021, - 5.6824, - 6.1243, - 5.6763, - 5.5919, - 5.567, - 5.5445, - 5.5302, - 5.5238, - 5.5196, - 5.5158, - 5.513, - 5.5114, - 5.5108, - 5.5108, - 5.5114, - 5.513, - 5.5158, - 5.5196, - 5.5238, - 5.5302, - 5.5445, - 5.567, - 5.5919, - 5.6763, - 6.1243, - 6.1243, - ], + return nd.array( [ - 15.404, - 15.259, - 15.168, - 15.109, - 15.087, - 15.07, - 15.059, - 15.055, - 15.051, - 15.049, - 15.047, - 15.047, - 15.049, - 15.051, - 15.055, - 15.059, - 15.07, - 15.087, - 15.109, - 15.168, - 15.259, - 15.404, - 16.812, - 15.43, - 14.606, - 14.573, - 14.493, - 14.455, - 14.448, - 14.437, - 14.429, - 14.426, - 14.424, - 14.422, - 14.422, - 14.422, - 14.422, - 14.424, - 14.426, - 14.429, - 14.437, - 14.448, - 14.455, - 14.493, - 14.573, - 14.606, - 15.43, - 15.313, - 14.587, - 14.524, - 14.447, - 14.422, - 14.413, - 14.399, - 14.395, - 14.394, - 14.39, - 14.388, - 14.389, - 14.389, - 14.388, - 14.39, - 14.394, - 14.395, - 14.399, - 14.413, - 14.422, - 14.447, - 14.524, - 14.587, - 15.313, - 15.264, - 14.528, - 14.456, - 14.393, - 14.377, - 14.366, - 14.349, - 14.346, - 14.345, - 14.341, - 14.339, - 14.34, - 14.34, - 14.339, - 14.341, - 14.345, - 14.346, - 14.349, - 14.366, - 14.377, - 14.393, - 14.456, - 14.528, - 15.264, - 15.313, - 14.587, - 14.524, - 14.447, - 14.422, - 14.413, - 14.399, - 14.395, - 14.394, - 14.39, - 14.388, - 14.389, - 14.389, - 14.388, - 14.39, - 14.394, - 14.395, - 14.399, - 14.413, - 14.422, - 14.447, - 14.524, - 14.587, - 15.313, - 15.43, - 14.606, - 14.573, - 14.493, - 14.455, - 14.448, - 14.437, - 14.429, - 14.426, - 14.424, - 14.422, - 14.422, - 14.422, - 14.422, - 14.424, - 14.426, - 14.429, - 14.437, - 14.448, - 14.455, - 14.493, - 14.573, - 14.606, - 15.43, - 16.812, - 15.404, - 15.259, - 15.168, - 15.109, - 15.087, - 15.07, - 15.059, - 15.055, - 15.051, - 15.049, - 15.047, - 15.047, - 15.049, - 15.051, - 15.055, - 15.059, - 15.07, - 15.087, - 15.109, - 15.168, - 15.259, - 15.404, - 16.812, - 16.812, - ], - [ - 11.994, - 11.873, - 11.788, - 11.733, - 11.712, - 11.695, - 11.683, - 11.678, - 11.675, - 11.672, - 11.67, - 11.67, - 11.672, - 11.675, - 11.678, - 11.683, - 11.695, - 11.712, - 11.733, - 11.788, - 11.873, - 11.994, - 13.178, - 12.023, - 11.358, - 11.328, - 11.255, - 11.221, - 11.214, - 11.204, - 11.196, - 11.194, - 11.191, - 11.189, - 11.189, - 11.189, - 11.189, - 11.191, - 11.194, - 11.196, - 11.204, - 11.214, - 11.221, - 11.255, - 11.328, - 11.358, - 12.023, - 11.93, - 11.341, - 11.283, - 11.215, - 11.193, - 11.185, - 11.171, - 11.166, - 11.165, - 11.162, - 11.16, - 11.16, - 11.16, - 11.16, - 11.162, - 11.165, - 11.166, - 11.171, - 11.185, - 11.193, - 11.215, - 11.283, - 11.341, - 11.93, - 11.881, - 11.284, - 11.222, - 11.168, - 11.154, - 11.144, - 11.128, - 11.124, - 11.124, - 11.12, - 11.118, - 11.118, - 11.118, - 11.118, - 11.12, - 11.124, - 11.124, - 11.128, - 11.144, - 11.154, - 11.168, - 11.222, - 11.284, - 11.881, - 11.93, - 11.341, - 11.283, - 11.215, - 11.193, - 11.185, - 11.171, - 11.166, - 11.165, - 11.162, - 11.16, - 11.16, - 11.16, - 11.16, - 11.162, - 11.165, - 11.166, - 11.171, - 11.185, - 11.193, - 11.215, - 11.283, - 11.341, - 11.93, - 12.023, - 11.358, - 11.328, - 11.255, - 11.221, - 11.214, - 11.204, - 11.196, - 11.194, - 11.191, - 11.189, - 11.189, - 11.189, - 11.189, - 11.191, - 11.194, - 11.196, - 11.204, - 11.214, - 11.221, - 11.255, - 11.328, - 11.358, - 12.023, - 13.178, - 11.994, - 11.873, - 11.788, - 11.733, - 11.712, - 11.695, - 11.683, - 11.678, - 11.675, - 11.672, - 11.67, - 11.67, - 11.672, - 11.675, - 11.678, - 11.683, - 11.695, - 11.712, - 11.733, - 11.788, - 11.873, - 11.994, - 13.178, - 13.178, - ], - [ - 45.261, - 44.52, - 44.323, - 44.162, - 44.037, - 43.973, - 43.944, - 43.922, - 43.9, - 43.881, - 43.871, - 43.871, - 43.881, - 43.9, - 43.922, - 43.944, - 43.973, - 44.037, - 44.162, - 44.323, - 44.52, - 45.261, - 48.529, - 45.281, - 43.171, - 42.914, - 42.863, - 42.755, - 42.67, - 42.636, - 42.62, - 42.603, - 42.588, - 42.581, - 42.58, - 42.58, - 42.581, - 42.588, - 42.603, - 42.62, - 42.636, - 42.67, - 42.755, - 42.863, - 42.914, - 43.171, - 45.281, - 44.753, - 43.038, - 42.82, - 42.723, - 42.601, - 42.538, - 42.526, - 42.517, - 42.495, - 42.476, - 42.471, - 42.475, - 42.475, - 42.471, - 42.476, - 42.495, - 42.517, - 42.526, - 42.538, - 42.601, - 42.723, - 42.82, - 43.038, - 44.753, - 44.736, - 43.082, - 42.828, - 42.692, - 42.561, - 42.506, - 42.505, - 42.499, - 42.475, - 42.454, - 42.449, - 42.453, - 42.453, - 42.449, - 42.454, - 42.475, - 42.499, - 42.505, - 42.506, - 42.561, - 42.692, - 42.828, - 43.082, - 44.736, - 44.753, - 43.038, - 42.82, - 42.723, - 42.601, - 42.538, - 42.526, - 42.517, - 42.495, - 42.476, - 42.471, - 42.475, - 42.475, - 42.471, - 42.476, - 42.495, - 42.517, - 42.526, - 42.538, - 42.601, - 42.723, - 42.82, - 43.038, - 44.753, - 45.281, - 43.171, - 42.914, - 42.863, - 42.755, - 42.67, - 42.636, - 42.62, - 42.603, - 42.588, - 42.581, - 42.58, - 42.58, - 42.581, - 42.588, - 42.603, - 42.62, - 42.636, - 42.67, - 42.755, - 42.863, - 42.914, - 43.171, - 45.281, - 48.529, - 45.261, - 44.52, - 44.323, - 44.162, - 44.037, - 43.973, - 43.944, - 43.922, - 43.9, - 43.881, - 43.871, - 43.871, - 43.881, - 43.9, - 43.922, - 43.944, - 43.973, - 44.037, - 44.162, - 44.323, - 44.52, - 45.261, - 48.529, - 48.529, - ], - [ - 4.5187, - 4.4744, - 4.4501, - 4.433, - 4.4268, - 4.4222, - 4.4187, - 4.4174, - 4.4167, - 4.4159, - 4.4154, - 4.4154, - 4.4159, - 4.4167, - 4.4174, - 4.4187, - 4.4222, - 4.4268, - 4.433, - 4.4501, - 4.4744, - 4.5187, - 4.9177, - 4.5256, - 4.2875, - 4.2783, - 4.2574, - 4.2455, - 4.2434, - 4.2405, - 4.2382, - 4.2374, - 4.2368, - 4.2362, - 4.2361, - 4.2361, - 4.2362, - 4.2368, - 4.2374, - 4.2382, - 4.2405, - 4.2434, - 4.2455, - 4.2574, - 4.2783, - 4.2875, - 4.5256, - 4.4897, - 4.2819, - 4.2649, - 4.2439, - 4.2359, - 4.2338, - 4.2301, - 4.2283, - 4.2281, - 4.2273, - 4.2267, - 4.2267, - 4.2267, - 4.2267, - 4.2273, - 4.2281, - 4.2283, - 4.2301, - 4.2338, - 4.2359, - 4.2439, - 4.2649, - 4.2819, - 4.4897, - 4.4781, - 4.2679, - 4.2471, - 4.2289, - 4.2238, - 4.2214, - 4.2168, - 4.2152, - 4.2153, - 4.2144, - 4.2137, - 4.2138, - 4.2138, - 4.2137, - 4.2144, - 4.2153, - 4.2152, - 4.2168, - 4.2214, - 4.2238, - 4.2289, - 4.2471, - 4.2679, - 4.4781, - 4.4897, - 4.2819, - 4.2649, - 4.2439, - 4.2359, - 4.2338, - 4.2301, - 4.2283, - 4.2281, - 4.2273, - 4.2267, - 4.2267, - 4.2267, - 4.2267, - 4.2273, - 4.2281, - 4.2283, - 4.2301, - 4.2338, - 4.2359, - 4.2439, - 4.2649, - 4.2819, - 4.4897, - 4.5256, - 4.2875, - 4.2783, - 4.2574, - 4.2455, - 4.2434, - 4.2405, - 4.2382, - 4.2374, - 4.2368, - 4.2362, - 4.2361, - 4.2361, - 4.2362, - 4.2368, - 4.2374, - 4.2382, - 4.2405, - 4.2434, - 4.2455, - 4.2574, - 4.2783, - 4.2875, - 4.5256, - 4.9177, - 4.5187, - 4.4744, - 4.4501, - 4.433, - 4.4268, - 4.4222, - 4.4187, - 4.4174, - 4.4167, - 4.4159, - 4.4154, - 4.4154, - 4.4159, - 4.4167, - 4.4174, - 4.4187, - 4.4222, - 4.4268, - 4.433, - 4.4501, - 4.4744, - 4.5187, - 4.9177, - 4.9177, - ], - [ - 1.3486, - 1.3332, - 1.3187, - 1.3103, - 1.3068, - 1.3037, - 1.3015, - 1.3006, - 1.2999, - 1.2991, - 1.2986, - 1.2986, - 1.2991, - 1.2999, - 1.3006, - 1.3015, - 1.3037, - 1.3068, - 1.3103, - 1.3187, - 1.3332, - 1.3486, - 1.506, - 1.3566, - 1.2758, - 1.2693, - 1.2573, - 1.2531, - 1.2522, - 1.2503, - 1.2489, - 1.2484, - 1.2479, - 1.2475, - 1.2474, - 1.2474, - 1.2475, - 1.2479, - 1.2484, - 1.2489, - 1.2503, - 1.2522, - 1.2531, - 1.2573, - 1.2693, - 1.2758, - 1.3566, - 1.345, - 1.2709, - 1.2607, - 1.2507, - 1.248, - 1.2463, - 1.2442, - 1.2434, - 1.2432, - 1.2425, - 1.242, - 1.242, - 1.242, - 1.242, - 1.2425, - 1.2432, - 1.2434, - 1.2442, - 1.2463, - 1.248, - 1.2507, - 1.2607, - 1.2709, - 1.345, - 1.335, - 1.2613, - 1.2528, - 1.2457, - 1.2438, - 1.2416, - 1.239, - 1.2384, - 1.2384, - 1.2377, - 1.2372, - 1.2373, - 1.2373, - 1.2372, - 1.2377, - 1.2384, - 1.2384, - 1.239, - 1.2416, - 1.2438, - 1.2457, - 1.2528, - 1.2613, - 1.335, - 1.345, - 1.2709, - 1.2607, - 1.2507, - 1.248, - 1.2463, - 1.2442, - 1.2434, - 1.2432, - 1.2425, - 1.242, - 1.242, - 1.242, - 1.242, - 1.2425, - 1.2432, - 1.2434, - 1.2442, - 1.2463, - 1.248, - 1.2507, - 1.2607, - 1.2709, - 1.345, - 1.3566, - 1.2758, - 1.2693, - 1.2573, - 1.2531, - 1.2522, - 1.2503, - 1.2489, - 1.2484, - 1.2479, - 1.2475, - 1.2474, - 1.2474, - 1.2475, - 1.2479, - 1.2484, - 1.2489, - 1.2503, - 1.2522, - 1.2531, - 1.2573, - 1.2693, - 1.2758, - 1.3566, - 1.506, - 1.3486, - 1.3332, - 1.3187, - 1.3103, - 1.3068, - 1.3037, - 1.3015, - 1.3006, - 1.2999, - 1.2991, - 1.2986, - 1.2986, - 1.2991, - 1.2999, - 1.3006, - 1.3015, - 1.3037, - 1.3068, - 1.3103, - 1.3187, - 1.3332, - 1.3486, - 1.506, - 1.506, - ], - ]) + [ + 62.202, + 61.116, + 60.813, + 60.657, + 60.515, + 60.403, + 60.333, + 60.297, + 60.282, + 60.276, + 60.275, + 60.275, + 60.276, + 60.282, + 60.297, + 60.333, + 60.403, + 60.515, + 60.657, + 60.813, + 61.116, + 62.202, + 65.704, + 62.155, + 59.724, + 59.199, + 59.129, + 59.061, + 58.969, + 58.9, + 58.866, + 58.85, + 58.838, + 58.826, + 58.819, + 58.819, + 58.826, + 58.838, + 58.85, + 58.866, + 58.9, + 58.969, + 59.061, + 59.129, + 59.199, + 59.724, + 62.155, + 61.374, + 59.422, + 59.048, + 58.973, + 58.87, + 58.764, + 58.706, + 58.692, + 58.689, + 58.68, + 58.665, + 58.654, + 58.654, + 58.665, + 58.68, + 58.689, + 58.692, + 58.706, + 58.764, + 58.87, + 58.973, + 59.048, + 59.422, + 61.374, + 61.321, + 59.489, + 59.128, + 59.023, + 58.892, + 58.773, + 58.716, + 58.708, + 58.711, + 58.704, + 58.688, + 58.675, + 58.675, + 58.688, + 58.704, + 58.711, + 58.708, + 58.716, + 58.773, + 58.892, + 59.023, + 59.128, + 59.489, + 61.321, + 61.374, + 59.422, + 59.048, + 58.973, + 58.87, + 58.764, + 58.706, + 58.692, + 58.689, + 58.68, + 58.665, + 58.654, + 58.654, + 58.665, + 58.68, + 58.689, + 58.692, + 58.706, + 58.764, + 58.87, + 58.973, + 59.048, + 59.422, + 61.374, + 62.155, + 59.724, + 59.199, + 59.129, + 59.061, + 58.969, + 58.9, + 58.866, + 58.85, + 58.838, + 58.826, + 58.819, + 58.819, + 58.826, + 58.838, + 58.85, + 58.866, + 58.9, + 58.969, + 59.061, + 59.129, + 59.199, + 59.724, + 62.155, + 65.704, + 62.202, + 61.116, + 60.813, + 60.657, + 60.515, + 60.403, + 60.333, + 60.297, + 60.282, + 60.276, + 60.275, + 60.275, + 60.276, + 60.282, + 60.297, + 60.333, + 60.403, + 60.515, + 60.657, + 60.813, + 61.116, + 62.202, + 65.704, + 65.704, + ], + [ + 4.6693, + 4.6057, + 4.5913, + 4.5803, + 4.5721, + 4.5684, + 4.5667, + 4.5655, + 4.5644, + 4.5637, + 4.5633, + 4.5633, + 4.5637, + 4.5644, + 4.5655, + 4.5667, + 4.5684, + 4.5721, + 4.5803, + 4.5913, + 4.6057, + 4.6693, + 4.9388, + 4.6701, + 4.4794, + 4.4525, + 4.4488, + 4.4409, + 4.4349, + 4.4328, + 4.4317, + 4.4306, + 4.4299, + 4.4295, + 4.4293, + 4.4293, + 4.4295, + 4.4299, + 4.4306, + 4.4317, + 4.4328, + 4.4349, + 4.4409, + 4.4488, + 4.4525, + 4.4794, + 4.6701, + 4.618, + 4.4607, + 4.44, + 4.4336, + 4.4245, + 4.4199, + 4.4191, + 4.4183, + 4.4169, + 4.4162, + 4.416, + 4.416, + 4.416, + 4.416, + 4.4162, + 4.4169, + 4.4183, + 4.4191, + 4.4199, + 4.4245, + 4.4336, + 4.44, + 4.4607, + 4.618, + 4.616, + 4.4654, + 4.4433, + 4.434, + 4.4239, + 4.4199, + 4.4198, + 4.419, + 4.4175, + 4.4167, + 4.4166, + 4.4166, + 4.4166, + 4.4166, + 4.4167, + 4.4175, + 4.419, + 4.4198, + 4.4199, + 4.4239, + 4.434, + 4.4433, + 4.4654, + 4.616, + 4.618, + 4.4607, + 4.44, + 4.4336, + 4.4245, + 4.4199, + 4.4191, + 4.4183, + 4.4169, + 4.4162, + 4.416, + 4.416, + 4.416, + 4.416, + 4.4162, + 4.4169, + 4.4183, + 4.4191, + 4.4199, + 4.4245, + 4.4336, + 4.44, + 4.4607, + 4.618, + 4.6701, + 4.4794, + 4.4525, + 4.4488, + 4.4409, + 4.4349, + 4.4328, + 4.4317, + 4.4306, + 4.4299, + 4.4295, + 4.4293, + 4.4293, + 4.4295, + 4.4299, + 4.4306, + 4.4317, + 4.4328, + 4.4349, + 4.4409, + 4.4488, + 4.4525, + 4.4794, + 4.6701, + 4.9388, + 4.6693, + 4.6057, + 4.5913, + 4.5803, + 4.5721, + 4.5684, + 4.5667, + 4.5655, + 4.5644, + 4.5637, + 4.5633, + 4.5633, + 4.5637, + 4.5644, + 4.5655, + 4.5667, + 4.5684, + 4.5721, + 4.5803, + 4.5913, + 4.6057, + 4.6693, + 4.9388, + 4.9388, + ], + [ + 10.6, + 10.437, + 10.384, + 10.335, + 10.304, + 10.29, + 10.28, + 10.271, + 10.264, + 10.261, + 10.259, + 10.259, + 10.261, + 10.264, + 10.271, + 10.28, + 10.29, + 10.304, + 10.335, + 10.384, + 10.437, + 10.6, + 11.514, + 10.618, + 10.069, + 10.028, + 10.002, + 9.9658, + 9.9464, + 9.9396, + 9.9341, + 9.9288, + 9.9256, + 9.9239, + 9.9228, + 9.9228, + 9.9239, + 9.9256, + 9.9288, + 9.9341, + 9.9396, + 9.9464, + 9.9658, + 10.002, + 10.028, + 10.069, + 10.618, + 10.505, + 10.051, + 10.003, + 9.9655, + 9.9329, + 9.9221, + 9.9184, + 9.9114, + 9.9045, + 9.9017, + 9.9012, + 9.9009, + 9.9009, + 9.9012, + 9.9017, + 9.9045, + 9.9114, + 9.9184, + 9.9221, + 9.9329, + 9.9655, + 10.003, + 10.051, + 10.505, + 10.492, + 10.043, + 9.9813, + 9.9381, + 9.9089, + 9.9031, + 9.9013, + 9.8936, + 9.8858, + 9.8829, + 9.8828, + 9.8828, + 9.8828, + 9.8828, + 9.8829, + 9.8858, + 9.8936, + 9.9013, + 9.9031, + 9.9089, + 9.9381, + 9.9813, + 10.043, + 10.492, + 10.505, + 10.051, + 10.003, + 9.9655, + 9.9329, + 9.9221, + 9.9184, + 9.9114, + 9.9045, + 9.9017, + 9.9012, + 9.9009, + 9.9009, + 9.9012, + 9.9017, + 9.9045, + 9.9114, + 9.9184, + 9.9221, + 9.9329, + 9.9655, + 10.003, + 10.051, + 10.505, + 10.618, + 10.069, + 10.028, + 10.002, + 9.9658, + 9.9464, + 9.9396, + 9.9341, + 9.9288, + 9.9256, + 9.9239, + 9.9228, + 9.9228, + 9.9239, + 9.9256, + 9.9288, + 9.9341, + 9.9396, + 9.9464, + 9.9658, + 10.002, + 10.028, + 10.069, + 10.618, + 11.514, + 10.6, + 10.437, + 10.384, + 10.335, + 10.304, + 10.29, + 10.28, + 10.271, + 10.264, + 10.261, + 10.259, + 10.259, + 10.261, + 10.264, + 10.271, + 10.28, + 10.29, + 10.304, + 10.335, + 10.384, + 10.437, + 10.6, + 11.514, + 11.514, + ], + [ + 5.8917, + 5.847, + 5.803, + 5.7866, + 5.7789, + 5.7729, + 5.7707, + 5.7687, + 5.7674, + 5.7671, + 5.7668, + 5.7668, + 5.7671, + 5.7674, + 5.7687, + 5.7707, + 5.7729, + 5.7789, + 5.7866, + 5.803, + 5.847, + 5.8917, + 6.4598, + 5.9005, + 5.5807, + 5.5595, + 5.5222, + 5.5167, + 5.5118, + 5.507, + 5.5059, + 5.5046, + 5.5038, + 5.5035, + 5.5033, + 5.5033, + 5.5035, + 5.5038, + 5.5046, + 5.5059, + 5.507, + 5.5118, + 5.5167, + 5.5222, + 5.5595, + 5.5807, + 5.9005, + 5.8645, + 5.566, + 5.5328, + 5.502, + 5.4976, + 5.4912, + 5.4879, + 5.487, + 5.4852, + 5.4848, + 5.4846, + 5.4843, + 5.4843, + 5.4846, + 5.4848, + 5.4852, + 5.487, + 5.4879, + 5.4912, + 5.4976, + 5.502, + 5.5328, + 5.566, + 5.8645, + 5.8339, + 5.534, + 5.507, + 5.4832, + 5.4774, + 5.4694, + 5.467, + 5.4663, + 5.4643, + 5.4639, + 5.4638, + 5.4634, + 5.4634, + 5.4638, + 5.4639, + 5.4643, + 5.4663, + 5.467, + 5.4694, + 5.4774, + 5.4832, + 5.507, + 5.534, + 5.8339, + 5.8645, + 5.566, + 5.5328, + 5.502, + 5.4976, + 5.4912, + 5.4879, + 5.487, + 5.4852, + 5.4848, + 5.4846, + 5.4843, + 5.4843, + 5.4846, + 5.4848, + 5.4852, + 5.487, + 5.4879, + 5.4912, + 5.4976, + 5.502, + 5.5328, + 5.566, + 5.8645, + 5.9005, + 5.5807, + 5.5595, + 5.5222, + 5.5167, + 5.5118, + 5.507, + 5.5059, + 5.5046, + 5.5038, + 5.5035, + 5.5033, + 5.5033, + 5.5035, + 5.5038, + 5.5046, + 5.5059, + 5.507, + 5.5118, + 5.5167, + 5.5222, + 5.5595, + 5.5807, + 5.9005, + 6.4598, + 5.8917, + 5.847, + 5.803, + 5.7866, + 5.7789, + 5.7729, + 5.7707, + 5.7687, + 5.7674, + 5.7671, + 5.7668, + 5.7668, + 5.7671, + 5.7674, + 5.7687, + 5.7707, + 5.7729, + 5.7789, + 5.7866, + 5.803, + 5.847, + 5.8917, + 6.4598, + 6.4598, + ], + [ + 5.6763, + 5.5919, + 5.567, + 5.5445, + 5.5302, + 5.5238, + 5.5196, + 5.5158, + 5.513, + 5.5114, + 5.5108, + 5.5108, + 5.5114, + 5.513, + 5.5158, + 5.5196, + 5.5238, + 5.5302, + 5.5445, + 5.567, + 5.5919, + 5.6763, + 6.1243, + 5.6824, + 5.4021, + 5.3783, + 5.3679, + 5.3509, + 5.3416, + 5.3385, + 5.336, + 5.3336, + 5.3322, + 5.3315, + 5.3309, + 5.3309, + 5.3315, + 5.3322, + 5.3336, + 5.336, + 5.3385, + 5.3416, + 5.3509, + 5.3679, + 5.3783, + 5.4021, + 5.6824, + 5.6208, + 5.3897, + 5.3653, + 5.3488, + 5.3329, + 5.3275, + 5.3261, + 5.3231, + 5.3201, + 5.3189, + 5.3187, + 5.3184, + 5.3184, + 5.3187, + 5.3189, + 5.3201, + 5.3231, + 5.3261, + 5.3275, + 5.3329, + 5.3488, + 5.3653, + 5.3897, + 5.6208, + 5.617, + 5.3904, + 5.3595, + 5.339, + 5.324, + 5.3207, + 5.3202, + 5.3169, + 5.3134, + 5.3122, + 5.3121, + 5.312, + 5.312, + 5.3121, + 5.3122, + 5.3134, + 5.3169, + 5.3202, + 5.3207, + 5.324, + 5.339, + 5.3595, + 5.3904, + 5.617, + 5.6208, + 5.3897, + 5.3653, + 5.3488, + 5.3329, + 5.3275, + 5.3261, + 5.3231, + 5.3201, + 5.3189, + 5.3187, + 5.3184, + 5.3184, + 5.3187, + 5.3189, + 5.3201, + 5.3231, + 5.3261, + 5.3275, + 5.3329, + 5.3488, + 5.3653, + 5.3897, + 5.6208, + 5.6824, + 5.4021, + 5.3783, + 5.3679, + 5.3509, + 5.3416, + 5.3385, + 5.336, + 5.3336, + 5.3322, + 5.3315, + 5.3309, + 5.3309, + 5.3315, + 5.3322, + 5.3336, + 5.336, + 5.3385, + 5.3416, + 5.3509, + 5.3679, + 5.3783, + 5.4021, + 5.6824, + 6.1243, + 5.6763, + 5.5919, + 5.567, + 5.5445, + 5.5302, + 5.5238, + 5.5196, + 5.5158, + 5.513, + 5.5114, + 5.5108, + 5.5108, + 5.5114, + 5.513, + 5.5158, + 5.5196, + 5.5238, + 5.5302, + 5.5445, + 5.567, + 5.5919, + 5.6763, + 6.1243, + 6.1243, + ], + [ + 15.404, + 15.259, + 15.168, + 15.109, + 15.087, + 15.07, + 15.059, + 15.055, + 15.051, + 15.049, + 15.047, + 15.047, + 15.049, + 15.051, + 15.055, + 15.059, + 15.07, + 15.087, + 15.109, + 15.168, + 15.259, + 15.404, + 16.812, + 15.43, + 14.606, + 14.573, + 14.493, + 14.455, + 14.448, + 14.437, + 14.429, + 14.426, + 14.424, + 14.422, + 14.422, + 14.422, + 14.422, + 14.424, + 14.426, + 14.429, + 14.437, + 14.448, + 14.455, + 14.493, + 14.573, + 14.606, + 15.43, + 15.313, + 14.587, + 14.524, + 14.447, + 14.422, + 14.413, + 14.399, + 14.395, + 14.394, + 14.39, + 14.388, + 14.389, + 14.389, + 14.388, + 14.39, + 14.394, + 14.395, + 14.399, + 14.413, + 14.422, + 14.447, + 14.524, + 14.587, + 15.313, + 15.264, + 14.528, + 14.456, + 14.393, + 14.377, + 14.366, + 14.349, + 14.346, + 14.345, + 14.341, + 14.339, + 14.34, + 14.34, + 14.339, + 14.341, + 14.345, + 14.346, + 14.349, + 14.366, + 14.377, + 14.393, + 14.456, + 14.528, + 15.264, + 15.313, + 14.587, + 14.524, + 14.447, + 14.422, + 14.413, + 14.399, + 14.395, + 14.394, + 14.39, + 14.388, + 14.389, + 14.389, + 14.388, + 14.39, + 14.394, + 14.395, + 14.399, + 14.413, + 14.422, + 14.447, + 14.524, + 14.587, + 15.313, + 15.43, + 14.606, + 14.573, + 14.493, + 14.455, + 14.448, + 14.437, + 14.429, + 14.426, + 14.424, + 14.422, + 14.422, + 14.422, + 14.422, + 14.424, + 14.426, + 14.429, + 14.437, + 14.448, + 14.455, + 14.493, + 14.573, + 14.606, + 15.43, + 16.812, + 15.404, + 15.259, + 15.168, + 15.109, + 15.087, + 15.07, + 15.059, + 15.055, + 15.051, + 15.049, + 15.047, + 15.047, + 15.049, + 15.051, + 15.055, + 15.059, + 15.07, + 15.087, + 15.109, + 15.168, + 15.259, + 15.404, + 16.812, + 16.812, + ], + [ + 11.994, + 11.873, + 11.788, + 11.733, + 11.712, + 11.695, + 11.683, + 11.678, + 11.675, + 11.672, + 11.67, + 11.67, + 11.672, + 11.675, + 11.678, + 11.683, + 11.695, + 11.712, + 11.733, + 11.788, + 11.873, + 11.994, + 13.178, + 12.023, + 11.358, + 11.328, + 11.255, + 11.221, + 11.214, + 11.204, + 11.196, + 11.194, + 11.191, + 11.189, + 11.189, + 11.189, + 11.189, + 11.191, + 11.194, + 11.196, + 11.204, + 11.214, + 11.221, + 11.255, + 11.328, + 11.358, + 12.023, + 11.93, + 11.341, + 11.283, + 11.215, + 11.193, + 11.185, + 11.171, + 11.166, + 11.165, + 11.162, + 11.16, + 11.16, + 11.16, + 11.16, + 11.162, + 11.165, + 11.166, + 11.171, + 11.185, + 11.193, + 11.215, + 11.283, + 11.341, + 11.93, + 11.881, + 11.284, + 11.222, + 11.168, + 11.154, + 11.144, + 11.128, + 11.124, + 11.124, + 11.12, + 11.118, + 11.118, + 11.118, + 11.118, + 11.12, + 11.124, + 11.124, + 11.128, + 11.144, + 11.154, + 11.168, + 11.222, + 11.284, + 11.881, + 11.93, + 11.341, + 11.283, + 11.215, + 11.193, + 11.185, + 11.171, + 11.166, + 11.165, + 11.162, + 11.16, + 11.16, + 11.16, + 11.16, + 11.162, + 11.165, + 11.166, + 11.171, + 11.185, + 11.193, + 11.215, + 11.283, + 11.341, + 11.93, + 12.023, + 11.358, + 11.328, + 11.255, + 11.221, + 11.214, + 11.204, + 11.196, + 11.194, + 11.191, + 11.189, + 11.189, + 11.189, + 11.189, + 11.191, + 11.194, + 11.196, + 11.204, + 11.214, + 11.221, + 11.255, + 11.328, + 11.358, + 12.023, + 13.178, + 11.994, + 11.873, + 11.788, + 11.733, + 11.712, + 11.695, + 11.683, + 11.678, + 11.675, + 11.672, + 11.67, + 11.67, + 11.672, + 11.675, + 11.678, + 11.683, + 11.695, + 11.712, + 11.733, + 11.788, + 11.873, + 11.994, + 13.178, + 13.178, + ], + [ + 45.261, + 44.52, + 44.323, + 44.162, + 44.037, + 43.973, + 43.944, + 43.922, + 43.9, + 43.881, + 43.871, + 43.871, + 43.881, + 43.9, + 43.922, + 43.944, + 43.973, + 44.037, + 44.162, + 44.323, + 44.52, + 45.261, + 48.529, + 45.281, + 43.171, + 42.914, + 42.863, + 42.755, + 42.67, + 42.636, + 42.62, + 42.603, + 42.588, + 42.581, + 42.58, + 42.58, + 42.581, + 42.588, + 42.603, + 42.62, + 42.636, + 42.67, + 42.755, + 42.863, + 42.914, + 43.171, + 45.281, + 44.753, + 43.038, + 42.82, + 42.723, + 42.601, + 42.538, + 42.526, + 42.517, + 42.495, + 42.476, + 42.471, + 42.475, + 42.475, + 42.471, + 42.476, + 42.495, + 42.517, + 42.526, + 42.538, + 42.601, + 42.723, + 42.82, + 43.038, + 44.753, + 44.736, + 43.082, + 42.828, + 42.692, + 42.561, + 42.506, + 42.505, + 42.499, + 42.475, + 42.454, + 42.449, + 42.453, + 42.453, + 42.449, + 42.454, + 42.475, + 42.499, + 42.505, + 42.506, + 42.561, + 42.692, + 42.828, + 43.082, + 44.736, + 44.753, + 43.038, + 42.82, + 42.723, + 42.601, + 42.538, + 42.526, + 42.517, + 42.495, + 42.476, + 42.471, + 42.475, + 42.475, + 42.471, + 42.476, + 42.495, + 42.517, + 42.526, + 42.538, + 42.601, + 42.723, + 42.82, + 43.038, + 44.753, + 45.281, + 43.171, + 42.914, + 42.863, + 42.755, + 42.67, + 42.636, + 42.62, + 42.603, + 42.588, + 42.581, + 42.58, + 42.58, + 42.581, + 42.588, + 42.603, + 42.62, + 42.636, + 42.67, + 42.755, + 42.863, + 42.914, + 43.171, + 45.281, + 48.529, + 45.261, + 44.52, + 44.323, + 44.162, + 44.037, + 43.973, + 43.944, + 43.922, + 43.9, + 43.881, + 43.871, + 43.871, + 43.881, + 43.9, + 43.922, + 43.944, + 43.973, + 44.037, + 44.162, + 44.323, + 44.52, + 45.261, + 48.529, + 48.529, + ], + [ + 4.5187, + 4.4744, + 4.4501, + 4.433, + 4.4268, + 4.4222, + 4.4187, + 4.4174, + 4.4167, + 4.4159, + 4.4154, + 4.4154, + 4.4159, + 4.4167, + 4.4174, + 4.4187, + 4.4222, + 4.4268, + 4.433, + 4.4501, + 4.4744, + 4.5187, + 4.9177, + 4.5256, + 4.2875, + 4.2783, + 4.2574, + 4.2455, + 4.2434, + 4.2405, + 4.2382, + 4.2374, + 4.2368, + 4.2362, + 4.2361, + 4.2361, + 4.2362, + 4.2368, + 4.2374, + 4.2382, + 4.2405, + 4.2434, + 4.2455, + 4.2574, + 4.2783, + 4.2875, + 4.5256, + 4.4897, + 4.2819, + 4.2649, + 4.2439, + 4.2359, + 4.2338, + 4.2301, + 4.2283, + 4.2281, + 4.2273, + 4.2267, + 4.2267, + 4.2267, + 4.2267, + 4.2273, + 4.2281, + 4.2283, + 4.2301, + 4.2338, + 4.2359, + 4.2439, + 4.2649, + 4.2819, + 4.4897, + 4.4781, + 4.2679, + 4.2471, + 4.2289, + 4.2238, + 4.2214, + 4.2168, + 4.2152, + 4.2153, + 4.2144, + 4.2137, + 4.2138, + 4.2138, + 4.2137, + 4.2144, + 4.2153, + 4.2152, + 4.2168, + 4.2214, + 4.2238, + 4.2289, + 4.2471, + 4.2679, + 4.4781, + 4.4897, + 4.2819, + 4.2649, + 4.2439, + 4.2359, + 4.2338, + 4.2301, + 4.2283, + 4.2281, + 4.2273, + 4.2267, + 4.2267, + 4.2267, + 4.2267, + 4.2273, + 4.2281, + 4.2283, + 4.2301, + 4.2338, + 4.2359, + 4.2439, + 4.2649, + 4.2819, + 4.4897, + 4.5256, + 4.2875, + 4.2783, + 4.2574, + 4.2455, + 4.2434, + 4.2405, + 4.2382, + 4.2374, + 4.2368, + 4.2362, + 4.2361, + 4.2361, + 4.2362, + 4.2368, + 4.2374, + 4.2382, + 4.2405, + 4.2434, + 4.2455, + 4.2574, + 4.2783, + 4.2875, + 4.5256, + 4.9177, + 4.5187, + 4.4744, + 4.4501, + 4.433, + 4.4268, + 4.4222, + 4.4187, + 4.4174, + 4.4167, + 4.4159, + 4.4154, + 4.4154, + 4.4159, + 4.4167, + 4.4174, + 4.4187, + 4.4222, + 4.4268, + 4.433, + 4.4501, + 4.4744, + 4.5187, + 4.9177, + 4.9177, + ], + [ + 1.3486, + 1.3332, + 1.3187, + 1.3103, + 1.3068, + 1.3037, + 1.3015, + 1.3006, + 1.2999, + 1.2991, + 1.2986, + 1.2986, + 1.2991, + 1.2999, + 1.3006, + 1.3015, + 1.3037, + 1.3068, + 1.3103, + 1.3187, + 1.3332, + 1.3486, + 1.506, + 1.3566, + 1.2758, + 1.2693, + 1.2573, + 1.2531, + 1.2522, + 1.2503, + 1.2489, + 1.2484, + 1.2479, + 1.2475, + 1.2474, + 1.2474, + 1.2475, + 1.2479, + 1.2484, + 1.2489, + 1.2503, + 1.2522, + 1.2531, + 1.2573, + 1.2693, + 1.2758, + 1.3566, + 1.345, + 1.2709, + 1.2607, + 1.2507, + 1.248, + 1.2463, + 1.2442, + 1.2434, + 1.2432, + 1.2425, + 1.242, + 1.242, + 1.242, + 1.242, + 1.2425, + 1.2432, + 1.2434, + 1.2442, + 1.2463, + 1.248, + 1.2507, + 1.2607, + 1.2709, + 1.345, + 1.335, + 1.2613, + 1.2528, + 1.2457, + 1.2438, + 1.2416, + 1.239, + 1.2384, + 1.2384, + 1.2377, + 1.2372, + 1.2373, + 1.2373, + 1.2372, + 1.2377, + 1.2384, + 1.2384, + 1.239, + 1.2416, + 1.2438, + 1.2457, + 1.2528, + 1.2613, + 1.335, + 1.345, + 1.2709, + 1.2607, + 1.2507, + 1.248, + 1.2463, + 1.2442, + 1.2434, + 1.2432, + 1.2425, + 1.242, + 1.242, + 1.242, + 1.242, + 1.2425, + 1.2432, + 1.2434, + 1.2442, + 1.2463, + 1.248, + 1.2507, + 1.2607, + 1.2709, + 1.345, + 1.3566, + 1.2758, + 1.2693, + 1.2573, + 1.2531, + 1.2522, + 1.2503, + 1.2489, + 1.2484, + 1.2479, + 1.2475, + 1.2474, + 1.2474, + 1.2475, + 1.2479, + 1.2484, + 1.2489, + 1.2503, + 1.2522, + 1.2531, + 1.2573, + 1.2693, + 1.2758, + 1.3566, + 1.506, + 1.3486, + 1.3332, + 1.3187, + 1.3103, + 1.3068, + 1.3037, + 1.3015, + 1.3006, + 1.2999, + 1.2991, + 1.2986, + 1.2986, + 1.2991, + 1.2999, + 1.3006, + 1.3015, + 1.3037, + 1.3068, + 1.3103, + 1.3187, + 1.3332, + 1.3486, + 1.506, + 1.506, + ], + ] + ) def load_xfull(): - return nd.array([ - [0.0, 1.0], - [0.0, 2.0], - [0.0, 3.0], - [0.0, 4.0], - [0.0, 5.0], - [0.0, 6.0], - [0.0, 7.0], - [0.0, 8.0], - [0.0, 9.0], - [0.0, 10.0], - [0.0, 11.0], - [0.0, 12.0], - [0.0, 13.0], - [0.0, 14.0], - [0.0, 15.0], - [0.0, 16.0], - [0.0, 17.0], - [0.0, 18.0], - [0.0, 19.0], - [0.0, 20.0], - [0.0, 21.0], - [0.0, 22.0], - [0.0, 23.0], - [1.0, 0.0], - [1.0, 1.0], - [1.0, 2.0], - [1.0, 3.0], - [1.0, 4.0], - [1.0, 5.0], - [1.0, 6.0], - [1.0, 7.0], - [1.0, 8.0], - [1.0, 9.0], - [1.0, 10.0], - [1.0, 11.0], - [1.0, 12.0], - [1.0, 13.0], - [1.0, 14.0], - [1.0, 15.0], - [1.0, 16.0], - [1.0, 17.0], - [1.0, 18.0], - [1.0, 19.0], - [1.0, 20.0], - [1.0, 21.0], - [1.0, 22.0], - [1.0, 23.0], - [2.0, 0.0], - [2.0, 1.0], - [2.0, 2.0], - [2.0, 3.0], - [2.0, 4.0], - [2.0, 5.0], - [2.0, 6.0], - [2.0, 7.0], - [2.0, 8.0], - [2.0, 9.0], - [2.0, 10.0], - [2.0, 11.0], - [2.0, 12.0], - [2.0, 13.0], - [2.0, 14.0], - [2.0, 15.0], - [2.0, 16.0], - [2.0, 17.0], - [2.0, 18.0], - [2.0, 19.0], - [2.0, 20.0], - [2.0, 21.0], - [2.0, 22.0], - [2.0, 23.0], - [3.0, 0.0], - [3.0, 1.0], - [3.0, 2.0], - [3.0, 3.0], - [3.0, 4.0], - [3.0, 5.0], - [3.0, 6.0], - [3.0, 7.0], - [3.0, 8.0], - [3.0, 9.0], - [3.0, 10.0], - [3.0, 11.0], - [3.0, 12.0], - [3.0, 13.0], - [3.0, 14.0], - [3.0, 15.0], - [3.0, 16.0], - [3.0, 17.0], - [3.0, 18.0], - [3.0, 19.0], - [3.0, 20.0], - [3.0, 21.0], - [3.0, 22.0], - [3.0, 23.0], - [4.0, 0.0], - [4.0, 1.0], - [4.0, 2.0], - [4.0, 3.0], - [4.0, 4.0], - [4.0, 5.0], - [4.0, 6.0], - [4.0, 7.0], - [4.0, 8.0], - [4.0, 9.0], - [4.0, 10.0], - [4.0, 11.0], - [4.0, 12.0], - [4.0, 13.0], - [4.0, 14.0], - [4.0, 15.0], - [4.0, 16.0], - [4.0, 17.0], - [4.0, 18.0], - [4.0, 19.0], - [4.0, 20.0], - [4.0, 21.0], - [4.0, 22.0], - [4.0, 23.0], - [5.0, 0.0], - [5.0, 1.0], - [5.0, 2.0], - [5.0, 3.0], - [5.0, 4.0], - [5.0, 5.0], - [5.0, 6.0], - [5.0, 7.0], - [5.0, 8.0], - [5.0, 9.0], - [5.0, 10.0], - [5.0, 11.0], - [5.0, 12.0], - [5.0, 13.0], - [5.0, 14.0], - [5.0, 15.0], - [5.0, 16.0], - [5.0, 17.0], - [5.0, 18.0], - [5.0, 19.0], - [5.0, 20.0], - [5.0, 21.0], - [5.0, 22.0], - [5.0, 23.0], - [6.0, 0.0], - [6.0, 1.0], - [6.0, 2.0], - [6.0, 3.0], - [6.0, 4.0], - [6.0, 5.0], - [6.0, 6.0], - [6.0, 7.0], - [6.0, 8.0], - [6.0, 9.0], - [6.0, 10.0], - [6.0, 11.0], - [6.0, 12.0], - [6.0, 13.0], - [6.0, 14.0], - [6.0, 15.0], - [6.0, 16.0], - [6.0, 17.0], - [6.0, 18.0], - [6.0, 19.0], - [6.0, 20.0], - [6.0, 21.0], - [6.0, 22.0], - [6.0, 23.0], - [0.0, 0.0], - [0.0, 1.0], - [0.0, 2.0], - [0.0, 3.0], - [0.0, 4.0], - [0.0, 5.0], - [0.0, 6.0], - [0.0, 7.0], - [0.0, 8.0], - [0.0, 9.0], - [0.0, 10.0], - [0.0, 11.0], - [0.0, 12.0], - [0.0, 13.0], - [0.0, 14.0], - [0.0, 15.0], - [0.0, 16.0], - [0.0, 17.0], - [0.0, 18.0], - [0.0, 19.0], - [0.0, 20.0], - [0.0, 21.0], - [0.0, 22.0], - [0.0, 23.0], - [1.0, 0.0], - [1.0, 1.0], - [1.0, 2.0], - [1.0, 3.0], - [1.0, 4.0], - [1.0, 5.0], - [1.0, 6.0], - [1.0, 7.0], - [1.0, 8.0], - [1.0, 9.0], - [1.0, 10.0], - [1.0, 11.0], - [1.0, 12.0], - [1.0, 13.0], - [1.0, 14.0], - [1.0, 15.0], - [1.0, 16.0], - [1.0, 17.0], - [1.0, 18.0], - [1.0, 19.0], - [1.0, 20.0], - [1.0, 21.0], - [1.0, 22.0], - [1.0, 23.0], - [2.0, 0.0], - [2.0, 1.0], - [2.0, 2.0], - [2.0, 3.0], - [2.0, 4.0], - [2.0, 5.0], - [2.0, 6.0], - [2.0, 7.0], - [2.0, 8.0], - [2.0, 9.0], - [2.0, 10.0], - [2.0, 11.0], - [2.0, 12.0], - [2.0, 13.0], - [2.0, 14.0], - [2.0, 15.0], - [2.0, 16.0], - [2.0, 17.0], - [2.0, 18.0], - [2.0, 19.0], - [2.0, 20.0], - [2.0, 21.0], - [2.0, 22.0], - [2.0, 23.0], - [3.0, 0.0], - [3.0, 1.0], - [3.0, 2.0], - [3.0, 3.0], - [3.0, 4.0], - [3.0, 5.0], - [3.0, 6.0], - [3.0, 7.0], - [3.0, 8.0], - [3.0, 9.0], - [3.0, 10.0], - [3.0, 11.0], - [3.0, 12.0], - [3.0, 13.0], - [3.0, 14.0], - [3.0, 15.0], - [3.0, 16.0], - [3.0, 17.0], - [3.0, 18.0], - [3.0, 19.0], - [3.0, 20.0], - [3.0, 21.0], - [3.0, 22.0], - [3.0, 23.0], - [4.0, 0.0], - [4.0, 1.0], - [4.0, 2.0], - [4.0, 3.0], - [4.0, 4.0], - [4.0, 5.0], - [4.0, 6.0], - [4.0, 7.0], - [4.0, 8.0], - [4.0, 9.0], - [4.0, 10.0], - [4.0, 11.0], - [4.0, 12.0], - [4.0, 13.0], - [4.0, 14.0], - [4.0, 15.0], - [4.0, 16.0], - [4.0, 17.0], - [4.0, 18.0], - [4.0, 19.0], - [4.0, 20.0], - [4.0, 21.0], - [4.0, 22.0], - [4.0, 23.0], - [5.0, 0.0], - [5.0, 1.0], - [5.0, 2.0], - [5.0, 3.0], - [5.0, 4.0], - [5.0, 5.0], - [5.0, 6.0], - [5.0, 7.0], - [5.0, 8.0], - [5.0, 9.0], - [5.0, 10.0], - [5.0, 11.0], - [5.0, 12.0], - [5.0, 13.0], - [5.0, 14.0], - [5.0, 15.0], - [5.0, 16.0], - [5.0, 17.0], - [5.0, 18.0], - [5.0, 19.0], - [5.0, 20.0], - [5.0, 21.0], - [5.0, 22.0], - [5.0, 23.0], - [6.0, 0.0], - [6.0, 1.0], - [6.0, 2.0], - [6.0, 3.0], - [6.0, 4.0], - [6.0, 5.0], - [6.0, 6.0], - [6.0, 7.0], - [6.0, 8.0], - [6.0, 9.0], - [6.0, 10.0], - [6.0, 11.0], - [6.0, 12.0], - [6.0, 13.0], - [6.0, 14.0], - [6.0, 15.0], - [6.0, 16.0], - [6.0, 17.0], - [6.0, 18.0], - [6.0, 19.0], - [6.0, 20.0], - [6.0, 21.0], - [6.0, 22.0], - [6.0, 23.0], - [0.0, 0.0], - ]).expand_dims(axis=0) + return nd.array( + [ + [0.0, 1.0], + [0.0, 2.0], + [0.0, 3.0], + [0.0, 4.0], + [0.0, 5.0], + [0.0, 6.0], + [0.0, 7.0], + [0.0, 8.0], + [0.0, 9.0], + [0.0, 10.0], + [0.0, 11.0], + [0.0, 12.0], + [0.0, 13.0], + [0.0, 14.0], + [0.0, 15.0], + [0.0, 16.0], + [0.0, 17.0], + [0.0, 18.0], + [0.0, 19.0], + [0.0, 20.0], + [0.0, 21.0], + [0.0, 22.0], + [0.0, 23.0], + [1.0, 0.0], + [1.0, 1.0], + [1.0, 2.0], + [1.0, 3.0], + [1.0, 4.0], + [1.0, 5.0], + [1.0, 6.0], + [1.0, 7.0], + [1.0, 8.0], + [1.0, 9.0], + [1.0, 10.0], + [1.0, 11.0], + [1.0, 12.0], + [1.0, 13.0], + [1.0, 14.0], + [1.0, 15.0], + [1.0, 16.0], + [1.0, 17.0], + [1.0, 18.0], + [1.0, 19.0], + [1.0, 20.0], + [1.0, 21.0], + [1.0, 22.0], + [1.0, 23.0], + [2.0, 0.0], + [2.0, 1.0], + [2.0, 2.0], + [2.0, 3.0], + [2.0, 4.0], + [2.0, 5.0], + [2.0, 6.0], + [2.0, 7.0], + [2.0, 8.0], + [2.0, 9.0], + [2.0, 10.0], + [2.0, 11.0], + [2.0, 12.0], + [2.0, 13.0], + [2.0, 14.0], + [2.0, 15.0], + [2.0, 16.0], + [2.0, 17.0], + [2.0, 18.0], + [2.0, 19.0], + [2.0, 20.0], + [2.0, 21.0], + [2.0, 22.0], + [2.0, 23.0], + [3.0, 0.0], + [3.0, 1.0], + [3.0, 2.0], + [3.0, 3.0], + [3.0, 4.0], + [3.0, 5.0], + [3.0, 6.0], + [3.0, 7.0], + [3.0, 8.0], + [3.0, 9.0], + [3.0, 10.0], + [3.0, 11.0], + [3.0, 12.0], + [3.0, 13.0], + [3.0, 14.0], + [3.0, 15.0], + [3.0, 16.0], + [3.0, 17.0], + [3.0, 18.0], + [3.0, 19.0], + [3.0, 20.0], + [3.0, 21.0], + [3.0, 22.0], + [3.0, 23.0], + [4.0, 0.0], + [4.0, 1.0], + [4.0, 2.0], + [4.0, 3.0], + [4.0, 4.0], + [4.0, 5.0], + [4.0, 6.0], + [4.0, 7.0], + [4.0, 8.0], + [4.0, 9.0], + [4.0, 10.0], + [4.0, 11.0], + [4.0, 12.0], + [4.0, 13.0], + [4.0, 14.0], + [4.0, 15.0], + [4.0, 16.0], + [4.0, 17.0], + [4.0, 18.0], + [4.0, 19.0], + [4.0, 20.0], + [4.0, 21.0], + [4.0, 22.0], + [4.0, 23.0], + [5.0, 0.0], + [5.0, 1.0], + [5.0, 2.0], + [5.0, 3.0], + [5.0, 4.0], + [5.0, 5.0], + [5.0, 6.0], + [5.0, 7.0], + [5.0, 8.0], + [5.0, 9.0], + [5.0, 10.0], + [5.0, 11.0], + [5.0, 12.0], + [5.0, 13.0], + [5.0, 14.0], + [5.0, 15.0], + [5.0, 16.0], + [5.0, 17.0], + [5.0, 18.0], + [5.0, 19.0], + [5.0, 20.0], + [5.0, 21.0], + [5.0, 22.0], + [5.0, 23.0], + [6.0, 0.0], + [6.0, 1.0], + [6.0, 2.0], + [6.0, 3.0], + [6.0, 4.0], + [6.0, 5.0], + [6.0, 6.0], + [6.0, 7.0], + [6.0, 8.0], + [6.0, 9.0], + [6.0, 10.0], + [6.0, 11.0], + [6.0, 12.0], + [6.0, 13.0], + [6.0, 14.0], + [6.0, 15.0], + [6.0, 16.0], + [6.0, 17.0], + [6.0, 18.0], + [6.0, 19.0], + [6.0, 20.0], + [6.0, 21.0], + [6.0, 22.0], + [6.0, 23.0], + [0.0, 0.0], + [0.0, 1.0], + [0.0, 2.0], + [0.0, 3.0], + [0.0, 4.0], + [0.0, 5.0], + [0.0, 6.0], + [0.0, 7.0], + [0.0, 8.0], + [0.0, 9.0], + [0.0, 10.0], + [0.0, 11.0], + [0.0, 12.0], + [0.0, 13.0], + [0.0, 14.0], + [0.0, 15.0], + [0.0, 16.0], + [0.0, 17.0], + [0.0, 18.0], + [0.0, 19.0], + [0.0, 20.0], + [0.0, 21.0], + [0.0, 22.0], + [0.0, 23.0], + [1.0, 0.0], + [1.0, 1.0], + [1.0, 2.0], + [1.0, 3.0], + [1.0, 4.0], + [1.0, 5.0], + [1.0, 6.0], + [1.0, 7.0], + [1.0, 8.0], + [1.0, 9.0], + [1.0, 10.0], + [1.0, 11.0], + [1.0, 12.0], + [1.0, 13.0], + [1.0, 14.0], + [1.0, 15.0], + [1.0, 16.0], + [1.0, 17.0], + [1.0, 18.0], + [1.0, 19.0], + [1.0, 20.0], + [1.0, 21.0], + [1.0, 22.0], + [1.0, 23.0], + [2.0, 0.0], + [2.0, 1.0], + [2.0, 2.0], + [2.0, 3.0], + [2.0, 4.0], + [2.0, 5.0], + [2.0, 6.0], + [2.0, 7.0], + [2.0, 8.0], + [2.0, 9.0], + [2.0, 10.0], + [2.0, 11.0], + [2.0, 12.0], + [2.0, 13.0], + [2.0, 14.0], + [2.0, 15.0], + [2.0, 16.0], + [2.0, 17.0], + [2.0, 18.0], + [2.0, 19.0], + [2.0, 20.0], + [2.0, 21.0], + [2.0, 22.0], + [2.0, 23.0], + [3.0, 0.0], + [3.0, 1.0], + [3.0, 2.0], + [3.0, 3.0], + [3.0, 4.0], + [3.0, 5.0], + [3.0, 6.0], + [3.0, 7.0], + [3.0, 8.0], + [3.0, 9.0], + [3.0, 10.0], + [3.0, 11.0], + [3.0, 12.0], + [3.0, 13.0], + [3.0, 14.0], + [3.0, 15.0], + [3.0, 16.0], + [3.0, 17.0], + [3.0, 18.0], + [3.0, 19.0], + [3.0, 20.0], + [3.0, 21.0], + [3.0, 22.0], + [3.0, 23.0], + [4.0, 0.0], + [4.0, 1.0], + [4.0, 2.0], + [4.0, 3.0], + [4.0, 4.0], + [4.0, 5.0], + [4.0, 6.0], + [4.0, 7.0], + [4.0, 8.0], + [4.0, 9.0], + [4.0, 10.0], + [4.0, 11.0], + [4.0, 12.0], + [4.0, 13.0], + [4.0, 14.0], + [4.0, 15.0], + [4.0, 16.0], + [4.0, 17.0], + [4.0, 18.0], + [4.0, 19.0], + [4.0, 20.0], + [4.0, 21.0], + [4.0, 22.0], + [4.0, 23.0], + [5.0, 0.0], + [5.0, 1.0], + [5.0, 2.0], + [5.0, 3.0], + [5.0, 4.0], + [5.0, 5.0], + [5.0, 6.0], + [5.0, 7.0], + [5.0, 8.0], + [5.0, 9.0], + [5.0, 10.0], + [5.0, 11.0], + [5.0, 12.0], + [5.0, 13.0], + [5.0, 14.0], + [5.0, 15.0], + [5.0, 16.0], + [5.0, 17.0], + [5.0, 18.0], + [5.0, 19.0], + [5.0, 20.0], + [5.0, 21.0], + [5.0, 22.0], + [5.0, 23.0], + [6.0, 0.0], + [6.0, 1.0], + [6.0, 2.0], + [6.0, 3.0], + [6.0, 4.0], + [6.0, 5.0], + [6.0, 6.0], + [6.0, 7.0], + [6.0, 8.0], + [6.0, 9.0], + [6.0, 10.0], + [6.0, 11.0], + [6.0, 12.0], + [6.0, 13.0], + [6.0, 14.0], + [6.0, 15.0], + [6.0, 16.0], + [6.0, 17.0], + [6.0, 18.0], + [6.0, 19.0], + [6.0, 20.0], + [6.0, 21.0], + [6.0, 22.0], + [6.0, 23.0], + [0.0, 0.0], + ] + ).expand_dims(axis=0) def load_ytrain(): - return nd.array([ - [ - 3.392694091796875000e02, - 3.198630065917968750e02, - 3.210045776367187500e02, - 3.175798950195312500e02, - 3.130137023925781250e02, - 4.615182800292968750e02, - 5.757534179687500000e02, - 5.871689453125000000e02, - 6.362899780273437500e02, - 6.785616455078125000e02, - 7.677511596679687500e02, - 9.501369628906250000e02, - 1.041723754882812500e03, - 1.016529663085937500e03, - 9.112100219726562500e02, - 1.019965759277343750e03, - 1.019977172851562500e03, - 1.055456665039062500e03, - 1.033710083007812500e03, - 7.803538818359375000e02, - 7.322830810546875000e02, - 4.934931640625000000e02, - 3.872488708496093750e02, - 3.655593566894531250e02, - 3.747031860351562500e02, - 3.518378906250000000e02, - 3.621347045898437500e02, - 3.427054748535156250e02, - 3.461187133789062500e02, - 4.992009277343750000e02, - 6.488584594726562500e02, - 6.145890502929687500e02, - 6.888356323242187500e02, - 7.624086914062500000e02, - 7.517237548828125000e02, - 9.593036499023437500e02, - 9.810502319335937500e02, - 1.038287719726562500e03, - 9.249543457031250000e02, - 1.033710083007812500e03, - 1.079497680664062500e03, - 1.050879028320312500e03, - 9.673059082031250000e02, - 7.490639038085937500e02, - 7.319063720703125000e02, - 4.854908752441406250e02, - 3.826940612792968750e02, - 3.644178161621093750e02, - 3.667009277343750000e02, - 3.506963500976562500e02, - 3.438470458984375000e02, - 3.267123413085937500e02, - 3.347031860351562500e02, - 4.923515930175781250e02, - 6.385730590820312500e02, - 6.180136718750000000e02, - 6.766895141601562500e02, - 6.880822143554687500e02, - 7.643150634765625000e02, - 1.010810485839843750e03, - 1.053173461914062500e03, - 1.042876708984375000e03, - 9.570091552734375000e02, - 1.054303710937500000e03, - 1.118424682617187500e03, - 1.089805908203125000e03, - 9.661757812500000000e02, - 7.852968139648437500e02, - 7.608903808593750000e02, - 4.912100524902343750e02, - 3.803995361328125000e02, - 3.586986389160156250e02, - 3.667009277343750000e02, - 3.461187133789062500e02, - 3.609817199707031250e02, - 3.449771728515625000e02, - 3.404109497070312500e02, - 4.843721313476562500e02, - 6.560958862304687500e02, - 6.214497680664062500e02, - 6.937899780273437500e02, - 7.357191772460937500e02, - 7.421917724609375000e02, - 9.673173217773437500e02, - 9.375456542968750000e02, - 8.986301269531250000e02, - 9.787671508789062500e02, - 1.026849365234375000e03, - 9.684703369140625000e02, - 9.066552734375000000e02, - 8.253767089843750000e02, - 7.277055053710937500e02, - 7.002625732421875000e02, - 4.809360656738281250e02, - 3.712785339355468750e02, - 3.518378906250000000e02, - 3.552625427246093750e02, - 3.426940612792968750e02, - 3.438470458984375000e02, - 3.324201049804687500e02, - 3.278538818359375000e02, - 4.923401794433593750e02, - 6.351483764648437500e02, - 6.054566040039062500e02, - 6.796917724609375000e02, - 6.983789672851562500e02, - 7.273287963867187500e02, - 9.192237548828125000e02, - 9.512785644531250000e02, - 9.547260131835937500e02, - 9.169406127929687500e02, - 1.023401855468750000e03, - 1.011940612792968750e03, - 1.018824218750000000e03, - 9.146575317382812500e02, - 8.070548095703125000e02, - 7.444863281250000000e02, - 5.072031860351562500e02, - 3.998287658691406250e02, - 3.575570678710937500e02, - 3.598401794433593750e02, - 3.438356018066406250e02, - 3.472716979980468750e02, - 3.381278686523437500e02, - 3.426940612792968750e02, - 6.206963500976562500e02, - 5.837442626953125000e02, - 6.111643676757812500e02, - 7.471689453125000000e02, - 7.593721313476562500e02, - 8.070548095703125000e02, - 9.604451904296875000e02, - 1.025707763671875000e03, - 9.741781005859375000e02, - 9.364155273437500000e02, - 1.019965759277343750e03, - 1.085216918945312500e03, - 9.970662231445312500e02, - 8.929109497070312500e02, - 7.597374267578125000e02, - 7.505822143554687500e02, - 5.083447570800781250e02, - 3.986872253417968750e02, - 3.689954223632812500e02, - 3.769862976074218750e02, - 3.506849365234375000e02, - 3.552739868164062500e02, - 3.404223632812500000e02, - 3.312785339355468750e02, - 5.711757812500000000e02, - 5.403310546875000000e02, - 6.317237548828125000e02, - 7.059703369140625000e02, - 7.517579956054687500e02, - 8.024657592773437500e02, - 1.046301391601562500e03, - 1.109269409179687500e03, - 1.121860717773437500e03, - 1.015388122558593750e03, - 1.094383544921875000e03, - 1.151620971679687500e03, - 1.062317382812500000e03, - 1.077191772460937500e03, - 8.036187133789062500e02, - 7.216096191406250000e02, - 5.003424682617187500e02, - 3.964041137695312500e02, - 3.735616455078125000e02, - ], - [ - 2.263374519348144531e01, - 2.932098770141601562e01, - 2.932098770141601562e01, - 2.949245452880859375e01, - 2.880658531188964844e01, - 2.657750320434570312e01, - 2.263374519348144531e01, - 1.886145401000976562e01, - 2.434842300415039062e01, - 3.137860107421875000e01, - 2.400548744201660156e01, - 2.897805213928222656e01, - 3.377914810180664062e01, - 3.360768127441406250e01, - 3.240740585327148438e01, - 3.275034332275390625e01, - 2.880658531188964844e01, - 2.743484306335449219e01, - 3.755144119262695312e01, - 3.635116577148437500e01, - 3.703703689575195312e01, - 4.098079681396484375e01, - 3.275034332275390625e01, - 2.897805213928222656e01, - 3.086419677734375000e01, - 2.726337432861328125e01, - 1.920438957214355469e01, - 2.777777862548828125e01, - 2.846364974975585938e01, - 2.331961631774902344e01, - 1.406035709381103516e01, - 2.640603637695312500e01, - 2.537722969055175781e01, - 2.572016525268554688e01, - 3.240740585327148438e01, - 2.743484306335449219e01, - 2.897805213928222656e01, - 2.726337432861328125e01, - 2.006172752380371094e01, - 3.069272994995117188e01, - 3.000685882568359375e01, - 2.503429412841796875e01, - 3.000685882568359375e01, - 3.858024597167968750e01, - 3.789437484741210938e01, - 3.240740585327148438e01, - 3.480795669555664062e01, - 2.589163208007812500e01, - 2.263374519348144531e01, - 1.954732513427734375e01, - 1.920438957214355469e01, - 1.817558288574218750e01, - 2.006172752380371094e01, - 1.543209838867187500e01, - 1.406035709381103516e01, - 2.280521202087402344e01, - 3.275034332275390625e01, - 3.412208557128906250e01, - 3.360768127441406250e01, - 3.446501922607421875e01, - 3.515089035034179688e01, - 3.395061874389648438e01, - 2.846364974975585938e01, - 2.949245452880859375e01, - 3.017832565307617188e01, - 2.777777862548828125e01, - 3.446501922607421875e01, - 4.526749038696289062e01, - 3.943758392333984375e01, - 4.818244171142578125e01, - 3.737997436523437500e01, - 3.480795669555664062e01, - 2.400548744201660156e01, - 2.623456764221191406e01, - 2.760630989074707031e01, - 2.846364974975585938e01, - 2.932098770141601562e01, - 1.851851844787597656e01, - 2.383401870727539062e01, - 2.914951896667480469e01, - 2.897805213928222656e01, - 2.709190750122070312e01, - 3.343621444702148438e01, - 3.292181015014648438e01, - 3.772290802001953125e01, - 3.858024597167968750e01, - 3.326474761962890625e01, - 2.589163208007812500e01, - 3.172153663635253906e01, - 3.377914810180664062e01, - 4.406721496582031250e01, - 4.749657058715820312e01, - 4.406721496582031250e01, - 3.446501922607421875e01, - 3.960905456542968750e01, - 3.463648986816406250e01, - 3.189300346374511719e01, - 3.034979438781738281e01, - 3.069272994995117188e01, - 2.880658531188964844e01, - 2.897805213928222656e01, - 2.143346977233886719e01, - 2.331961631774902344e01, - 2.520576095581054688e01, - 2.109053421020507812e01, - 2.777777862548828125e01, - 3.000685882568359375e01, - 3.017832565307617188e01, - 3.343621444702148438e01, - 3.223593902587890625e01, - 2.503429412841796875e01, - 3.446501922607421875e01, - 3.343621444702148438e01, - 2.812071418762207031e01, - 2.897805213928222656e01, - 3.995198822021484375e01, - 3.377914810180664062e01, - 3.515089035034179688e01, - 3.703703689575195312e01, - 2.897805213928222656e01, - 2.366255187988281250e01, - 2.812071418762207031e01, - 2.760630989074707031e01, - 2.143346977233886719e01, - 2.743484306335449219e01, - 2.589163208007812500e01, - 2.177640533447265625e01, - 2.057613182067871094e01, - 3.292181015014648438e01, - 2.486282539367675781e01, - 2.589163208007812500e01, - 4.200960159301757812e01, - 3.086419677734375000e01, - 3.412208557128906250e01, - 3.840877914428710938e01, - 3.155006790161132812e01, - 3.600823211669921875e01, - 3.669410324096679688e01, - 3.446501922607421875e01, - 3.720850372314453125e01, - 4.406721496582031250e01, - 3.326474761962890625e01, - 3.532236099243164062e01, - 3.858024597167968750e01, - 3.275034332275390625e01, - 2.194787406921386719e01, - 2.897805213928222656e01, - 2.863511657714843750e01, - 2.469135856628417969e01, - 2.331961631774902344e01, - 2.554869651794433594e01, - 2.143346977233886719e01, - 2.486282539367675781e01, - 2.743484306335449219e01, - 2.383401870727539062e01, - 3.069272994995117188e01, - 2.794924545288085938e01, - 2.812071418762207031e01, - 2.434842300415039062e01, - 1.989026069641113281e01, - 1.971879196166992188e01, - 2.194787406921386719e01, - 2.486282539367675781e01, - 3.360768127441406250e01, - 3.343621444702148438e01, - 3.206447219848632812e01, - 3.034979438781738281e01, - 2.349108314514160156e01, - ], - [ - 1.214824981689453125e02, - 1.192518844604492188e02, - 1.103294448852539062e02, - 1.139327392578125000e02, - 1.137611541748046875e02, - 1.456760406494140625e02, - 1.597460479736328125e02, - 1.443033599853515625e02, - 1.691832580566406250e02, - 1.798215484619140625e02, - 1.918325347900390625e02, - 1.990391235351562500e02, - 1.932052154541015625e02, - 2.007549743652343750e02, - 2.069320526123046875e02, - 2.084763183593750000e02, - 2.134523010253906250e02, - 2.148249816894531250e02, - 2.132807159423828125e02, - 2.143102264404296875e02, - 2.086479034423828125e02, - 1.935483856201171875e02, - 1.743308105468750000e02, - 1.717570343017578125e02, - 1.662662963867187500e02, - 1.657515411376953125e02, - 1.580301971435546875e02, - 1.626630096435546875e02, - 1.640356903076171875e02, - 1.841111907958984375e02, - 2.000686340332031250e02, - 1.853122863769531250e02, - 2.057309570312500000e02, - 2.074468078613281250e02, - 2.151681518554687500e02, - 2.095058288574218750e02, - 2.088194885253906250e02, - 2.156829071044921875e02, - 2.189430389404296875e02, - 2.215168151855468750e02, - 2.287234039306640625e02, - 2.292381591796875000e02, - 2.230610809326171875e02, - 2.177419281005859375e02, - 2.115648651123046875e02, - 1.921757049560546875e02, - 1.741592254638671875e02, - 1.535689697265625000e02, - 1.590597076416015625e02, - 1.606039733886718750e02, - 1.568291015625000000e02, - 1.571722717285156250e02, - 1.575154418945312500e02, - 1.911461944580078125e02, - 1.866849670410156250e02, - 1.885724029541015625e02, - 2.033287506103515625e02, - 2.115648651123046875e02, - 2.316403503417968750e02, - 2.280370635986328125e02, - 2.302676696777343750e02, - 2.395332946777343750e02, - 2.343857269287109375e02, - 2.278654785156250000e02, - 2.285518188476562500e02, - 2.268359680175781250e02, - 2.184282836914062500e02, - 2.112216949462890625e02, - 2.155113220214843750e02, - 2.148249816894531250e02, - 1.988675384521484375e02, - 1.842827758789062500e02, - 1.726149597167968750e02, - 1.806794738769531250e02, - 1.762182617187500000e02, - 1.774193572998046875e02, - 1.835964355468750000e02, - 1.964653472900390625e02, - 2.240905914306640625e02, - 1.781056976318359375e02, - 1.878860626220703125e02, - 1.969801025390625000e02, - 2.091626586914062500e02, - 2.098489990234375000e02, - 2.163692474365234375e02, - 2.210020599365234375e02, - 2.251201171875000000e02, - 2.242621765136718750e02, - 2.196293792724609375e02, - 2.276938934326171875e02, - 2.225463256835937500e02, - 2.210020599365234375e02, - 2.047014465332031250e02, - 1.944063110351562500e02, - 1.722717895507812500e02, - 1.523678741455078125e02, - 1.559711761474609375e02, - 1.551132507324218750e02, - 1.513383636474609375e02, - 1.544269104003906250e02, - 1.530542144775390625e02, - 1.871997222900390625e02, - 1.799931335449218750e02, - 1.717570343017578125e02, - 1.935483856201171875e02, - 1.969801025390625000e02, - 2.067604675292968750e02, - 2.228894958496093750e02, - 2.105353393554687500e02, - 2.149965667724609375e02, - 2.024708251953125000e02, - 2.409059753417968750e02, - 2.491420745849609375e02, - 2.283802337646484375e02, - 2.213452301025390625e02, - 2.198009643554687500e02, - 2.007549743652343750e02, - 1.859986267089843750e02, - 1.739876403808593750e02, - 1.583733673095703125e02, - 1.376115264892578125e02, - 1.377831115722656250e02, - 1.362388458251953125e02, - 1.353809204101562500e02, - 1.431022644042968750e02, - 1.695264282226562500e02, - 1.822237548828125000e02, - 1.703843536376953125e02, - 1.896019287109375000e02, - 1.983527832031250000e02, - 2.048730316162109375e02, - 2.052162017822265625e02, - 2.062457122802734375e02, - 2.081331481933593750e02, - 2.047014465332031250e02, - 2.149965667724609375e02, - 2.047014465332031250e02, - 2.055593719482421875e02, - 2.227179107666015625e02, - 2.113932800292968750e02, - 1.998970489501953125e02, - 1.817089843750000000e02, - 1.657515411376953125e02, - 1.499656829833984375e02, - 1.364104309082031250e02, - 1.312628631591796875e02, - 1.283459167480468750e02, - 1.250857925415039062e02, - 1.264584732055664062e02, - 1.654083709716796875e02, - 1.798215484619140625e02, - 1.643788604736328125e02, - 1.904598541259765625e02, - 1.921757049560546875e02, - 2.069320526123046875e02, - 2.215168151855468750e02, - 2.185998687744140625e02, - 2.246053466796875000e02, - 2.259780426025390625e02, - 2.304392547607421875e02, - 2.283802337646484375e02, - 2.264927978515625000e02, - 2.222031555175781250e02, - 2.115648651123046875e02, - 2.010981445312500000e02, - 1.944063110351562500e02, - 1.676389770507812500e02, - 1.460192108154296875e02, - ], - [ - 6.946183013916015625e01, - 6.226533126831054688e01, - 6.007509231567382812e01, - 5.913642120361328125e01, - 5.788485717773437500e01, - 6.289111328125000000e01, - 4.787234115600585938e01, - 5.287859725952148438e01, - 5.287859725952148438e01, - 6.038798522949218750e01, - 6.570713043212890625e01, - 7.196495819091796875e01, - 7.352941131591796875e01, - 6.007509231567382812e01, - 5.851063919067382812e01, - 6.695870208740234375e01, - 6.289111328125000000e01, - 7.259073638916015625e01, - 9.355444335937500000e01, - 8.573216247558593750e01, - 9.793492126464843750e01, - 1.013767242431640625e02, - 1.004380493164062500e02, - 7.884856414794921875e01, - 6.727159118652343750e01, - 6.351689529418945312e01, - 6.351689529418945312e01, - 6.226533126831054688e01, - 6.289111328125000000e01, - 6.821026611328125000e01, - 4.630788421630859375e01, - 5.162703323364257812e01, - 5.037546920776367188e01, - 6.070087432861328125e01, - 6.821026611328125000e01, - 7.446808624267578125e01, - 7.415519714355468750e01, - 6.758448028564453125e01, - 6.539424133300781250e01, - 6.602002716064453125e01, - 6.226533126831054688e01, - 7.133917236328125000e01, - 8.854818725585937500e01, - 8.823529052734375000e01, - 1.041927413940429688e02, - 1.029411773681640625e02, - 9.699624633789062500e01, - 7.634542846679687500e01, - 6.758448028564453125e01, - 6.445556640625000000e01, - 6.351689529418945312e01, - 6.101376724243164062e01, - 5.976220321655273438e01, - 6.789736938476562500e01, - 5.225281524658203125e01, - 4.724655914306640625e01, - 5.131414413452148438e01, - 6.508135223388671875e01, - 7.478097534179687500e01, - 7.384230041503906250e01, - 6.163954925537109375e01, - 6.476846313476562500e01, - 6.414267730712890625e01, - 5.694618225097656250e01, - 6.508135223388671875e01, - 6.914893341064453125e01, - 7.947434234619140625e01, - 8.604505920410156250e01, - 1.001251602172851562e02, - 9.793492126464843750e01, - 9.167709350585937500e01, - 7.790988922119140625e01, - 6.789736938476562500e01, - 6.633291625976562500e01, - 6.351689529418945312e01, - 6.226533126831054688e01, - 6.195244216918945312e01, - 6.476846313476562500e01, - 5.225281524658203125e01, - 5.068836212158203125e01, - 5.413016128540039062e01, - 5.757196426391601562e01, - 7.352941131591796875e01, - 7.321652221679687500e01, - 7.196495819091796875e01, - 6.163954925537109375e01, - 5.882352828979492188e01, - 6.414267730712890625e01, - 6.476846313476562500e01, - 7.790988922119140625e01, - 9.824781036376953125e01, - 8.604505920410156250e01, - 1.026282882690429688e02, - 1.091989974975585938e02, - 1.073216552734375000e02, - 8.541927337646484375e01, - 7.415519714355468750e01, - 6.852315521240234375e01, - 6.289111328125000000e01, - 6.070087432861328125e01, - 6.101376724243164062e01, - 6.414267730712890625e01, - 5.256570816040039062e01, - 4.943679428100585938e01, - 6.070087432861328125e01, - 7.571965026855468750e01, - 8.792240142822265625e01, - 8.698372650146484375e01, - 7.822277832031250000e01, - 7.133917236328125000e01, - 7.227784729003906250e01, - 6.602002716064453125e01, - 6.883604431152343750e01, - 6.977471923828125000e01, - 7.978723144531250000e01, - 9.042552947998046875e01, - 1.057571945190429688e02, - 1.007509384155273438e02, - 9.762202453613281250e01, - 8.479349517822265625e01, - 7.165206146240234375e01, - 6.821026611328125000e01, - 6.226533126831054688e01, - 6.195244216918945312e01, - 7.478097534179687500e01, - 5.882352828979492188e01, - 4.974968719482421875e01, - 5.381727218627929688e01, - 6.007509231567382812e01, - 7.259073638916015625e01, - 7.603253936767578125e01, - 8.573216247558593750e01, - 8.698372650146484375e01, - 8.041301727294921875e01, - 7.853566741943359375e01, - 8.197747039794921875e01, - 7.603253936767578125e01, - 8.948686218261718750e01, - 8.948686218261718750e01, - 9.230287933349609375e01, - 9.230287933349609375e01, - 9.511889648437500000e01, - 9.543179321289062500e01, - 8.823529052734375000e01, - 8.166458129882812500e01, - 7.133917236328125000e01, - 6.195244216918945312e01, - 6.038798522949218750e01, - 5.976220321655273438e01, - 6.883604431152343750e01, - 4.849812316894531250e01, - 5.319149017333984375e01, - 6.382978820800781250e01, - 7.634542846679687500e01, - 9.261576843261718750e01, - 1.001251602172851562e02, - 8.573216247558593750e01, - 6.539424133300781250e01, - 5.600751113891601562e01, - 5.694618225097656250e01, - 5.538172531127929688e01, - 6.382978820800781250e01, - 7.790988922119140625e01, - 8.197747039794921875e01, - 1.029411773681640625e02, - 1.135794754028320312e02, - 1.004380493164062500e02, - 8.479349517822265625e01, - ], - [ - 4.069791030883789062e01, - 4.158940505981445312e01, - 3.891365432739257812e01, - 3.942435073852539062e01, - 4.057055664062500000e01, - 6.334054946899414062e01, - 5.866785430908203125e01, - 6.113219451904296875e01, - 7.438232421875000000e01, - 8.779418945312500000e01, - 9.533239746093750000e01, - 9.737519073486328125e01, - 9.609780883789062500e01, - 9.724783325195312500e01, - 9.839658355712890625e01, - 1.000573120117187500e02, - 1.004406509399414062e02, - 1.027394256591796875e02, - 9.839658355712890625e01, - 1.037621002197265625e02, - 9.520503997802734375e01, - 7.690779113769531250e01, - 4.783494567871093750e01, - 4.464849853515625000e01, - 4.490448379516601562e01, - 4.630667495727539062e01, - 4.362837600708007812e01, - 4.528655242919921875e01, - 4.452114105224609375e01, - 6.675624084472656250e01, - 7.067881011962890625e01, - 6.830870819091796875e01, - 7.859780883789062500e01, - 9.124427032470703125e01, - 9.673586273193359375e01, - 9.367167663574218750e01, - 9.852394104003906250e01, - 9.699057769775390625e01, - 9.647988128662109375e01, - 1.008239974975585938e02, - 1.017180328369140625e02, - 1.015906753540039062e02, - 1.017193069458007812e02, - 9.545848083496093750e01, - 9.341441345214843750e01, - 7.792536926269531250e01, - 5.165945053100585938e01, - 4.707080841064453125e01, - 4.719816589355468750e01, - 4.643275451660156250e01, - 4.579597473144531250e01, - 4.656138610839843750e01, - 4.732552337646484375e01, - 7.882450103759765625e01, - 7.118950653076171875e01, - 7.361691284179687500e01, - 8.600611114501953125e01, - 1.018466644287109375e02, - 9.929190063476562500e01, - 1.051681137084960938e02, - 1.092536926269531250e02, - 1.017180328369140625e02, - 1.055514526367187500e02, - 1.087442703247070312e02, - 1.073382568359375000e02, - 1.014620513916015625e02, - 1.061895065307617188e02, - 9.954534149169921875e01, - 9.418109893798828125e01, - 7.562786865234375000e01, - 4.923586273193359375e01, - 4.630540084838867188e01, - 4.643275451660156250e01, - 4.643275451660156250e01, - 4.477585220336914062e01, - 4.541263198852539062e01, - 4.388181304931640625e01, - 6.921293640136718750e01, - 7.013372039794921875e01, - 6.910977935791015625e01, - 8.690015411376953125e01, - 1.012073364257812500e02, - 1.023560867309570312e02, - 1.006953659057617188e02, - 9.967396545410156250e01, - 9.980132293701171875e01, - 1.022287292480468750e02, - 1.095096817016601562e02, - 1.096370315551757812e02, - 1.024847183227539062e02, - 1.031240463256835938e02, - 1.026120758056640625e02, - 9.277508544921875000e01, - 7.511716461181640625e01, - 4.872771453857421875e01, - 4.273560714721679688e01, - 4.311894989013671875e01, - 4.107997894287109375e01, - 3.980641937255859375e01, - 4.006113052368164062e01, - 3.853158569335937500e01, - 6.669383239746093750e01, - 6.282093811035156250e01, - 6.205807495117187500e01, - 8.000127410888671875e01, - 9.022160339355468750e01, - 9.839531707763671875e01, - 9.941798400878906250e01, - 1.023573608398437500e02, - 1.008239974975585938e02, - 1.035061111450195312e02, - 1.044001541137695312e02, - 1.088716278076171875e02, - 1.124477844238281250e02, - 1.075929718017578125e02, - 1.008239974975585938e02, - 9.622644042968750000e01, - 9.073229980468750000e01, - 6.257386779785156250e01, - 4.311894989013671875e01, - 4.299032211303710938e01, - 4.273815536499023438e01, - 4.095262527465820312e01, - 4.082526779174804688e01, - 4.146204757690429688e01, - 7.409832000732421875e01, - 6.358507537841796875e01, - 6.754330444335937500e01, - 8.396331787109375000e01, - 9.405374145507812500e01, - 1.013346939086914062e02, - 1.063155899047851562e02, - 1.123204269409179688e02, - 1.139798812866210938e02, - 1.114251174926757812e02, - 1.038894577026367188e02, - 1.175560379028320312e02, - 1.093823242187500000e02, - 1.097656631469726562e02, - 1.035061111450195312e02, - 9.890728759765625000e01, - 9.341568756103515625e01, - 6.959373474121093750e01, - 4.503183746337890625e01, - 4.515919494628906250e01, - 4.388308715820312500e01, - 4.069791030883789062e01, - 4.082526779174804688e01, - 4.069791030883789062e01, - 6.576541137695312500e01, - 6.329724884033203125e01, - 6.716250610351562500e01, - 8.204534149169921875e01, - 9.405374145507812500e01, - 1.003120193481445312e02, - 1.045287857055664062e02, - 1.082335739135742188e02, - 1.067002029418945312e02, - 1.054215469360351562e02, - 1.070835418701171875e02, - 1.100216522216796875e02, - 1.037608261108398438e02, - 9.865257263183593750e01, - 9.775852966308593750e01, - 9.711793518066406250e01, - 8.137290191650390625e01, - 4.974656295776367188e01, - 4.464849853515625000e01, - ], - [ - 1.140194625854492188e02, - 1.181511993408203125e02, - 1.222866744995117188e02, - 1.147717056274414062e02, - 1.102619781494140625e02, - 1.170284423828125000e02, - 1.421856231689453125e02, - 1.713136291503906250e02, - 1.703555450439453125e02, - 1.906886291503906250e02, - 2.291916198730468750e02, - 2.488061370849609375e02, - 2.491841278076171875e02, - 2.295696105957031250e02, - 2.412574920654296875e02, - 2.488061370849609375e02, - 2.601235046386718750e02, - 2.408869781494140625e02, - 2.397492523193359375e02, - 2.078667602539062500e02, - 1.358046417236328125e02, - 1.249139251708984375e02, - 1.219124221801757812e02, - 1.155202102661132812e02, - 1.162761993408203125e02, - 1.196519470214843750e02, - 1.245359268188476562e02, - 1.189034423828125000e02, - 1.162761993408203125e02, - 1.410591278076171875e02, - 1.515793457031250000e02, - 2.093974609375000000e02, - 2.220247039794921875e02, - 2.397567291259765625e02, - 2.518263397216796875e02, - 2.555950622558593750e02, - 2.435254516601562500e02, - 2.242926635742187500e02, - 1.980613708496093750e02, - 2.208907165527343750e02, - 2.337163238525390625e02, - 2.310778503417968750e02, - 1.957934112548828125e02, - 1.893824920654296875e02, - 1.403068847656250000e02, - 1.264146728515625000e02, - 1.207821884155273438e02, - 1.155239486694335938e02, - 1.166504516601562500e02, - 1.189071884155273438e02, - 1.252881698608398438e02, - 1.192814407348632812e02, - 1.158982009887695312e02, - 1.414371185302734375e02, - 1.583345794677734375e02, - 2.112799377441406250e02, - 2.442814331054687500e02, - 2.499401245117187500e02, - 2.529528503417968750e02, - 2.495583801269531250e02, - 2.552208099365234375e02, - 2.488061370849609375e02, - 2.465456542968750000e02, - 2.503181152343750000e02, - 2.552208099365234375e02, - 2.559693145751953125e02, - 2.484281463623046875e02, - 2.325860748291015625e02, - 1.423989562988281250e02, - 1.241654205322265625e02, - 1.192739486694335938e02, - 1.155202102661132812e02, - 1.158944625854492188e02, - 1.185254516601562500e02, - 1.256624221801757812e02, - 1.189034423828125000e02, - 1.155239486694335938e02, - 1.403106231689453125e02, - 1.538323364257812500e02, - 2.082634735107421875e02, - 2.152357788085937500e02, - 2.340943145751953125e02, - 2.457896728515625000e02, - 2.461676635742187500e02, - 2.337163238525390625e02, - 2.363622741699218750e02, - 2.261751556396484375e02, - 2.431474609375000000e02, - 2.427694549560546875e02, - 2.205127258300781250e02, - 1.759880218505859375e02, - 1.863398132324218750e02, - 1.354266510009765625e02, - 1.200299377441406250e02, - 1.158982009887695312e02, - 1.110142211914062500e02, - 1.098877258300781250e02, - 1.125187149047851562e02, - 1.181549377441406250e02, - 1.113884735107421875e02, - 1.065119781494140625e02, - 1.335516510009765625e02, - 1.425636291503906250e02, - 1.694386291503906250e02, - 1.718562927246093750e02, - 1.916429595947265625e02, - 2.276871185302734375e02, - 2.446556854248046875e02, - 2.371145172119140625e02, - 2.333420715332031250e02, - 2.186302337646484375e02, - 2.431474609375000000e02, - 2.593675231933593750e02, - 2.529565887451171875e02, - 2.537088317871093750e02, - 2.352245483398437500e02, - 1.369311370849609375e02, - 1.219049377441406250e02, - 1.173989486694335938e02, - 1.132709579467773438e02, - 1.125187149047851562e02, - 1.151459579467773438e02, - 1.204079360961914062e02, - 1.140194625854492188e02, - 1.102619781494140625e02, - 1.327994079589843750e02, - 1.433121185302734375e02, - 1.943188629150390625e02, - 1.980651245117187500e02, - 2.193824920654296875e02, - 2.363622741699218750e02, - 2.540868225097656250e02, - 2.544648132324218750e02, - 2.593712463378906250e02, - 2.537125701904296875e02, - 2.457896728515625000e02, - 2.578592834472656250e02, - 2.601235046386718750e02, - 2.533308410644531250e02, - 2.446594238281250000e02, - 1.469086761474609375e02, - 1.222829360961914062e02, - 1.185291900634765625e02, - 1.147717056274414062e02, - 1.136414642333984375e02, - 1.151459579467773438e02, - 1.215344314575195312e02, - 1.136414642333984375e02, - 1.102619781494140625e02, - 1.331736602783203125e02, - 1.436901245117187500e02, - 1.984767913818359375e02, - 1.948652648925781250e02, - 2.276796417236328125e02, - 2.597492370605468750e02, - 2.631399841308593750e02, - 2.593675231933593750e02, - 2.604977416992187500e02, - 2.805651245117187500e02, - 2.882223205566406250e02, - 2.730014953613281250e02, - 2.786601867675781250e02, - 2.714932556152343750e02, - 2.461676635742187500e02, - 1.472866821289062500e02, - 1.256661682128906250e02, - 1.222829360961914062e02, - 1.166504516601562500e02, - ], - [ - 1.792779235839843750e02, - 1.680313415527343750e02, - 1.694005432128906250e02, - 1.666689300537109375e02, - 1.663317413330078125e02, - 1.772343292236328125e02, - 2.585524597167968750e02, - 2.466008148193359375e02, - 2.790497131347656250e02, - 2.776839294433593750e02, - 2.957867736816406250e02, - 3.036410217285156250e02, - 2.995436096191406250e02, - 3.166212463378906250e02, - 3.152520446777343750e02, - 3.149114379882812500e02, - 3.258412780761718750e02, - 3.009093933105468750e02, - 3.087636108398437500e02, - 3.152520446777343750e02, - 3.179870605468750000e02, - 3.080824279785156250e02, - 2.722173156738281250e02, - 2.232629394531250000e02, - 2.017779235839843750e02, - 1.908719329833984375e02, - 1.860967254638671875e02, - 1.802997283935546875e02, - 1.782561340332031250e02, - 1.860933227539062500e02, - 2.691450805664062500e02, - 2.848569335937500000e02, - 3.053474121093750000e02, - 3.067132263183593750e02, - 3.442847290039062500e02, - 3.296015014648437500e02, - 3.268699035644531250e02, - 3.381403198242187500e02, - 3.203746643066406250e02, - 3.319856872558593750e02, - 3.405279235839843750e02, - 3.094482421875000000e02, - 3.073978271484375000e02, - 3.220844726562500000e02, - 3.265258789062500000e02, - 3.121764221191406250e02, - 2.705075073242187500e02, - 2.285047760009765625e02, - 2.080279235839843750e02, - 1.952997283935546875e02, - 1.932561340332031250e02, - 1.891689300537109375e02, - 1.867813415527343750e02, - 1.939373321533203125e02, - 2.598126831054687500e02, - 2.913453674316406250e02, - 3.084230346679687500e02, - 3.128610229492187500e02, - 3.299421081542968750e02, - 3.436069335937500000e02, - 3.336954956054687500e02, - 3.504325561523437500e02, - 3.330143127441406250e02, - 3.220810546875000000e02, - 3.391621398925781250e02, - 3.183310546875000000e02, - 3.050068054199218750e02, - 3.261818847656250000e02, - 3.299421081542968750e02, - 3.125204467773437500e02, - 2.691450805664062500e02, - 2.244005432128906250e02, - 2.059877319335937500e02, - 1.949625396728515625e02, - 1.925749359130859375e02, - 1.888283386230468750e02, - 1.884877319335937500e02, - 2.018937377929687500e02, - 2.729019165039062500e02, - 2.940769653320312500e02, - 3.244788818359375000e02, - 3.173024597167968750e02, - 3.381403198242187500e02, - 3.272104797363281250e02, - 3.466791687011718750e02, - 3.354053039550781250e02, - 3.289169006347656250e02, - 3.401873168945312500e02, - 3.425749206542968750e02, - 3.149114379882812500e02, - 3.039850158691406250e02, - 3.227656555175781250e02, - 3.237908630371093750e02, - 3.128610229492187500e02, - 2.708480834960937500e02, - 2.315735626220703125e02, - 1.884809265136718750e02, - 1.803031311035156250e02, - 1.813249359130859375e02, - 1.809809265136718750e02, - 1.799591217041015625e02, - 2.236273803710937500e02, - 2.558242492675781250e02, - 2.780245361328125000e02, - 3.012465820312500000e02, - 3.026158142089843750e02, - 3.309673156738281250e02, - 3.227690734863281250e02, - 3.282322998046875000e02, - 3.265258789062500000e02, - 3.357459106445312500e02, - 3.347207031250000000e02, - 3.381403198242187500e02, - 3.149148559570312500e02, - 3.162806396484375000e02, - 3.237908630371093750e02, - 3.097922363281250000e02, - 3.036410217285156250e02, - 2.612874755859375000e02, - 2.131403198242187500e02, - 1.905313415527343750e02, - 1.799625396728515625e02, - 1.782561340332031250e02, - 1.799591217041015625e02, - 1.779155273437500000e02, - 2.352384185791015625e02, - 2.548024597167968750e02, - 2.619686584472656250e02, - 2.992030029296875000e02, - 2.961273803710937500e02, - 3.138862304687500000e02, - 3.343800964355468750e02, - 3.330143127441406250e02, - 3.227690734863281250e02, - 3.152520446777343750e02, - 3.234502868652343750e02, - 3.323330993652343750e02, - 3.145708312988281250e02, - 3.002247924804687500e02, - 3.142302551269531250e02, - 3.193528747558593750e02, - 3.036376037597656250e02, - 2.551464538574218750e02, - 2.176907348632812500e02, - 1.946185302734375000e02, - 1.799625396728515625e02, - 1.802997283935546875e02, - 1.772343292236328125e02, - 1.809809265136718750e02, - 2.389918212890625000e02, - 2.565054626464843750e02, - 2.534298400878906250e02, - 2.937363891601562500e02, - 2.971525878906250000e02, - 3.091076354980468750e02, - 3.309638977050781250e02, - 3.336954956054687500e02, - 3.323297119140625000e02, - 3.251566772460937500e02, - 3.504359741210937500e02, - 3.449693603515625000e02, - 3.149114379882812500e02, - 3.104734191894531250e02, - 3.268630676269531250e02, - 3.268664855957031250e02, - 3.094482421875000000e02, - 2.691450805664062500e02, - 2.232697601318359375e02, - ], - [ - 3.947381896972656250e02, - 3.778378295898437500e02, - 3.702280273437500000e02, - 3.660050659179687500e02, - 5.198479614257812500e02, - 5.274493408203125000e02, - 5.730996704101562500e02, - 5.959121704101562500e02, - 6.664611206054687500e02, - 7.358361206054687500e02, - 7.853547363281250000e02, - 8.268750000000000000e02, - 8.395861206054687500e02, - 8.141723022460937500e02, - 7.938428955078125000e02, - 8.760134887695312500e02, - 8.827871704101562500e02, - 8.395861206054687500e02, - 7.845185546875000000e02, - 7.489273681640625000e02, - 7.396115112304687500e02, - 5.734966430664062500e02, - 4.420692443847656250e02, - 4.124830932617187500e02, - 4.260135192871093750e02, - 4.116385192871093750e02, - 4.048817443847656250e02, - 4.031925659179687500e02, - 5.739357910156250000e02, - 5.967567749023437500e02, - 6.246536865234375000e02, - 6.195861206054687500e02, - 6.803969726562500000e02, - 7.751942749023437500e02, - 8.345017089843750000e02, - 7.743496704101562500e02, - 8.057009887695312500e02, - 8.124746704101562500e02, - 8.209459228515625000e02, - 8.539949340820312500e02, - 8.624661865234375000e02, - 8.192482910156250000e02, - 7.955321044921875000e02, - 7.692736206054687500e02, - 7.235134887695312500e02, - 5.726942749023437500e02, - 4.370101318359375000e02, - 4.133446044921875000e02, - 4.243243103027343750e02, - 4.133361511230468750e02, - 4.175675659179687500e02, - 4.226351318359375000e02, - 6.288598022460937500e02, - 6.711232910156250000e02, - 6.356503295898437500e02, - 6.880321044921875000e02, - 7.989189453125000000e02, - 8.972044067382812500e02, - 9.039780273437500000e02, - 8.709290771484375000e02, - 8.870354614257812500e02, - 8.734797363281250000e02, - 8.522973022460937500e02, - 8.921199340820312500e02, - 8.904138793945312500e02, - 8.607685546875000000e02, - 8.175591430664062500e02, - 7.299240112304687500e02, - 7.065709228515625000e02, - 5.743750000000000000e02, - 4.446114807128906250e02, - 3.998057556152343750e02, - 4.184121704101562500e02, - 3.998057556152343750e02, - 4.048733215332031250e02, - 4.065709533691406250e02, - 5.190033569335937500e02, - 5.959121704101562500e02, - 6.178969726562500000e02, - 6.588767089843750000e02, - 7.070354614257812500e02, - 7.896115112304687500e02, - 7.921536865234375000e02, - 7.870607910156250000e02, - 8.260303955078125000e02, - 7.921452636718750000e02, - 8.031588134765625000e02, - 8.683868408203125000e02, - 8.311148681640625000e02, - 7.896030273437500000e02, - 7.565625000000000000e02, - 7.112838134765625000e02, - 6.664696044921875000e02, - 5.337584228515625000e02, - 3.998057556152343750e02, - 3.803716125488281250e02, - 3.854391784667968750e02, - 3.676858215332031250e02, - 3.634628295898437500e02, - 3.685473022460937500e02, - 4.919510192871093750e02, - 5.536571044921875000e02, - 5.578800659179687500e02, - 5.790033569335937500e02, - 6.474830932617187500e02, - 7.463935546875000000e02, - 7.879053955078125000e02, - 8.090878295898437500e02, - 8.073901977539062500e02, - 7.692651977539062500e02, - 8.277280273437500000e02, - 8.700844726562500000e02, - 8.794088134765625000e02, - 8.345101318359375000e02, - 8.667060546875000000e02, - 7.523226318359375000e02, - 7.353800659179687500e02, - 6.554982910156250000e02, - 5.232178955078125000e02, - 3.938935852050781250e02, - 4.116469726562500000e02, - 3.896790466308593750e02, - 3.820608215332031250e02, - 3.854391784667968750e02, - 5.274408569335937500e02, - 6.001351318359375000e02, - 6.212753295898437500e02, - 6.322634887695312500e02, - 7.045017089843750000e02, - 8.387500000000000000e02, - 8.997381591796875000e02, - 8.912669067382812500e02, - 8.988851318359375000e02, - 8.548310546875000000e02, - 8.912669067382812500e02, - 9.107601318359375000e02, - 9.132939453125000000e02, - 8.497550659179687500e02, - 9.056672363281250000e02, - 8.285726318359375000e02, - 7.514780273437500000e02, - 6.580405273437500000e02, - 5.257601318359375000e02, - 3.837500000000000000e02, - 3.981250000000000000e02, - 3.964358215332031250e02, - 3.719172363281250000e02, - 3.829053955078125000e02, - 5.206756591796875000e02, - 5.705574340820312500e02, - 5.553463134765625000e02, - 6.233530273437500000e02, - 6.994172363281250000e02, - 8.336571044921875000e02, - 9.226182250976562500e02, - 9.056672363281250000e02, - 9.276942749023437500e02, - 8.683952636718750000e02, - 8.929560546875000000e02, - 1.004788879394531250e03, - 9.675169067382812500e02, - 9.285473022460937500e02, - 8.827955932617187500e02, - 8.014611206054687500e02, - 7.599493408203125000e02, - 5.938006591796875000e02, - 4.463006896972656250e02, - 3.896790466308593750e02, - ], - [ - 4.006647109985351562e01, - 3.545051574707031250e01, - 2.289512634277343750e01, - 1.772525787353515625e01, - 1.070901012420654297e01, - 9.231905937194824219e00, - 7.016248226165771484e00, - 1.070901012420654297e01, - 1.920236396789550781e01, - 3.766617584228515625e01, - 4.523633575439453125e01, - 5.409896469116210938e01, - 6.776218414306640625e01, - 6.720827484130859375e01, - 5.816100311279296875e01, - 5.243722152709960938e01, - 5.649925994873046875e01, - 5.243722152709960938e01, - 6.037666320800781250e01, - 6.573117065429687500e01, - 6.831610107421875000e01, - 6.185376739501953125e01, - 5.391432952880859375e01, - 4.911373519897460938e01, - 4.689807891845703125e01, - 2.861890602111816406e01, - 2.437223052978515625e01, - 1.901772499084472656e01, - 1.052437210083007812e01, - 8.124076843261718750e00, - 7.385524272918701172e00, - 1.107828617095947266e01, - 2.621861076354980469e01, - 4.671343994140625000e01, - 5.539143371582031250e01, - 5.889955520629882812e01, - 7.200886535644531250e01, - 6.776218414306640625e01, - 6.093057632446289062e01, - 5.742245101928710938e01, - 5.391432952880859375e01, - 5.612998580932617188e01, - 5.723781204223632812e01, - 6.794682312011718750e01, - 6.443869781494140625e01, - 6.277695846557617188e01, - 6.517725372314453125e01, - 5.225258636474609375e01, - 5.040620422363281250e01, - 3.489660263061523438e01, - 2.381831550598144531e01, - 1.790989685058593750e01, - 1.181683921813964844e01, - 8.493352890014648438e00, - 8.493352890014648438e00, - 1.827917289733886719e01, - 3.175775527954101562e01, - 5.003692626953125000e01, - 5.649925994873046875e01, - 6.739290618896484375e01, - 7.090103149414062500e01, - 6.517725372314453125e01, - 5.889955520629882812e01, - 5.760708999633789062e01, - 5.760708999633789062e01, - 5.594534683227539062e01, - 5.889955520629882812e01, - 7.348596954345703125e01, - 7.256277465820312500e01, - 6.517725372314453125e01, - 5.982274627685546875e01, - 5.631462478637695312e01, - 4.689807891845703125e01, - 3.101920318603515625e01, - 2.511078262329101562e01, - 2.344903945922851562e01, - 1.144756317138671875e01, - 7.570162296295166016e00, - 7.200886249542236328e00, - 1.347858238220214844e01, - 2.566469764709472656e01, - 4.209748840332031250e01, - 4.745199584960937500e01, - 6.425405883789062500e01, - 7.920974731445312500e01, - 6.905464935302734375e01, - 5.539143371582031250e01, - 5.280649948120117188e01, - 5.059084320068359375e01, - 5.077547836303710938e01, - 5.631462478637695312e01, - 6.462333679199218750e01, - 7.163958740234375000e01, - 5.631462478637695312e01, - 5.668389892578125000e01, - 4.412850952148437500e01, - 4.080502319335937500e01, - 3.157311630249023438e01, - 2.234121131896972656e01, - 2.012555313110351562e01, - 1.532496261596679688e01, - 9.047266960144042969e00, - 9.970458030700683594e00, - 1.366322040557861328e01, - 2.677252578735351562e01, - 4.375923156738281250e01, - 5.483751678466796875e01, - 6.351551055908203125e01, - 7.994830322265625000e01, - 7.330133056640625000e01, - 6.240768051147460938e01, - 6.351551055908203125e01, - 5.834564208984375000e01, - 5.631462478637695312e01, - 7.200886535644531250e01, - 7.477843475341796875e01, - 7.071639251708984375e01, - 6.517725372314453125e01, - 5.409896469116210938e01, - 4.966765213012695312e01, - 4.560561370849609375e01, - 3.526588058471679688e01, - 2.307976341247558594e01, - 2.455686759948730469e01, - 1.366322040557861328e01, - 8.862628936767578125e00, - 7.016248226165771484e00, - 1.403249645233154297e01, - 2.437223052978515625e01, - 4.394387054443359375e01, - 5.132939529418945312e01, - 6.277695846557617188e01, - 7.293205261230468750e01, - 6.591580200195312500e01, - 6.351551055908203125e01, - 6.296159362792968750e01, - 6.333087158203125000e01, - 6.628507995605468750e01, - 7.182422637939453125e01, - 7.754800415039062500e01, - 7.644017791748046875e01, - 6.259231948852539062e01, - 5.280649948120117188e01, - 5.409896469116210938e01, - 4.911373519897460938e01, - 3.120384025573730469e01, - 2.474150657653808594e01, - 2.049483108520507812e01, - 1.310930538177490234e01, - 9.601181983947753906e00, - 8.493352890014648438e00, - 1.643279266357421875e01, - 5.096011734008789062e01, - 4.837518310546875000e01, - 4.966765213012695312e01, - 6.333087158203125000e01, - 7.440915679931640625e01, - 7.274741363525390625e01, - 6.683899688720703125e01, - 6.517725372314453125e01, - 6.831610107421875000e01, - 6.702363586425781250e01, - 7.607089996337890625e01, - 7.662481689453125000e01, - 7.828656005859375000e01, - 6.499261474609375000e01, - 6.093057632446289062e01, - 5.040620422363281250e01, - ], + return nd.array( [ - 7.309510040283203125e01, - 7.324276733398437500e01, - 7.250443267822265625e01, - 7.250443267822265625e01, - 7.132309722900390625e01, - 7.265209960937500000e01, - 7.279976654052734375e01, - 7.545776367187500000e01, - 7.649143218994140625e01, - 7.870643615722656250e01, - 8.165977478027343750e01, - 8.254577636718750000e01, - 8.210277557373046875e01, - 8.313644409179687500e01, - 8.313644409179687500e01, - 8.298877716064453125e01, - 8.254577636718750000e01, - 8.225044250488281250e01, - 8.062610626220703125e01, - 8.106910705566406250e01, - 7.959243774414062500e01, - 7.974010467529296875e01, - 7.841110229492187500e01, - 8.033077239990234375e01, - 8.062610626220703125e01, - 8.254577636718750000e01, - 8.136444091796875000e01, - 8.092144012451171875e01, - 8.092144012451171875e01, - 8.033077239990234375e01, - 8.151210784912109375e01, - 8.136444091796875000e01, - 8.121677398681640625e01, - 8.225044250488281250e01, - 8.343177795410156250e01, - 8.313644409179687500e01, - 8.254577636718750000e01, - 8.210277557373046875e01, - 8.343177795410156250e01, - 8.254577636718750000e01, - 8.225044250488281250e01, - 8.151210784912109375e01, - 8.121677398681640625e01, - 8.062610626220703125e01, - 7.974010467529296875e01, - 7.885410308837890625e01, - 7.900177001953125000e01, - 8.018310546875000000e01, - 8.225044250488281250e01, - 8.225044250488281250e01, - 8.269344329833984375e01, - 8.239810943603515625e01, - 8.033077239990234375e01, - 8.136444091796875000e01, - 8.225044250488281250e01, - 8.284111022949218750e01, - 8.269344329833984375e01, - 8.239810943603515625e01, - 8.225044250488281250e01, - 8.328411102294921875e01, - 8.357944488525390625e01, - 8.343177795410156250e01, - 8.594211578369140625e01, - 8.417011260986328125e01, - 8.357944488525390625e01, - 8.269344329833984375e01, - 8.018310546875000000e01, - 8.106910705566406250e01, - 7.841110229492187500e01, - 7.988777160644531250e01, - 7.914943695068359375e01, - 7.988777160644531250e01, - 7.959243774414062500e01, - 8.136444091796875000e01, - 8.195510864257812500e01, - 8.239810943603515625e01, - 8.165977478027343750e01, - 8.180744171142578125e01, - 8.047843933105468750e01, - 8.121677398681640625e01, - 8.151210784912109375e01, - 8.298877716064453125e01, - 8.180744171142578125e01, - 8.313644409179687500e01, - 8.180744171142578125e01, - 8.328411102294921875e01, - 8.328411102294921875e01, - 8.313644409179687500e01, - 8.165977478027343750e01, - 8.151210784912109375e01, - 7.929710388183593750e01, - 7.900177001953125000e01, - 7.826343536376953125e01, - 7.796810150146484375e01, - 7.826343536376953125e01, - 7.708209991455078125e01, - 7.309510040283203125e01, - 7.368576812744140625e01, - 7.324276733398437500e01, - 7.250443267822265625e01, - 7.294743347167968750e01, - 7.206143188476562500e01, - 7.265209960937500000e01, - 7.442410278320312500e01, - 7.575309753417968750e01, - 7.870643615722656250e01, - 8.062610626220703125e01, - 8.106910705566406250e01, - 8.180744171142578125e01, - 8.313644409179687500e01, - 8.254577636718750000e01, - 8.180744171142578125e01, - 8.328411102294921875e01, - 8.254577636718750000e01, - 7.959243774414062500e01, - 7.782043457031250000e01, - 7.841110229492187500e01, - 7.767276763916015625e01, - 7.767276763916015625e01, - 7.722976684570312500e01, - 7.516242980957031250e01, - 7.531009674072265625e01, - 7.619609832763671875e01, - 7.457176971435546875e01, - 7.442410278320312500e01, - 7.398110198974609375e01, - 7.324276733398437500e01, - 7.368576812744140625e01, - 7.398110198974609375e01, - 7.708209991455078125e01, - 7.841110229492187500e01, - 7.885410308837890625e01, - 7.988777160644531250e01, - 7.929710388183593750e01, - 7.900177001953125000e01, - 8.033077239990234375e01, - 7.900177001953125000e01, - 8.062610626220703125e01, - 7.885410308837890625e01, - 7.841110229492187500e01, - 7.988777160644531250e01, - 8.003543853759765625e01, - 7.914943695068359375e01, - 8.018310546875000000e01, - 7.914943695068359375e01, - 7.870643615722656250e01, - 7.560543060302734375e01, - 7.457176971435546875e01, - 7.442410278320312500e01, - 7.442410278320312500e01, - 7.368576812744140625e01, - 7.471943664550781250e01, - 7.412876892089843750e01, - 7.634376525878906250e01, - 8.018310546875000000e01, - 7.855876922607421875e01, - 7.959243774414062500e01, - 7.959243774414062500e01, - 7.855876922607421875e01, - 8.047843933105468750e01, - 7.944477081298828125e01, - 7.900177001953125000e01, - 7.796810150146484375e01, - 7.811576843261718750e01, - 7.974010467529296875e01, - 7.914943695068359375e01, - 7.988777160644531250e01, - 7.974010467529296875e01, - ], - ]) + [ + 3.392694091796875000e02, + 3.198630065917968750e02, + 3.210045776367187500e02, + 3.175798950195312500e02, + 3.130137023925781250e02, + 4.615182800292968750e02, + 5.757534179687500000e02, + 5.871689453125000000e02, + 6.362899780273437500e02, + 6.785616455078125000e02, + 7.677511596679687500e02, + 9.501369628906250000e02, + 1.041723754882812500e03, + 1.016529663085937500e03, + 9.112100219726562500e02, + 1.019965759277343750e03, + 1.019977172851562500e03, + 1.055456665039062500e03, + 1.033710083007812500e03, + 7.803538818359375000e02, + 7.322830810546875000e02, + 4.934931640625000000e02, + 3.872488708496093750e02, + 3.655593566894531250e02, + 3.747031860351562500e02, + 3.518378906250000000e02, + 3.621347045898437500e02, + 3.427054748535156250e02, + 3.461187133789062500e02, + 4.992009277343750000e02, + 6.488584594726562500e02, + 6.145890502929687500e02, + 6.888356323242187500e02, + 7.624086914062500000e02, + 7.517237548828125000e02, + 9.593036499023437500e02, + 9.810502319335937500e02, + 1.038287719726562500e03, + 9.249543457031250000e02, + 1.033710083007812500e03, + 1.079497680664062500e03, + 1.050879028320312500e03, + 9.673059082031250000e02, + 7.490639038085937500e02, + 7.319063720703125000e02, + 4.854908752441406250e02, + 3.826940612792968750e02, + 3.644178161621093750e02, + 3.667009277343750000e02, + 3.506963500976562500e02, + 3.438470458984375000e02, + 3.267123413085937500e02, + 3.347031860351562500e02, + 4.923515930175781250e02, + 6.385730590820312500e02, + 6.180136718750000000e02, + 6.766895141601562500e02, + 6.880822143554687500e02, + 7.643150634765625000e02, + 1.010810485839843750e03, + 1.053173461914062500e03, + 1.042876708984375000e03, + 9.570091552734375000e02, + 1.054303710937500000e03, + 1.118424682617187500e03, + 1.089805908203125000e03, + 9.661757812500000000e02, + 7.852968139648437500e02, + 7.608903808593750000e02, + 4.912100524902343750e02, + 3.803995361328125000e02, + 3.586986389160156250e02, + 3.667009277343750000e02, + 3.461187133789062500e02, + 3.609817199707031250e02, + 3.449771728515625000e02, + 3.404109497070312500e02, + 4.843721313476562500e02, + 6.560958862304687500e02, + 6.214497680664062500e02, + 6.937899780273437500e02, + 7.357191772460937500e02, + 7.421917724609375000e02, + 9.673173217773437500e02, + 9.375456542968750000e02, + 8.986301269531250000e02, + 9.787671508789062500e02, + 1.026849365234375000e03, + 9.684703369140625000e02, + 9.066552734375000000e02, + 8.253767089843750000e02, + 7.277055053710937500e02, + 7.002625732421875000e02, + 4.809360656738281250e02, + 3.712785339355468750e02, + 3.518378906250000000e02, + 3.552625427246093750e02, + 3.426940612792968750e02, + 3.438470458984375000e02, + 3.324201049804687500e02, + 3.278538818359375000e02, + 4.923401794433593750e02, + 6.351483764648437500e02, + 6.054566040039062500e02, + 6.796917724609375000e02, + 6.983789672851562500e02, + 7.273287963867187500e02, + 9.192237548828125000e02, + 9.512785644531250000e02, + 9.547260131835937500e02, + 9.169406127929687500e02, + 1.023401855468750000e03, + 1.011940612792968750e03, + 1.018824218750000000e03, + 9.146575317382812500e02, + 8.070548095703125000e02, + 7.444863281250000000e02, + 5.072031860351562500e02, + 3.998287658691406250e02, + 3.575570678710937500e02, + 3.598401794433593750e02, + 3.438356018066406250e02, + 3.472716979980468750e02, + 3.381278686523437500e02, + 3.426940612792968750e02, + 6.206963500976562500e02, + 5.837442626953125000e02, + 6.111643676757812500e02, + 7.471689453125000000e02, + 7.593721313476562500e02, + 8.070548095703125000e02, + 9.604451904296875000e02, + 1.025707763671875000e03, + 9.741781005859375000e02, + 9.364155273437500000e02, + 1.019965759277343750e03, + 1.085216918945312500e03, + 9.970662231445312500e02, + 8.929109497070312500e02, + 7.597374267578125000e02, + 7.505822143554687500e02, + 5.083447570800781250e02, + 3.986872253417968750e02, + 3.689954223632812500e02, + 3.769862976074218750e02, + 3.506849365234375000e02, + 3.552739868164062500e02, + 3.404223632812500000e02, + 3.312785339355468750e02, + 5.711757812500000000e02, + 5.403310546875000000e02, + 6.317237548828125000e02, + 7.059703369140625000e02, + 7.517579956054687500e02, + 8.024657592773437500e02, + 1.046301391601562500e03, + 1.109269409179687500e03, + 1.121860717773437500e03, + 1.015388122558593750e03, + 1.094383544921875000e03, + 1.151620971679687500e03, + 1.062317382812500000e03, + 1.077191772460937500e03, + 8.036187133789062500e02, + 7.216096191406250000e02, + 5.003424682617187500e02, + 3.964041137695312500e02, + 3.735616455078125000e02, + ], + [ + 2.263374519348144531e01, + 2.932098770141601562e01, + 2.932098770141601562e01, + 2.949245452880859375e01, + 2.880658531188964844e01, + 2.657750320434570312e01, + 2.263374519348144531e01, + 1.886145401000976562e01, + 2.434842300415039062e01, + 3.137860107421875000e01, + 2.400548744201660156e01, + 2.897805213928222656e01, + 3.377914810180664062e01, + 3.360768127441406250e01, + 3.240740585327148438e01, + 3.275034332275390625e01, + 2.880658531188964844e01, + 2.743484306335449219e01, + 3.755144119262695312e01, + 3.635116577148437500e01, + 3.703703689575195312e01, + 4.098079681396484375e01, + 3.275034332275390625e01, + 2.897805213928222656e01, + 3.086419677734375000e01, + 2.726337432861328125e01, + 1.920438957214355469e01, + 2.777777862548828125e01, + 2.846364974975585938e01, + 2.331961631774902344e01, + 1.406035709381103516e01, + 2.640603637695312500e01, + 2.537722969055175781e01, + 2.572016525268554688e01, + 3.240740585327148438e01, + 2.743484306335449219e01, + 2.897805213928222656e01, + 2.726337432861328125e01, + 2.006172752380371094e01, + 3.069272994995117188e01, + 3.000685882568359375e01, + 2.503429412841796875e01, + 3.000685882568359375e01, + 3.858024597167968750e01, + 3.789437484741210938e01, + 3.240740585327148438e01, + 3.480795669555664062e01, + 2.589163208007812500e01, + 2.263374519348144531e01, + 1.954732513427734375e01, + 1.920438957214355469e01, + 1.817558288574218750e01, + 2.006172752380371094e01, + 1.543209838867187500e01, + 1.406035709381103516e01, + 2.280521202087402344e01, + 3.275034332275390625e01, + 3.412208557128906250e01, + 3.360768127441406250e01, + 3.446501922607421875e01, + 3.515089035034179688e01, + 3.395061874389648438e01, + 2.846364974975585938e01, + 2.949245452880859375e01, + 3.017832565307617188e01, + 2.777777862548828125e01, + 3.446501922607421875e01, + 4.526749038696289062e01, + 3.943758392333984375e01, + 4.818244171142578125e01, + 3.737997436523437500e01, + 3.480795669555664062e01, + 2.400548744201660156e01, + 2.623456764221191406e01, + 2.760630989074707031e01, + 2.846364974975585938e01, + 2.932098770141601562e01, + 1.851851844787597656e01, + 2.383401870727539062e01, + 2.914951896667480469e01, + 2.897805213928222656e01, + 2.709190750122070312e01, + 3.343621444702148438e01, + 3.292181015014648438e01, + 3.772290802001953125e01, + 3.858024597167968750e01, + 3.326474761962890625e01, + 2.589163208007812500e01, + 3.172153663635253906e01, + 3.377914810180664062e01, + 4.406721496582031250e01, + 4.749657058715820312e01, + 4.406721496582031250e01, + 3.446501922607421875e01, + 3.960905456542968750e01, + 3.463648986816406250e01, + 3.189300346374511719e01, + 3.034979438781738281e01, + 3.069272994995117188e01, + 2.880658531188964844e01, + 2.897805213928222656e01, + 2.143346977233886719e01, + 2.331961631774902344e01, + 2.520576095581054688e01, + 2.109053421020507812e01, + 2.777777862548828125e01, + 3.000685882568359375e01, + 3.017832565307617188e01, + 3.343621444702148438e01, + 3.223593902587890625e01, + 2.503429412841796875e01, + 3.446501922607421875e01, + 3.343621444702148438e01, + 2.812071418762207031e01, + 2.897805213928222656e01, + 3.995198822021484375e01, + 3.377914810180664062e01, + 3.515089035034179688e01, + 3.703703689575195312e01, + 2.897805213928222656e01, + 2.366255187988281250e01, + 2.812071418762207031e01, + 2.760630989074707031e01, + 2.143346977233886719e01, + 2.743484306335449219e01, + 2.589163208007812500e01, + 2.177640533447265625e01, + 2.057613182067871094e01, + 3.292181015014648438e01, + 2.486282539367675781e01, + 2.589163208007812500e01, + 4.200960159301757812e01, + 3.086419677734375000e01, + 3.412208557128906250e01, + 3.840877914428710938e01, + 3.155006790161132812e01, + 3.600823211669921875e01, + 3.669410324096679688e01, + 3.446501922607421875e01, + 3.720850372314453125e01, + 4.406721496582031250e01, + 3.326474761962890625e01, + 3.532236099243164062e01, + 3.858024597167968750e01, + 3.275034332275390625e01, + 2.194787406921386719e01, + 2.897805213928222656e01, + 2.863511657714843750e01, + 2.469135856628417969e01, + 2.331961631774902344e01, + 2.554869651794433594e01, + 2.143346977233886719e01, + 2.486282539367675781e01, + 2.743484306335449219e01, + 2.383401870727539062e01, + 3.069272994995117188e01, + 2.794924545288085938e01, + 2.812071418762207031e01, + 2.434842300415039062e01, + 1.989026069641113281e01, + 1.971879196166992188e01, + 2.194787406921386719e01, + 2.486282539367675781e01, + 3.360768127441406250e01, + 3.343621444702148438e01, + 3.206447219848632812e01, + 3.034979438781738281e01, + 2.349108314514160156e01, + ], + [ + 1.214824981689453125e02, + 1.192518844604492188e02, + 1.103294448852539062e02, + 1.139327392578125000e02, + 1.137611541748046875e02, + 1.456760406494140625e02, + 1.597460479736328125e02, + 1.443033599853515625e02, + 1.691832580566406250e02, + 1.798215484619140625e02, + 1.918325347900390625e02, + 1.990391235351562500e02, + 1.932052154541015625e02, + 2.007549743652343750e02, + 2.069320526123046875e02, + 2.084763183593750000e02, + 2.134523010253906250e02, + 2.148249816894531250e02, + 2.132807159423828125e02, + 2.143102264404296875e02, + 2.086479034423828125e02, + 1.935483856201171875e02, + 1.743308105468750000e02, + 1.717570343017578125e02, + 1.662662963867187500e02, + 1.657515411376953125e02, + 1.580301971435546875e02, + 1.626630096435546875e02, + 1.640356903076171875e02, + 1.841111907958984375e02, + 2.000686340332031250e02, + 1.853122863769531250e02, + 2.057309570312500000e02, + 2.074468078613281250e02, + 2.151681518554687500e02, + 2.095058288574218750e02, + 2.088194885253906250e02, + 2.156829071044921875e02, + 2.189430389404296875e02, + 2.215168151855468750e02, + 2.287234039306640625e02, + 2.292381591796875000e02, + 2.230610809326171875e02, + 2.177419281005859375e02, + 2.115648651123046875e02, + 1.921757049560546875e02, + 1.741592254638671875e02, + 1.535689697265625000e02, + 1.590597076416015625e02, + 1.606039733886718750e02, + 1.568291015625000000e02, + 1.571722717285156250e02, + 1.575154418945312500e02, + 1.911461944580078125e02, + 1.866849670410156250e02, + 1.885724029541015625e02, + 2.033287506103515625e02, + 2.115648651123046875e02, + 2.316403503417968750e02, + 2.280370635986328125e02, + 2.302676696777343750e02, + 2.395332946777343750e02, + 2.343857269287109375e02, + 2.278654785156250000e02, + 2.285518188476562500e02, + 2.268359680175781250e02, + 2.184282836914062500e02, + 2.112216949462890625e02, + 2.155113220214843750e02, + 2.148249816894531250e02, + 1.988675384521484375e02, + 1.842827758789062500e02, + 1.726149597167968750e02, + 1.806794738769531250e02, + 1.762182617187500000e02, + 1.774193572998046875e02, + 1.835964355468750000e02, + 1.964653472900390625e02, + 2.240905914306640625e02, + 1.781056976318359375e02, + 1.878860626220703125e02, + 1.969801025390625000e02, + 2.091626586914062500e02, + 2.098489990234375000e02, + 2.163692474365234375e02, + 2.210020599365234375e02, + 2.251201171875000000e02, + 2.242621765136718750e02, + 2.196293792724609375e02, + 2.276938934326171875e02, + 2.225463256835937500e02, + 2.210020599365234375e02, + 2.047014465332031250e02, + 1.944063110351562500e02, + 1.722717895507812500e02, + 1.523678741455078125e02, + 1.559711761474609375e02, + 1.551132507324218750e02, + 1.513383636474609375e02, + 1.544269104003906250e02, + 1.530542144775390625e02, + 1.871997222900390625e02, + 1.799931335449218750e02, + 1.717570343017578125e02, + 1.935483856201171875e02, + 1.969801025390625000e02, + 2.067604675292968750e02, + 2.228894958496093750e02, + 2.105353393554687500e02, + 2.149965667724609375e02, + 2.024708251953125000e02, + 2.409059753417968750e02, + 2.491420745849609375e02, + 2.283802337646484375e02, + 2.213452301025390625e02, + 2.198009643554687500e02, + 2.007549743652343750e02, + 1.859986267089843750e02, + 1.739876403808593750e02, + 1.583733673095703125e02, + 1.376115264892578125e02, + 1.377831115722656250e02, + 1.362388458251953125e02, + 1.353809204101562500e02, + 1.431022644042968750e02, + 1.695264282226562500e02, + 1.822237548828125000e02, + 1.703843536376953125e02, + 1.896019287109375000e02, + 1.983527832031250000e02, + 2.048730316162109375e02, + 2.052162017822265625e02, + 2.062457122802734375e02, + 2.081331481933593750e02, + 2.047014465332031250e02, + 2.149965667724609375e02, + 2.047014465332031250e02, + 2.055593719482421875e02, + 2.227179107666015625e02, + 2.113932800292968750e02, + 1.998970489501953125e02, + 1.817089843750000000e02, + 1.657515411376953125e02, + 1.499656829833984375e02, + 1.364104309082031250e02, + 1.312628631591796875e02, + 1.283459167480468750e02, + 1.250857925415039062e02, + 1.264584732055664062e02, + 1.654083709716796875e02, + 1.798215484619140625e02, + 1.643788604736328125e02, + 1.904598541259765625e02, + 1.921757049560546875e02, + 2.069320526123046875e02, + 2.215168151855468750e02, + 2.185998687744140625e02, + 2.246053466796875000e02, + 2.259780426025390625e02, + 2.304392547607421875e02, + 2.283802337646484375e02, + 2.264927978515625000e02, + 2.222031555175781250e02, + 2.115648651123046875e02, + 2.010981445312500000e02, + 1.944063110351562500e02, + 1.676389770507812500e02, + 1.460192108154296875e02, + ], + [ + 6.946183013916015625e01, + 6.226533126831054688e01, + 6.007509231567382812e01, + 5.913642120361328125e01, + 5.788485717773437500e01, + 6.289111328125000000e01, + 4.787234115600585938e01, + 5.287859725952148438e01, + 5.287859725952148438e01, + 6.038798522949218750e01, + 6.570713043212890625e01, + 7.196495819091796875e01, + 7.352941131591796875e01, + 6.007509231567382812e01, + 5.851063919067382812e01, + 6.695870208740234375e01, + 6.289111328125000000e01, + 7.259073638916015625e01, + 9.355444335937500000e01, + 8.573216247558593750e01, + 9.793492126464843750e01, + 1.013767242431640625e02, + 1.004380493164062500e02, + 7.884856414794921875e01, + 6.727159118652343750e01, + 6.351689529418945312e01, + 6.351689529418945312e01, + 6.226533126831054688e01, + 6.289111328125000000e01, + 6.821026611328125000e01, + 4.630788421630859375e01, + 5.162703323364257812e01, + 5.037546920776367188e01, + 6.070087432861328125e01, + 6.821026611328125000e01, + 7.446808624267578125e01, + 7.415519714355468750e01, + 6.758448028564453125e01, + 6.539424133300781250e01, + 6.602002716064453125e01, + 6.226533126831054688e01, + 7.133917236328125000e01, + 8.854818725585937500e01, + 8.823529052734375000e01, + 1.041927413940429688e02, + 1.029411773681640625e02, + 9.699624633789062500e01, + 7.634542846679687500e01, + 6.758448028564453125e01, + 6.445556640625000000e01, + 6.351689529418945312e01, + 6.101376724243164062e01, + 5.976220321655273438e01, + 6.789736938476562500e01, + 5.225281524658203125e01, + 4.724655914306640625e01, + 5.131414413452148438e01, + 6.508135223388671875e01, + 7.478097534179687500e01, + 7.384230041503906250e01, + 6.163954925537109375e01, + 6.476846313476562500e01, + 6.414267730712890625e01, + 5.694618225097656250e01, + 6.508135223388671875e01, + 6.914893341064453125e01, + 7.947434234619140625e01, + 8.604505920410156250e01, + 1.001251602172851562e02, + 9.793492126464843750e01, + 9.167709350585937500e01, + 7.790988922119140625e01, + 6.789736938476562500e01, + 6.633291625976562500e01, + 6.351689529418945312e01, + 6.226533126831054688e01, + 6.195244216918945312e01, + 6.476846313476562500e01, + 5.225281524658203125e01, + 5.068836212158203125e01, + 5.413016128540039062e01, + 5.757196426391601562e01, + 7.352941131591796875e01, + 7.321652221679687500e01, + 7.196495819091796875e01, + 6.163954925537109375e01, + 5.882352828979492188e01, + 6.414267730712890625e01, + 6.476846313476562500e01, + 7.790988922119140625e01, + 9.824781036376953125e01, + 8.604505920410156250e01, + 1.026282882690429688e02, + 1.091989974975585938e02, + 1.073216552734375000e02, + 8.541927337646484375e01, + 7.415519714355468750e01, + 6.852315521240234375e01, + 6.289111328125000000e01, + 6.070087432861328125e01, + 6.101376724243164062e01, + 6.414267730712890625e01, + 5.256570816040039062e01, + 4.943679428100585938e01, + 6.070087432861328125e01, + 7.571965026855468750e01, + 8.792240142822265625e01, + 8.698372650146484375e01, + 7.822277832031250000e01, + 7.133917236328125000e01, + 7.227784729003906250e01, + 6.602002716064453125e01, + 6.883604431152343750e01, + 6.977471923828125000e01, + 7.978723144531250000e01, + 9.042552947998046875e01, + 1.057571945190429688e02, + 1.007509384155273438e02, + 9.762202453613281250e01, + 8.479349517822265625e01, + 7.165206146240234375e01, + 6.821026611328125000e01, + 6.226533126831054688e01, + 6.195244216918945312e01, + 7.478097534179687500e01, + 5.882352828979492188e01, + 4.974968719482421875e01, + 5.381727218627929688e01, + 6.007509231567382812e01, + 7.259073638916015625e01, + 7.603253936767578125e01, + 8.573216247558593750e01, + 8.698372650146484375e01, + 8.041301727294921875e01, + 7.853566741943359375e01, + 8.197747039794921875e01, + 7.603253936767578125e01, + 8.948686218261718750e01, + 8.948686218261718750e01, + 9.230287933349609375e01, + 9.230287933349609375e01, + 9.511889648437500000e01, + 9.543179321289062500e01, + 8.823529052734375000e01, + 8.166458129882812500e01, + 7.133917236328125000e01, + 6.195244216918945312e01, + 6.038798522949218750e01, + 5.976220321655273438e01, + 6.883604431152343750e01, + 4.849812316894531250e01, + 5.319149017333984375e01, + 6.382978820800781250e01, + 7.634542846679687500e01, + 9.261576843261718750e01, + 1.001251602172851562e02, + 8.573216247558593750e01, + 6.539424133300781250e01, + 5.600751113891601562e01, + 5.694618225097656250e01, + 5.538172531127929688e01, + 6.382978820800781250e01, + 7.790988922119140625e01, + 8.197747039794921875e01, + 1.029411773681640625e02, + 1.135794754028320312e02, + 1.004380493164062500e02, + 8.479349517822265625e01, + ], + [ + 4.069791030883789062e01, + 4.158940505981445312e01, + 3.891365432739257812e01, + 3.942435073852539062e01, + 4.057055664062500000e01, + 6.334054946899414062e01, + 5.866785430908203125e01, + 6.113219451904296875e01, + 7.438232421875000000e01, + 8.779418945312500000e01, + 9.533239746093750000e01, + 9.737519073486328125e01, + 9.609780883789062500e01, + 9.724783325195312500e01, + 9.839658355712890625e01, + 1.000573120117187500e02, + 1.004406509399414062e02, + 1.027394256591796875e02, + 9.839658355712890625e01, + 1.037621002197265625e02, + 9.520503997802734375e01, + 7.690779113769531250e01, + 4.783494567871093750e01, + 4.464849853515625000e01, + 4.490448379516601562e01, + 4.630667495727539062e01, + 4.362837600708007812e01, + 4.528655242919921875e01, + 4.452114105224609375e01, + 6.675624084472656250e01, + 7.067881011962890625e01, + 6.830870819091796875e01, + 7.859780883789062500e01, + 9.124427032470703125e01, + 9.673586273193359375e01, + 9.367167663574218750e01, + 9.852394104003906250e01, + 9.699057769775390625e01, + 9.647988128662109375e01, + 1.008239974975585938e02, + 1.017180328369140625e02, + 1.015906753540039062e02, + 1.017193069458007812e02, + 9.545848083496093750e01, + 9.341441345214843750e01, + 7.792536926269531250e01, + 5.165945053100585938e01, + 4.707080841064453125e01, + 4.719816589355468750e01, + 4.643275451660156250e01, + 4.579597473144531250e01, + 4.656138610839843750e01, + 4.732552337646484375e01, + 7.882450103759765625e01, + 7.118950653076171875e01, + 7.361691284179687500e01, + 8.600611114501953125e01, + 1.018466644287109375e02, + 9.929190063476562500e01, + 1.051681137084960938e02, + 1.092536926269531250e02, + 1.017180328369140625e02, + 1.055514526367187500e02, + 1.087442703247070312e02, + 1.073382568359375000e02, + 1.014620513916015625e02, + 1.061895065307617188e02, + 9.954534149169921875e01, + 9.418109893798828125e01, + 7.562786865234375000e01, + 4.923586273193359375e01, + 4.630540084838867188e01, + 4.643275451660156250e01, + 4.643275451660156250e01, + 4.477585220336914062e01, + 4.541263198852539062e01, + 4.388181304931640625e01, + 6.921293640136718750e01, + 7.013372039794921875e01, + 6.910977935791015625e01, + 8.690015411376953125e01, + 1.012073364257812500e02, + 1.023560867309570312e02, + 1.006953659057617188e02, + 9.967396545410156250e01, + 9.980132293701171875e01, + 1.022287292480468750e02, + 1.095096817016601562e02, + 1.096370315551757812e02, + 1.024847183227539062e02, + 1.031240463256835938e02, + 1.026120758056640625e02, + 9.277508544921875000e01, + 7.511716461181640625e01, + 4.872771453857421875e01, + 4.273560714721679688e01, + 4.311894989013671875e01, + 4.107997894287109375e01, + 3.980641937255859375e01, + 4.006113052368164062e01, + 3.853158569335937500e01, + 6.669383239746093750e01, + 6.282093811035156250e01, + 6.205807495117187500e01, + 8.000127410888671875e01, + 9.022160339355468750e01, + 9.839531707763671875e01, + 9.941798400878906250e01, + 1.023573608398437500e02, + 1.008239974975585938e02, + 1.035061111450195312e02, + 1.044001541137695312e02, + 1.088716278076171875e02, + 1.124477844238281250e02, + 1.075929718017578125e02, + 1.008239974975585938e02, + 9.622644042968750000e01, + 9.073229980468750000e01, + 6.257386779785156250e01, + 4.311894989013671875e01, + 4.299032211303710938e01, + 4.273815536499023438e01, + 4.095262527465820312e01, + 4.082526779174804688e01, + 4.146204757690429688e01, + 7.409832000732421875e01, + 6.358507537841796875e01, + 6.754330444335937500e01, + 8.396331787109375000e01, + 9.405374145507812500e01, + 1.013346939086914062e02, + 1.063155899047851562e02, + 1.123204269409179688e02, + 1.139798812866210938e02, + 1.114251174926757812e02, + 1.038894577026367188e02, + 1.175560379028320312e02, + 1.093823242187500000e02, + 1.097656631469726562e02, + 1.035061111450195312e02, + 9.890728759765625000e01, + 9.341568756103515625e01, + 6.959373474121093750e01, + 4.503183746337890625e01, + 4.515919494628906250e01, + 4.388308715820312500e01, + 4.069791030883789062e01, + 4.082526779174804688e01, + 4.069791030883789062e01, + 6.576541137695312500e01, + 6.329724884033203125e01, + 6.716250610351562500e01, + 8.204534149169921875e01, + 9.405374145507812500e01, + 1.003120193481445312e02, + 1.045287857055664062e02, + 1.082335739135742188e02, + 1.067002029418945312e02, + 1.054215469360351562e02, + 1.070835418701171875e02, + 1.100216522216796875e02, + 1.037608261108398438e02, + 9.865257263183593750e01, + 9.775852966308593750e01, + 9.711793518066406250e01, + 8.137290191650390625e01, + 4.974656295776367188e01, + 4.464849853515625000e01, + ], + [ + 1.140194625854492188e02, + 1.181511993408203125e02, + 1.222866744995117188e02, + 1.147717056274414062e02, + 1.102619781494140625e02, + 1.170284423828125000e02, + 1.421856231689453125e02, + 1.713136291503906250e02, + 1.703555450439453125e02, + 1.906886291503906250e02, + 2.291916198730468750e02, + 2.488061370849609375e02, + 2.491841278076171875e02, + 2.295696105957031250e02, + 2.412574920654296875e02, + 2.488061370849609375e02, + 2.601235046386718750e02, + 2.408869781494140625e02, + 2.397492523193359375e02, + 2.078667602539062500e02, + 1.358046417236328125e02, + 1.249139251708984375e02, + 1.219124221801757812e02, + 1.155202102661132812e02, + 1.162761993408203125e02, + 1.196519470214843750e02, + 1.245359268188476562e02, + 1.189034423828125000e02, + 1.162761993408203125e02, + 1.410591278076171875e02, + 1.515793457031250000e02, + 2.093974609375000000e02, + 2.220247039794921875e02, + 2.397567291259765625e02, + 2.518263397216796875e02, + 2.555950622558593750e02, + 2.435254516601562500e02, + 2.242926635742187500e02, + 1.980613708496093750e02, + 2.208907165527343750e02, + 2.337163238525390625e02, + 2.310778503417968750e02, + 1.957934112548828125e02, + 1.893824920654296875e02, + 1.403068847656250000e02, + 1.264146728515625000e02, + 1.207821884155273438e02, + 1.155239486694335938e02, + 1.166504516601562500e02, + 1.189071884155273438e02, + 1.252881698608398438e02, + 1.192814407348632812e02, + 1.158982009887695312e02, + 1.414371185302734375e02, + 1.583345794677734375e02, + 2.112799377441406250e02, + 2.442814331054687500e02, + 2.499401245117187500e02, + 2.529528503417968750e02, + 2.495583801269531250e02, + 2.552208099365234375e02, + 2.488061370849609375e02, + 2.465456542968750000e02, + 2.503181152343750000e02, + 2.552208099365234375e02, + 2.559693145751953125e02, + 2.484281463623046875e02, + 2.325860748291015625e02, + 1.423989562988281250e02, + 1.241654205322265625e02, + 1.192739486694335938e02, + 1.155202102661132812e02, + 1.158944625854492188e02, + 1.185254516601562500e02, + 1.256624221801757812e02, + 1.189034423828125000e02, + 1.155239486694335938e02, + 1.403106231689453125e02, + 1.538323364257812500e02, + 2.082634735107421875e02, + 2.152357788085937500e02, + 2.340943145751953125e02, + 2.457896728515625000e02, + 2.461676635742187500e02, + 2.337163238525390625e02, + 2.363622741699218750e02, + 2.261751556396484375e02, + 2.431474609375000000e02, + 2.427694549560546875e02, + 2.205127258300781250e02, + 1.759880218505859375e02, + 1.863398132324218750e02, + 1.354266510009765625e02, + 1.200299377441406250e02, + 1.158982009887695312e02, + 1.110142211914062500e02, + 1.098877258300781250e02, + 1.125187149047851562e02, + 1.181549377441406250e02, + 1.113884735107421875e02, + 1.065119781494140625e02, + 1.335516510009765625e02, + 1.425636291503906250e02, + 1.694386291503906250e02, + 1.718562927246093750e02, + 1.916429595947265625e02, + 2.276871185302734375e02, + 2.446556854248046875e02, + 2.371145172119140625e02, + 2.333420715332031250e02, + 2.186302337646484375e02, + 2.431474609375000000e02, + 2.593675231933593750e02, + 2.529565887451171875e02, + 2.537088317871093750e02, + 2.352245483398437500e02, + 1.369311370849609375e02, + 1.219049377441406250e02, + 1.173989486694335938e02, + 1.132709579467773438e02, + 1.125187149047851562e02, + 1.151459579467773438e02, + 1.204079360961914062e02, + 1.140194625854492188e02, + 1.102619781494140625e02, + 1.327994079589843750e02, + 1.433121185302734375e02, + 1.943188629150390625e02, + 1.980651245117187500e02, + 2.193824920654296875e02, + 2.363622741699218750e02, + 2.540868225097656250e02, + 2.544648132324218750e02, + 2.593712463378906250e02, + 2.537125701904296875e02, + 2.457896728515625000e02, + 2.578592834472656250e02, + 2.601235046386718750e02, + 2.533308410644531250e02, + 2.446594238281250000e02, + 1.469086761474609375e02, + 1.222829360961914062e02, + 1.185291900634765625e02, + 1.147717056274414062e02, + 1.136414642333984375e02, + 1.151459579467773438e02, + 1.215344314575195312e02, + 1.136414642333984375e02, + 1.102619781494140625e02, + 1.331736602783203125e02, + 1.436901245117187500e02, + 1.984767913818359375e02, + 1.948652648925781250e02, + 2.276796417236328125e02, + 2.597492370605468750e02, + 2.631399841308593750e02, + 2.593675231933593750e02, + 2.604977416992187500e02, + 2.805651245117187500e02, + 2.882223205566406250e02, + 2.730014953613281250e02, + 2.786601867675781250e02, + 2.714932556152343750e02, + 2.461676635742187500e02, + 1.472866821289062500e02, + 1.256661682128906250e02, + 1.222829360961914062e02, + 1.166504516601562500e02, + ], + [ + 1.792779235839843750e02, + 1.680313415527343750e02, + 1.694005432128906250e02, + 1.666689300537109375e02, + 1.663317413330078125e02, + 1.772343292236328125e02, + 2.585524597167968750e02, + 2.466008148193359375e02, + 2.790497131347656250e02, + 2.776839294433593750e02, + 2.957867736816406250e02, + 3.036410217285156250e02, + 2.995436096191406250e02, + 3.166212463378906250e02, + 3.152520446777343750e02, + 3.149114379882812500e02, + 3.258412780761718750e02, + 3.009093933105468750e02, + 3.087636108398437500e02, + 3.152520446777343750e02, + 3.179870605468750000e02, + 3.080824279785156250e02, + 2.722173156738281250e02, + 2.232629394531250000e02, + 2.017779235839843750e02, + 1.908719329833984375e02, + 1.860967254638671875e02, + 1.802997283935546875e02, + 1.782561340332031250e02, + 1.860933227539062500e02, + 2.691450805664062500e02, + 2.848569335937500000e02, + 3.053474121093750000e02, + 3.067132263183593750e02, + 3.442847290039062500e02, + 3.296015014648437500e02, + 3.268699035644531250e02, + 3.381403198242187500e02, + 3.203746643066406250e02, + 3.319856872558593750e02, + 3.405279235839843750e02, + 3.094482421875000000e02, + 3.073978271484375000e02, + 3.220844726562500000e02, + 3.265258789062500000e02, + 3.121764221191406250e02, + 2.705075073242187500e02, + 2.285047760009765625e02, + 2.080279235839843750e02, + 1.952997283935546875e02, + 1.932561340332031250e02, + 1.891689300537109375e02, + 1.867813415527343750e02, + 1.939373321533203125e02, + 2.598126831054687500e02, + 2.913453674316406250e02, + 3.084230346679687500e02, + 3.128610229492187500e02, + 3.299421081542968750e02, + 3.436069335937500000e02, + 3.336954956054687500e02, + 3.504325561523437500e02, + 3.330143127441406250e02, + 3.220810546875000000e02, + 3.391621398925781250e02, + 3.183310546875000000e02, + 3.050068054199218750e02, + 3.261818847656250000e02, + 3.299421081542968750e02, + 3.125204467773437500e02, + 2.691450805664062500e02, + 2.244005432128906250e02, + 2.059877319335937500e02, + 1.949625396728515625e02, + 1.925749359130859375e02, + 1.888283386230468750e02, + 1.884877319335937500e02, + 2.018937377929687500e02, + 2.729019165039062500e02, + 2.940769653320312500e02, + 3.244788818359375000e02, + 3.173024597167968750e02, + 3.381403198242187500e02, + 3.272104797363281250e02, + 3.466791687011718750e02, + 3.354053039550781250e02, + 3.289169006347656250e02, + 3.401873168945312500e02, + 3.425749206542968750e02, + 3.149114379882812500e02, + 3.039850158691406250e02, + 3.227656555175781250e02, + 3.237908630371093750e02, + 3.128610229492187500e02, + 2.708480834960937500e02, + 2.315735626220703125e02, + 1.884809265136718750e02, + 1.803031311035156250e02, + 1.813249359130859375e02, + 1.809809265136718750e02, + 1.799591217041015625e02, + 2.236273803710937500e02, + 2.558242492675781250e02, + 2.780245361328125000e02, + 3.012465820312500000e02, + 3.026158142089843750e02, + 3.309673156738281250e02, + 3.227690734863281250e02, + 3.282322998046875000e02, + 3.265258789062500000e02, + 3.357459106445312500e02, + 3.347207031250000000e02, + 3.381403198242187500e02, + 3.149148559570312500e02, + 3.162806396484375000e02, + 3.237908630371093750e02, + 3.097922363281250000e02, + 3.036410217285156250e02, + 2.612874755859375000e02, + 2.131403198242187500e02, + 1.905313415527343750e02, + 1.799625396728515625e02, + 1.782561340332031250e02, + 1.799591217041015625e02, + 1.779155273437500000e02, + 2.352384185791015625e02, + 2.548024597167968750e02, + 2.619686584472656250e02, + 2.992030029296875000e02, + 2.961273803710937500e02, + 3.138862304687500000e02, + 3.343800964355468750e02, + 3.330143127441406250e02, + 3.227690734863281250e02, + 3.152520446777343750e02, + 3.234502868652343750e02, + 3.323330993652343750e02, + 3.145708312988281250e02, + 3.002247924804687500e02, + 3.142302551269531250e02, + 3.193528747558593750e02, + 3.036376037597656250e02, + 2.551464538574218750e02, + 2.176907348632812500e02, + 1.946185302734375000e02, + 1.799625396728515625e02, + 1.802997283935546875e02, + 1.772343292236328125e02, + 1.809809265136718750e02, + 2.389918212890625000e02, + 2.565054626464843750e02, + 2.534298400878906250e02, + 2.937363891601562500e02, + 2.971525878906250000e02, + 3.091076354980468750e02, + 3.309638977050781250e02, + 3.336954956054687500e02, + 3.323297119140625000e02, + 3.251566772460937500e02, + 3.504359741210937500e02, + 3.449693603515625000e02, + 3.149114379882812500e02, + 3.104734191894531250e02, + 3.268630676269531250e02, + 3.268664855957031250e02, + 3.094482421875000000e02, + 2.691450805664062500e02, + 2.232697601318359375e02, + ], + [ + 3.947381896972656250e02, + 3.778378295898437500e02, + 3.702280273437500000e02, + 3.660050659179687500e02, + 5.198479614257812500e02, + 5.274493408203125000e02, + 5.730996704101562500e02, + 5.959121704101562500e02, + 6.664611206054687500e02, + 7.358361206054687500e02, + 7.853547363281250000e02, + 8.268750000000000000e02, + 8.395861206054687500e02, + 8.141723022460937500e02, + 7.938428955078125000e02, + 8.760134887695312500e02, + 8.827871704101562500e02, + 8.395861206054687500e02, + 7.845185546875000000e02, + 7.489273681640625000e02, + 7.396115112304687500e02, + 5.734966430664062500e02, + 4.420692443847656250e02, + 4.124830932617187500e02, + 4.260135192871093750e02, + 4.116385192871093750e02, + 4.048817443847656250e02, + 4.031925659179687500e02, + 5.739357910156250000e02, + 5.967567749023437500e02, + 6.246536865234375000e02, + 6.195861206054687500e02, + 6.803969726562500000e02, + 7.751942749023437500e02, + 8.345017089843750000e02, + 7.743496704101562500e02, + 8.057009887695312500e02, + 8.124746704101562500e02, + 8.209459228515625000e02, + 8.539949340820312500e02, + 8.624661865234375000e02, + 8.192482910156250000e02, + 7.955321044921875000e02, + 7.692736206054687500e02, + 7.235134887695312500e02, + 5.726942749023437500e02, + 4.370101318359375000e02, + 4.133446044921875000e02, + 4.243243103027343750e02, + 4.133361511230468750e02, + 4.175675659179687500e02, + 4.226351318359375000e02, + 6.288598022460937500e02, + 6.711232910156250000e02, + 6.356503295898437500e02, + 6.880321044921875000e02, + 7.989189453125000000e02, + 8.972044067382812500e02, + 9.039780273437500000e02, + 8.709290771484375000e02, + 8.870354614257812500e02, + 8.734797363281250000e02, + 8.522973022460937500e02, + 8.921199340820312500e02, + 8.904138793945312500e02, + 8.607685546875000000e02, + 8.175591430664062500e02, + 7.299240112304687500e02, + 7.065709228515625000e02, + 5.743750000000000000e02, + 4.446114807128906250e02, + 3.998057556152343750e02, + 4.184121704101562500e02, + 3.998057556152343750e02, + 4.048733215332031250e02, + 4.065709533691406250e02, + 5.190033569335937500e02, + 5.959121704101562500e02, + 6.178969726562500000e02, + 6.588767089843750000e02, + 7.070354614257812500e02, + 7.896115112304687500e02, + 7.921536865234375000e02, + 7.870607910156250000e02, + 8.260303955078125000e02, + 7.921452636718750000e02, + 8.031588134765625000e02, + 8.683868408203125000e02, + 8.311148681640625000e02, + 7.896030273437500000e02, + 7.565625000000000000e02, + 7.112838134765625000e02, + 6.664696044921875000e02, + 5.337584228515625000e02, + 3.998057556152343750e02, + 3.803716125488281250e02, + 3.854391784667968750e02, + 3.676858215332031250e02, + 3.634628295898437500e02, + 3.685473022460937500e02, + 4.919510192871093750e02, + 5.536571044921875000e02, + 5.578800659179687500e02, + 5.790033569335937500e02, + 6.474830932617187500e02, + 7.463935546875000000e02, + 7.879053955078125000e02, + 8.090878295898437500e02, + 8.073901977539062500e02, + 7.692651977539062500e02, + 8.277280273437500000e02, + 8.700844726562500000e02, + 8.794088134765625000e02, + 8.345101318359375000e02, + 8.667060546875000000e02, + 7.523226318359375000e02, + 7.353800659179687500e02, + 6.554982910156250000e02, + 5.232178955078125000e02, + 3.938935852050781250e02, + 4.116469726562500000e02, + 3.896790466308593750e02, + 3.820608215332031250e02, + 3.854391784667968750e02, + 5.274408569335937500e02, + 6.001351318359375000e02, + 6.212753295898437500e02, + 6.322634887695312500e02, + 7.045017089843750000e02, + 8.387500000000000000e02, + 8.997381591796875000e02, + 8.912669067382812500e02, + 8.988851318359375000e02, + 8.548310546875000000e02, + 8.912669067382812500e02, + 9.107601318359375000e02, + 9.132939453125000000e02, + 8.497550659179687500e02, + 9.056672363281250000e02, + 8.285726318359375000e02, + 7.514780273437500000e02, + 6.580405273437500000e02, + 5.257601318359375000e02, + 3.837500000000000000e02, + 3.981250000000000000e02, + 3.964358215332031250e02, + 3.719172363281250000e02, + 3.829053955078125000e02, + 5.206756591796875000e02, + 5.705574340820312500e02, + 5.553463134765625000e02, + 6.233530273437500000e02, + 6.994172363281250000e02, + 8.336571044921875000e02, + 9.226182250976562500e02, + 9.056672363281250000e02, + 9.276942749023437500e02, + 8.683952636718750000e02, + 8.929560546875000000e02, + 1.004788879394531250e03, + 9.675169067382812500e02, + 9.285473022460937500e02, + 8.827955932617187500e02, + 8.014611206054687500e02, + 7.599493408203125000e02, + 5.938006591796875000e02, + 4.463006896972656250e02, + 3.896790466308593750e02, + ], + [ + 4.006647109985351562e01, + 3.545051574707031250e01, + 2.289512634277343750e01, + 1.772525787353515625e01, + 1.070901012420654297e01, + 9.231905937194824219e00, + 7.016248226165771484e00, + 1.070901012420654297e01, + 1.920236396789550781e01, + 3.766617584228515625e01, + 4.523633575439453125e01, + 5.409896469116210938e01, + 6.776218414306640625e01, + 6.720827484130859375e01, + 5.816100311279296875e01, + 5.243722152709960938e01, + 5.649925994873046875e01, + 5.243722152709960938e01, + 6.037666320800781250e01, + 6.573117065429687500e01, + 6.831610107421875000e01, + 6.185376739501953125e01, + 5.391432952880859375e01, + 4.911373519897460938e01, + 4.689807891845703125e01, + 2.861890602111816406e01, + 2.437223052978515625e01, + 1.901772499084472656e01, + 1.052437210083007812e01, + 8.124076843261718750e00, + 7.385524272918701172e00, + 1.107828617095947266e01, + 2.621861076354980469e01, + 4.671343994140625000e01, + 5.539143371582031250e01, + 5.889955520629882812e01, + 7.200886535644531250e01, + 6.776218414306640625e01, + 6.093057632446289062e01, + 5.742245101928710938e01, + 5.391432952880859375e01, + 5.612998580932617188e01, + 5.723781204223632812e01, + 6.794682312011718750e01, + 6.443869781494140625e01, + 6.277695846557617188e01, + 6.517725372314453125e01, + 5.225258636474609375e01, + 5.040620422363281250e01, + 3.489660263061523438e01, + 2.381831550598144531e01, + 1.790989685058593750e01, + 1.181683921813964844e01, + 8.493352890014648438e00, + 8.493352890014648438e00, + 1.827917289733886719e01, + 3.175775527954101562e01, + 5.003692626953125000e01, + 5.649925994873046875e01, + 6.739290618896484375e01, + 7.090103149414062500e01, + 6.517725372314453125e01, + 5.889955520629882812e01, + 5.760708999633789062e01, + 5.760708999633789062e01, + 5.594534683227539062e01, + 5.889955520629882812e01, + 7.348596954345703125e01, + 7.256277465820312500e01, + 6.517725372314453125e01, + 5.982274627685546875e01, + 5.631462478637695312e01, + 4.689807891845703125e01, + 3.101920318603515625e01, + 2.511078262329101562e01, + 2.344903945922851562e01, + 1.144756317138671875e01, + 7.570162296295166016e00, + 7.200886249542236328e00, + 1.347858238220214844e01, + 2.566469764709472656e01, + 4.209748840332031250e01, + 4.745199584960937500e01, + 6.425405883789062500e01, + 7.920974731445312500e01, + 6.905464935302734375e01, + 5.539143371582031250e01, + 5.280649948120117188e01, + 5.059084320068359375e01, + 5.077547836303710938e01, + 5.631462478637695312e01, + 6.462333679199218750e01, + 7.163958740234375000e01, + 5.631462478637695312e01, + 5.668389892578125000e01, + 4.412850952148437500e01, + 4.080502319335937500e01, + 3.157311630249023438e01, + 2.234121131896972656e01, + 2.012555313110351562e01, + 1.532496261596679688e01, + 9.047266960144042969e00, + 9.970458030700683594e00, + 1.366322040557861328e01, + 2.677252578735351562e01, + 4.375923156738281250e01, + 5.483751678466796875e01, + 6.351551055908203125e01, + 7.994830322265625000e01, + 7.330133056640625000e01, + 6.240768051147460938e01, + 6.351551055908203125e01, + 5.834564208984375000e01, + 5.631462478637695312e01, + 7.200886535644531250e01, + 7.477843475341796875e01, + 7.071639251708984375e01, + 6.517725372314453125e01, + 5.409896469116210938e01, + 4.966765213012695312e01, + 4.560561370849609375e01, + 3.526588058471679688e01, + 2.307976341247558594e01, + 2.455686759948730469e01, + 1.366322040557861328e01, + 8.862628936767578125e00, + 7.016248226165771484e00, + 1.403249645233154297e01, + 2.437223052978515625e01, + 4.394387054443359375e01, + 5.132939529418945312e01, + 6.277695846557617188e01, + 7.293205261230468750e01, + 6.591580200195312500e01, + 6.351551055908203125e01, + 6.296159362792968750e01, + 6.333087158203125000e01, + 6.628507995605468750e01, + 7.182422637939453125e01, + 7.754800415039062500e01, + 7.644017791748046875e01, + 6.259231948852539062e01, + 5.280649948120117188e01, + 5.409896469116210938e01, + 4.911373519897460938e01, + 3.120384025573730469e01, + 2.474150657653808594e01, + 2.049483108520507812e01, + 1.310930538177490234e01, + 9.601181983947753906e00, + 8.493352890014648438e00, + 1.643279266357421875e01, + 5.096011734008789062e01, + 4.837518310546875000e01, + 4.966765213012695312e01, + 6.333087158203125000e01, + 7.440915679931640625e01, + 7.274741363525390625e01, + 6.683899688720703125e01, + 6.517725372314453125e01, + 6.831610107421875000e01, + 6.702363586425781250e01, + 7.607089996337890625e01, + 7.662481689453125000e01, + 7.828656005859375000e01, + 6.499261474609375000e01, + 6.093057632446289062e01, + 5.040620422363281250e01, + ], + [ + 7.309510040283203125e01, + 7.324276733398437500e01, + 7.250443267822265625e01, + 7.250443267822265625e01, + 7.132309722900390625e01, + 7.265209960937500000e01, + 7.279976654052734375e01, + 7.545776367187500000e01, + 7.649143218994140625e01, + 7.870643615722656250e01, + 8.165977478027343750e01, + 8.254577636718750000e01, + 8.210277557373046875e01, + 8.313644409179687500e01, + 8.313644409179687500e01, + 8.298877716064453125e01, + 8.254577636718750000e01, + 8.225044250488281250e01, + 8.062610626220703125e01, + 8.106910705566406250e01, + 7.959243774414062500e01, + 7.974010467529296875e01, + 7.841110229492187500e01, + 8.033077239990234375e01, + 8.062610626220703125e01, + 8.254577636718750000e01, + 8.136444091796875000e01, + 8.092144012451171875e01, + 8.092144012451171875e01, + 8.033077239990234375e01, + 8.151210784912109375e01, + 8.136444091796875000e01, + 8.121677398681640625e01, + 8.225044250488281250e01, + 8.343177795410156250e01, + 8.313644409179687500e01, + 8.254577636718750000e01, + 8.210277557373046875e01, + 8.343177795410156250e01, + 8.254577636718750000e01, + 8.225044250488281250e01, + 8.151210784912109375e01, + 8.121677398681640625e01, + 8.062610626220703125e01, + 7.974010467529296875e01, + 7.885410308837890625e01, + 7.900177001953125000e01, + 8.018310546875000000e01, + 8.225044250488281250e01, + 8.225044250488281250e01, + 8.269344329833984375e01, + 8.239810943603515625e01, + 8.033077239990234375e01, + 8.136444091796875000e01, + 8.225044250488281250e01, + 8.284111022949218750e01, + 8.269344329833984375e01, + 8.239810943603515625e01, + 8.225044250488281250e01, + 8.328411102294921875e01, + 8.357944488525390625e01, + 8.343177795410156250e01, + 8.594211578369140625e01, + 8.417011260986328125e01, + 8.357944488525390625e01, + 8.269344329833984375e01, + 8.018310546875000000e01, + 8.106910705566406250e01, + 7.841110229492187500e01, + 7.988777160644531250e01, + 7.914943695068359375e01, + 7.988777160644531250e01, + 7.959243774414062500e01, + 8.136444091796875000e01, + 8.195510864257812500e01, + 8.239810943603515625e01, + 8.165977478027343750e01, + 8.180744171142578125e01, + 8.047843933105468750e01, + 8.121677398681640625e01, + 8.151210784912109375e01, + 8.298877716064453125e01, + 8.180744171142578125e01, + 8.313644409179687500e01, + 8.180744171142578125e01, + 8.328411102294921875e01, + 8.328411102294921875e01, + 8.313644409179687500e01, + 8.165977478027343750e01, + 8.151210784912109375e01, + 7.929710388183593750e01, + 7.900177001953125000e01, + 7.826343536376953125e01, + 7.796810150146484375e01, + 7.826343536376953125e01, + 7.708209991455078125e01, + 7.309510040283203125e01, + 7.368576812744140625e01, + 7.324276733398437500e01, + 7.250443267822265625e01, + 7.294743347167968750e01, + 7.206143188476562500e01, + 7.265209960937500000e01, + 7.442410278320312500e01, + 7.575309753417968750e01, + 7.870643615722656250e01, + 8.062610626220703125e01, + 8.106910705566406250e01, + 8.180744171142578125e01, + 8.313644409179687500e01, + 8.254577636718750000e01, + 8.180744171142578125e01, + 8.328411102294921875e01, + 8.254577636718750000e01, + 7.959243774414062500e01, + 7.782043457031250000e01, + 7.841110229492187500e01, + 7.767276763916015625e01, + 7.767276763916015625e01, + 7.722976684570312500e01, + 7.516242980957031250e01, + 7.531009674072265625e01, + 7.619609832763671875e01, + 7.457176971435546875e01, + 7.442410278320312500e01, + 7.398110198974609375e01, + 7.324276733398437500e01, + 7.368576812744140625e01, + 7.398110198974609375e01, + 7.708209991455078125e01, + 7.841110229492187500e01, + 7.885410308837890625e01, + 7.988777160644531250e01, + 7.929710388183593750e01, + 7.900177001953125000e01, + 8.033077239990234375e01, + 7.900177001953125000e01, + 8.062610626220703125e01, + 7.885410308837890625e01, + 7.841110229492187500e01, + 7.988777160644531250e01, + 8.003543853759765625e01, + 7.914943695068359375e01, + 8.018310546875000000e01, + 7.914943695068359375e01, + 7.870643615722656250e01, + 7.560543060302734375e01, + 7.457176971435546875e01, + 7.442410278320312500e01, + 7.442410278320312500e01, + 7.368576812744140625e01, + 7.471943664550781250e01, + 7.412876892089843750e01, + 7.634376525878906250e01, + 8.018310546875000000e01, + 7.855876922607421875e01, + 7.959243774414062500e01, + 7.959243774414062500e01, + 7.855876922607421875e01, + 8.047843933105468750e01, + 7.944477081298828125e01, + 7.900177001953125000e01, + 7.796810150146484375e01, + 7.811576843261718750e01, + 7.974010467529296875e01, + 7.914943695068359375e01, + 7.988777160644531250e01, + 7.974010467529296875e01, + ], + ] + ) diff --git a/test/mx/model/renewal/test_predictor.py b/test/mx/model/renewal/test_predictor.py index f10600c7af..e05be630dd 100644 --- a/test/mx/model/renewal/test_predictor.py +++ b/test/mx/model/renewal/test_predictor.py @@ -43,10 +43,12 @@ [[[0, 0, 0, 0, 0, 0, 0]]], ), ( - [[ - [[3, 1, 2, 3, 1, 1, 1], [3, 5, 4, 1, 1, 1, 1]], - [[3, 1, 2, 3, 1, 1, 1], [3, 5, 4, 1, 1, 1, 1]], - ]], + [ + [ + [[3, 1, 2, 3, 1, 1, 1], [3, 5, 4, 1, 1, 1, 1]], + [[3, 1, 2, 3, 1, 1, 1], [3, 5, 4, 1, 1, 1, 1]], + ] + ], [[[0, 0, 3, 5, 0, 4, 0], [0, 0, 3, 5, 0, 4, 0]]], ), ( @@ -69,38 +71,42 @@ def test_output_transform(input, expected): def test_predictor_smoke_test(): train_ds = ListDataset( - [{ - "target": [ - 100.0, - 63.0, - 83.0, - 126.0, - 115.0, - 92.0, - 57.0, - 95.0, - 94.0, - 92.0, - 142.0, - 35.0, - 116.0, - 78.0, - 64.0, - 141.0, - ], - "start": "2018-01-07 00:00:00", - "feat_static_cat": [0], - }], + [ + { + "target": [ + 100.0, + 63.0, + 83.0, + 126.0, + 115.0, + 92.0, + 57.0, + 95.0, + 94.0, + 92.0, + 142.0, + 35.0, + 116.0, + 78.0, + 64.0, + 141.0, + ], + "start": "2018-01-07 00:00:00", + "feat_static_cat": [0], + } + ], freq="1m", ) test_ds = ListDataset( - [{ - "target": [100.0, 63.0, 83.0, 126.0, 115.0, 92.0, 57.0, 95.0] - + [0] * 15, - "start": "2018-01-07 00:00:00", - "feat_static_cat": [1], - }], + [ + { + "target": [100.0, 63.0, 83.0, 126.0, 115.0, 92.0, 57.0, 95.0] + + [0] * 15, + "start": "2018-01-07 00:00:00", + "feat_static_cat": [1], + } + ], freq="1m", ) diff --git a/test/mx/model/seq2seq/test_forking_sequence_splitter.py b/test/mx/model/seq2seq/test_forking_sequence_splitter.py index 66d60b51db..00d1df560f 100644 --- a/test/mx/model/seq2seq/test_forking_sequence_splitter.py +++ b/test/mx/model/seq2seq/test_forking_sequence_splitter.py @@ -44,30 +44,34 @@ def test_forking_sequence_splitter() -> None: enc_len = 5 dec_len = 3 - trans = transform.Chain([ - transform.AddAgeFeature( - target_field=FieldName.TARGET, - output_field="age", - pred_length=dec_len, - ), - ForkingSequenceSplitter( - instance_sampler=ValidationSplitSampler(min_future=dec_len), - enc_len=enc_len, - dec_len=dec_len, - encoder_series_fields=["age"], - ), - ]) + trans = transform.Chain( + [ + transform.AddAgeFeature( + target_field=FieldName.TARGET, + output_field="age", + pred_length=dec_len, + ), + ForkingSequenceSplitter( + instance_sampler=ValidationSplitSampler(min_future=dec_len), + enc_len=enc_len, + dec_len=dec_len, + encoder_series_fields=["age"], + ), + ] + ) out = trans(ds, is_train=True) transformed_data = next(iter(out)) - future_target = np.array([ - [13.0, 14.0, 15.0], - [14.0, 15.0, 16.0], - [15.0, 16.0, 17.0], - [16.0, 17.0, 18.0], - [17.0, 18.0, 19.0], - ]) + future_target = np.array( + [ + [13.0, 14.0, 15.0], + [14.0, 15.0, 16.0], + [15.0, 16.0, 17.0], + [16.0, 17.0, 18.0], + [17.0, 18.0, 19.0], + ] + ) assert ( np.linalg.norm(future_target - transformed_data["future_target"]) < 1e-5 @@ -106,35 +110,37 @@ def make_dataset(N, train_length): num_time_feat_daily_freq = 3 num_age_feat = 1 - trans = transform.Chain([ - transform.AddAgeFeature( - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_AGE, - pred_length=10, - ), - transform.AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field=FieldName.FEAT_TIME, - time_features=time_features_from_frequency_str("D"), - pred_length=10, - ), - ForkingSequenceSplitter( - instance_sampler=( - ValidationSplitSampler(min_future=dec_len) - if is_train - else TSplitSampler() + trans = transform.Chain( + [ + transform.AddAgeFeature( + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_AGE, + pred_length=10, ), - enc_len=enc_len, - dec_len=dec_len, - num_forking=num_forking, - encoder_series_fields=[ - FieldName.FEAT_AGE, - FieldName.FEAT_TIME, - ], - decoder_series_fields=[FieldName.FEAT_TIME], - ), - ]) + transform.AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field=FieldName.FEAT_TIME, + time_features=time_features_from_frequency_str("D"), + pred_length=10, + ), + ForkingSequenceSplitter( + instance_sampler=( + ValidationSplitSampler(min_future=dec_len) + if is_train + else TSplitSampler() + ), + enc_len=enc_len, + dec_len=dec_len, + num_forking=num_forking, + encoder_series_fields=[ + FieldName.FEAT_AGE, + FieldName.FEAT_TIME, + ], + decoder_series_fields=[FieldName.FEAT_TIME], + ), + ] + ) out = trans(iter(ds), is_train=is_train) transformed_data = next(iter(out)) diff --git a/test/mx/model/simple_feedforward/test_serde.py b/test/mx/model/simple_feedforward/test_serde.py index 7f4bd03c61..6ebdc5d403 100644 --- a/test/mx/model/simple_feedforward/test_serde.py +++ b/test/mx/model/simple_feedforward/test_serde.py @@ -26,10 +26,12 @@ def test_simplefeedforward_symbol_block_serde(): with tempfile.TemporaryDirectory( prefix="gluonts-predictor-temp-" ) as temp_dir: - dataset = [{ - "start": pd.Period("2022-01-01", freq="D"), - "target": np.random.normal(size=(200)), - }] + dataset = [ + { + "start": pd.Period("2022-01-01", freq="D"), + "target": np.random.normal(size=(200)), + } + ] estimator = SimpleFeedForwardEstimator( prediction_length=10, diff --git a/test/mx/model/tpp/common.py b/test/mx/model/tpp/common.py index 182ee5cd5b..6253eabf51 100644 --- a/test/mx/model/tpp/common.py +++ b/test/mx/model/tpp/common.py @@ -22,11 +22,13 @@ def point_process_dataset(): marks = np.array([0, 1, 2, 0, 1, 2, 2, 2]) lds = ListDataset( - [{ - "target": np.c_[ia_times, marks].T, - "start": pd.Timestamp("2011-01-01 00:00:00"), - "end": pd.Timestamp("2011-01-01 03:00:00"), - }], + [ + { + "target": np.c_[ia_times, marks].T, + "start": pd.Timestamp("2011-01-01 00:00:00"), + "end": pd.Timestamp("2011-01-01 03:00:00"), + } + ], freq="H", one_dim_target=False, use_timestamp=True, diff --git a/test/mx/representation/test_bin.py b/test/mx/representation/test_bin.py index 6b3b609d5e..8558b93e2d 100644 --- a/test/mx/representation/test_bin.py +++ b/test/mx/representation/test_bin.py @@ -20,201 +20,215 @@ binning_cases = [ ( CustomBinning(bin_centers=np.linspace(-1, 10, 5)), - mx.nd.array([ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([-np.inf, 0.375, 3.125, 5.875, 8.625, np.inf]), - mx.nd.array([ - [ - 1.0, - 2.0, - 2.0, - 3.0, - 3.0, - 4.0, - 4.0, - 4.0, - 5.0, - 5.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], + mx.nd.array( [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - ], + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array([-np.inf, 0.375, 3.125, 5.875, 8.625, np.inf]), + mx.nd.array( [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - ]), + [ + 1.0, + 2.0, + 2.0, + 3.0, + 3.0, + 4.0, + 4.0, + 4.0, + 5.0, + 5.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + ] + ), ), ( CustomBinning(bin_centers=np.linspace(-10, 10, 8)), - mx.nd.array([ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - -np.inf, - -8.57142857, - -5.71428571, - -2.85714286, - 0.0, - 2.85714286, - 5.71428571, - 8.57142857, - np.inf, - ]), - mx.nd.array([ - [ - 4.0, - 5.0, - 6.0, - 6.0, - 6.0, - 7.0, - 7.0, - 7.0, - 8.0, - 8.0, - ], - [ - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - ], + mx.nd.array( [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - ], + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - ], + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - ], + -np.inf, + -8.57142857, + -5.71428571, + -2.85714286, + 0.0, + 2.85714286, + 5.71428571, + 8.57142857, + np.inf, + ] + ), + mx.nd.array( [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - ], - ]), + [ + 4.0, + 5.0, + 6.0, + 6.0, + 6.0, + 7.0, + 7.0, + 7.0, + 8.0, + 8.0, + ], + [ + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + ], + [ + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + ], + [ + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + ], + [ + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + ], + [ + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + ], + ] + ), ), ] diff --git a/test/mx/representation/test_grb.py b/test/mx/representation/test_grb.py index 5b8f5e511b..0f8b32f845 100644 --- a/test/mx/representation/test_grb.py +++ b/test/mx/representation/test_grb.py @@ -24,126 +24,136 @@ is_quantile=True, quantile_scaling_limit=1.0, ), - mx.nd.array([ - [ - -0.188679, - 0.377358, - 0.566038, - 0.754717, - 0.943396, - 1.13208, - 1.32075, - 1.50943, - 1.69811, - 1.88679, - ], - [1.0] * 10, - [0.857143] * 5 + [1.14286] * 5, - [1.05263] * 8 + [0.789474] * 2, - [1.0] * 10, - ]), - mx.nd.array([ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - -np.inf, - 0.334232, - 0.92857149, - 1.0, - 1.03425997, - 1.47765499, - np.inf, - ]), mx.nd.array( - [-0.18867899, 0.85714298, 1.0, 1.0, 1.06851995, 1.88679004] - ), - mx.nd.array([ [ - 1, - 2, - 2, - 2, - 3, - 5, - 5, - 6, - 6, - 6, - ], - [ - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], + [ + -0.188679, + 0.377358, + 0.566038, + 0.754717, + 0.943396, + 1.13208, + 1.32075, + 1.50943, + 1.69811, + 1.88679, + ], + [1.0] * 10, + [0.857143] * 5 + [1.14286] * 5, + [1.05263] * 8 + [0.789474] * 2, + [1.0] * 10, + ] + ), + mx.nd.array( [ - 1, - 1, - 1, - 1, - 1, - 4, - 4, - 4, - 4, - 4, - ], + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( [ - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 2, - 2, - ], + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( [ - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - ], + -np.inf, + 0.334232, + 0.92857149, + 1.0, + 1.03425997, + 1.47765499, + np.inf, + ] + ), + mx.nd.array( + [-0.18867899, 0.85714298, 1.0, 1.0, 1.06851995, 1.88679004] + ), + mx.nd.array( [ - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], - ]), + [ + 1, + 2, + 2, + 2, + 3, + 5, + 5, + 6, + 6, + 6, + ], + [ + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], + [ + 1, + 1, + 1, + 1, + 1, + 4, + 4, + 4, + 4, + 4, + ], + [ + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 2, + 2, + ], + [ + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + ], + [ + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], + ] + ), ), ( GlobalRelativeBinning( @@ -151,135 +161,147 @@ is_quantile=True, quantile_scaling_limit=1.0, ), - mx.nd.array([ - [ - -0.188679, - 0.377358, - 0.566038, - 0.754717, - 0.943396, - 1.13208, - 1.32075, - 1.50943, - 1.69811, - 1.88679, - ], - [1.0] * 10, - [0.857143] * 5 + [1.14286] * 5, - [1.05263] * 8 + [0.789474] * 2, - [1.0] * 10, - ]), - mx.nd.array([ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - -np.inf, - 0.334232, - 0.92857149, - 1.0, - 1.0, - 1.02631497, - 1.097745, - 1.51482505, - np.inf, - ]), - mx.nd.array([ - -0.18867899, - 0.85714298, - 1.0, - 1.0, - 1.0, - 1.05262995, - 1.14286005, - 1.88679004, - ]), - mx.nd.array([ + mx.nd.array( [ - 1, - 2, - 2, - 2, - 3, - 7, - 7, - 7, - 8, - 8, - ], + [ + -0.188679, + 0.377358, + 0.566038, + 0.754717, + 0.943396, + 1.13208, + 1.32075, + 1.50943, + 1.69811, + 1.88679, + ], + [1.0] * 10, + [0.857143] * 5 + [1.14286] * 5, + [1.05263] * 8 + [0.789474] * 2, + [1.0] * 10, + ] + ), + mx.nd.array( [ - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - ], + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( [ - 1, - 1, - 1, - 1, - 1, - 5, - 5, - 5, - 5, - 5, - ], + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( [ - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 6, - 2, - 2, - ], + -np.inf, + 0.334232, + 0.92857149, + 1.0, + 1.0, + 1.02631497, + 1.097745, + 1.51482505, + np.inf, + ] + ), + mx.nd.array( [ - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - 9, - ], + -0.18867899, + 0.85714298, + 1.0, + 1.0, + 1.0, + 1.05262995, + 1.14286005, + 1.88679004, + ] + ), + mx.nd.array( [ - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - 5, - ], - ]), + [ + 1, + 2, + 2, + 2, + 3, + 7, + 7, + 7, + 8, + 8, + ], + [ + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + ], + [ + 1, + 1, + 1, + 1, + 1, + 5, + 5, + 5, + 5, + 5, + ], + [ + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + 2, + 2, + ], + [ + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + 9, + ], + [ + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + ], + ] + ), ), ( GlobalRelativeBinning( @@ -287,116 +309,124 @@ is_quantile=False, quantile_scaling_limit=1.0, ), - mx.nd.array([ - [ - -0.188679, - 0.377358, - 0.566038, - 0.754717, - 0.943396, - 1.13208, - 1.32075, - 1.50943, - 1.69811, - 1.88679, - ], - [1.0] * 10, - [0.857143] * 5 + [1.14286] * 5, - [1.05263] * 8 + [0.789474] * 2, - [1.0] * 10, - ]), - mx.nd.array([ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([-np.inf, -8.0, -4.0, 0.0, 4.0, 8.0, np.inf]), - mx.nd.array([-10.0, -6.0, -2.0, 2.0, 6.0, 10.0]), - mx.nd.array([ - [ - 3, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], - [ - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], + mx.nd.array( [ - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], + [ + -0.188679, + 0.377358, + 0.566038, + 0.754717, + 0.943396, + 1.13208, + 1.32075, + 1.50943, + 1.69811, + 1.88679, + ], + [1.0] * 10, + [0.857143] * 5 + [1.14286] * 5, + [1.05263] * 8 + [0.789474] * 2, + [1.0] * 10, + ] + ), + mx.nd.array( [ - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( [ - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - 7, - ], + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array([-np.inf, -8.0, -4.0, 0.0, 4.0, 8.0, np.inf]), + mx.nd.array([-10.0, -6.0, -2.0, 2.0, 6.0, 10.0]), + mx.nd.array( [ - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ], - ]), + [ + 3, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], + [ + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], + [ + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], + [ + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], + [ + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + 7, + ], + [ + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + ], + ] + ), ), ] diff --git a/test/mx/representation/test_hyb.py b/test/mx/representation/test_hyb.py index 223262cc44..466311b9da 100644 --- a/test/mx/representation/test_hyb.py +++ b/test/mx/representation/test_hyb.py @@ -41,171 +41,179 @@ ), ] ), - mx.nd.array([ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ]), + mx.nd.array( + [ + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( + [ + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ] + ), [ - mx.nd.array([ - [ - 1.0, - 2.0, - 2.0, - 3.0, - 3.0, - 4.0, - 4.0, - 4.0, - 5.0, - 5.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - ]), - mx.nd.array([ - [ - 4.0, - 5.0, - 6.0, - 6.0, - 6.0, - 7.0, - 7.0, - 7.0, - 8.0, - 8.0, - ], - [ - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - ], - [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 6.0, - 6.0, - 6.0, - 6.0, - 6.0, - ], - [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - ], + mx.nd.array( [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - ], + [ + 1.0, + 2.0, + 2.0, + 3.0, + 3.0, + 4.0, + 4.0, + 4.0, + 5.0, + 5.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + ] + ), + mx.nd.array( [ - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - 5.0, - ], - ]), + [ + 4.0, + 5.0, + 6.0, + 6.0, + 6.0, + 7.0, + 7.0, + 7.0, + 8.0, + 8.0, + ], + [ + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + ], + [ + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + ], + [ + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + ], + [ + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + ], + [ + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + 5.0, + ], + ] + ), ], ), ( @@ -225,105 +233,113 @@ ), ] ), - mx.nd.array([ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ]), + mx.nd.array( + [ + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( + [ + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ] + ), [ - mx.nd.array([ - [ - 1.0, - 2.0, - 2.0, - 3.0, - 3.0, - 4.0, - 4.0, - 4.0, - 5.0, - 5.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], + mx.nd.array( [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], + [ + 1.0, + 2.0, + 2.0, + 3.0, + 3.0, + 4.0, + 4.0, + 4.0, + 5.0, + 5.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ], + [ + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + 2.0, + ], + ] + ), + mx.nd.array( [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - ], - [ - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - 2.0, - ], - ]), - mx.nd.array([ - [1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - ]), + [1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + ] + ), ], ), ( @@ -337,31 +353,37 @@ ), ] ), - mx.nd.array([ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ]), + mx.nd.array( + [ + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( + [ + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ] + ), [ - mx.nd.array([ - [1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - ]), + mx.nd.array( + [ + [1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + ] + ), ], ), ] diff --git a/test/mx/representation/test_lab.py b/test/mx/representation/test_lab.py index 76cdc3e25b..98f83fa763 100644 --- a/test/mx/representation/test_lab.py +++ b/test/mx/representation/test_lab.py @@ -20,269 +20,289 @@ la_binning_cases = [ ( LocalAbsoluteBinning(num_bins=6, is_quantile=True), - mx.nd.array([ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ + mx.nd.array( [ - -np.inf, - 0.9, - 3.7, - 5.5, - 7.3, - 9.1, - np.inf, - ], + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( [ - -np.inf, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - np.inf, - ], + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( [ - -np.inf, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - np.inf, - ], + [ + -np.inf, + 0.9, + 3.7, + 5.5, + 7.3, + 9.1, + np.inf, + ], + [ + -np.inf, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + np.inf, + ], + [ + -np.inf, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + np.inf, + ], + [ + -np.inf, + 1.7, + 1.95, + 2.0, + 2.0, + 2.0, + np.inf, + ], + [ + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + ], + [ + -np.inf, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + np.inf, + ], + ] + ), + mx.nd.array( [ - -np.inf, - 1.7, - 1.95, - 2.0, - 2.0, - 2.0, - np.inf, - ], + [ + -1.0, + 2.8, + 4.6, + 6.4, + 8.2, + 10.0, + ], + [ + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + ], + [ + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + ], + [ + 1.5, + 1.9, + 2.0, + 2.0, + 2.0, + 2.0, + ], + [ + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ], + ] + ), + mx.nd.array( [ - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - ], - [ - -np.inf, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - np.inf, - ], - ]), - mx.nd.array([ - [ - -1.0, - 2.8, - 4.6, - 6.4, - 8.2, - 10.0, - ], - [ - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - ], - [ - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - ], - [ - 1.5, - 1.9, - 2.0, - 2.0, - 2.0, - 2.0, - ], - [ - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - ], - [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - ], - ]), - mx.nd.array([ - [1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - ]), + [1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + ] + ), ), ( LocalAbsoluteBinning(num_bins=6, is_quantile=False), - mx.nd.array([ - [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], - [3.0] * 10, - [0.0] * 5 + [3.0] * 5, - [2.0] * 8 + [1.5] * 2, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - [1.0] * 10, - [1.0] * 10, - [0.0] * 5 + [1.0] * 5, - [1.0] * 9 + [1.0] * 1, - [0.0] * 10, - [1.0] * 10, - ]), - mx.nd.array([ - [ - -np.inf, - 0.1, - 2.3, - 4.5, - 6.7, - 8.9, - np.inf, - ], - [ - -np.inf, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - np.inf, - ], - [ - -np.inf, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - np.inf, - ], - [ - -np.inf, - 1.55, - 1.65, - 1.75, - 1.85, - 1.95, - np.inf, - ], - [ - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - ], - [ - -np.inf, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - np.inf, - ], - ]), - mx.nd.array([ - [ - -1.0, - 1.2, - 3.4, - 5.6, - 7.8, - 10.0, - ], + mx.nd.array( [ - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - ], + [-1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], + [3.0] * 10, + [0.0] * 5 + [3.0] * 5, + [2.0] * 8 + [1.5] * 2, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( [ - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - 3.0, - ], + [1.0] * 10, + [1.0] * 10, + [0.0] * 5 + [1.0] * 5, + [1.0] * 9 + [1.0] * 1, + [0.0] * 10, + [1.0] * 10, + ] + ), + mx.nd.array( [ - 1.5, - 1.6, - 1.7, - 1.8, - 1.9, - 2.0, - ], + [ + -np.inf, + 0.1, + 2.3, + 4.5, + 6.7, + 8.9, + np.inf, + ], + [ + -np.inf, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + np.inf, + ], + [ + -np.inf, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + np.inf, + ], + [ + -np.inf, + 1.55, + 1.65, + 1.75, + 1.85, + 1.95, + np.inf, + ], + [ + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + ], + [ + -np.inf, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + np.inf, + ], + ] + ), + mx.nd.array( [ - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - -1.0, - ], + [ + -1.0, + 1.2, + 3.4, + 5.6, + 7.8, + 10.0, + ], + [ + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + ], + [ + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + 3.0, + ], + [ + 1.5, + 1.6, + 1.7, + 1.8, + 1.9, + 2.0, + ], + [ + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + ], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ], + ] + ), + mx.nd.array( [ - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - 1.0, - ], - ]), - mx.nd.array([ - [1.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], - ]), + [1.0, 2.0, 3.0, 3.0, 4.0, 4.0, 5.0, 5.0, 6.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 6.0, 6.0, 6.0, 6.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0], + ] + ), ), ] diff --git a/test/mx/representation/test_mean.py b/test/mx/representation/test_mean.py index a114eec4df..677232c873 100644 --- a/test/mx/representation/test_mean.py +++ b/test/mx/representation/test_mean.py @@ -20,36 +20,44 @@ mean_cases = [ ( MeanScaling(), - mx.nd.array([ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ]), - mx.nd.array([ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ]), + mx.nd.array( + [ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ] + ), + mx.nd.array( + [ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ] + ), mx.nd.array([1.0, 3.0, 1.5, 1.00396824, 1.00396824]), ), ( MeanScaling(), - mx.nd.array([ - [120.0] * 25 + [150.0] * 25, - [0.0] * 10 + [3.0] * 20 + [61.0] * 20, - [0.0] * 50, - [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, - ]), - mx.nd.array([ - [1.0] * 25 + [1.0] * 25, - [0.0] * 10 + [1.0] * 20 + [1.0] * 20, - [0.0] * 50, - [1.0] * 10 + [0.0] * 30 + [1.0] * 10, - ]), + mx.nd.array( + [ + [120.0] * 25 + [150.0] * 25, + [0.0] * 10 + [3.0] * 20 + [61.0] * 20, + [0.0] * 50, + [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, + ] + ), + mx.nd.array( + [ + [1.0] * 25 + [1.0] * 25, + [0.0] * 10 + [1.0] * 20 + [1.0] * 20, + [0.0] * 50, + [1.0] * 10 + [0.0] * 30 + [1.0] * 10, + ] + ), mx.nd.array([135.0, 32.0, 73.00454712, 2.5e-2]), ), ( diff --git a/test/mx/representation/test_rep.py b/test/mx/representation/test_rep.py index 9807c8cfdd..622d337c1f 100644 --- a/test/mx/representation/test_rep.py +++ b/test/mx/representation/test_rep.py @@ -18,34 +18,42 @@ cases = [ ( - mx.nd.array([ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ]), - mx.nd.array([ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ]), + mx.nd.array( + [ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ] + ), + mx.nd.array( + [ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ] + ), ), ( - mx.nd.array([ - [120.0] * 25 + [150.0] * 25, - [0.0] * 10 + [3.0] * 20 + [61.0] * 20, - [0.0] * 50, - [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, - ]), - mx.nd.array([ - [1.0] * 25 + [1.0] * 25, - [0.0] * 10 + [1.0] * 20 + [1.0] * 20, - [0.0] * 50, - [1.0] * 10 + [0.0] * 30 + [1.0] * 10, - ]), + mx.nd.array( + [ + [120.0] * 25 + [150.0] * 25, + [0.0] * 10 + [3.0] * 20 + [61.0] * 20, + [0.0] * 50, + [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, + ] + ), + mx.nd.array( + [ + [1.0] * 25 + [1.0] * 25, + [0.0] * 10 + [1.0] * 20 + [1.0] * 20, + [0.0] * 50, + [1.0] * 10 + [0.0] * 30 + [1.0] * 10, + ] + ), ), ( mx.nd.random.normal(shape=(5, 30)), diff --git a/test/mx/test_transform_equals.py b/test/mx/test_transform_equals.py index 0cd7dfefa6..23356cd966 100644 --- a/test/mx/test_transform_equals.py +++ b/test/mx/test_transform_equals.py @@ -131,54 +131,58 @@ def test_continuous_time_splitter(): def test_chain(): - chain = transform.Chain([ - transform.AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field="time_feat", - time_features=[ - time_feature.day_of_week, - time_feature.day_of_month, - time_feature.month_of_year, - ], - pred_length=10, - ), - transform.AddAgeFeature( - target_field=FieldName.TARGET, - output_field="age", - pred_length=10, - log_scale=True, - ), - transform.AddObservedValuesIndicator( - target_field=FieldName.TARGET, output_field="observed_values" - ), - ]) + chain = transform.Chain( + [ + transform.AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field="time_feat", + time_features=[ + time_feature.day_of_week, + time_feature.day_of_month, + time_feature.month_of_year, + ], + pred_length=10, + ), + transform.AddAgeFeature( + target_field=FieldName.TARGET, + output_field="age", + pred_length=10, + log_scale=True, + ), + transform.AddObservedValuesIndicator( + target_field=FieldName.TARGET, output_field="observed_values" + ), + ] + ) assert equals(chain, clone(chain)) assert not equals(chain, clone(chain, {"transformations": []})) - another_chain = transform.Chain([ - transform.AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field="time_feat", - time_features=[ - time_feature.day_of_week, - time_feature.day_of_month, - time_feature.month_of_year, - ], - pred_length=10, - ), - transform.AddAgeFeature( - target_field=FieldName.TARGET, - output_field="age", - pred_length=10, - log_scale=False, - ), - transform.AddObservedValuesIndicator( - target_field=FieldName.TARGET, output_field="observed_values" - ), - ]) + another_chain = transform.Chain( + [ + transform.AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field="time_feat", + time_features=[ + time_feature.day_of_week, + time_feature.day_of_month, + time_feature.month_of_year, + ], + pred_length=10, + ), + transform.AddAgeFeature( + target_field=FieldName.TARGET, + output_field="age", + pred_length=10, + log_scale=False, + ), + transform.AddObservedValuesIndicator( + target_field=FieldName.TARGET, output_field="observed_values" + ), + ] + ) assert not equals(chain, another_chain) diff --git a/test/nursery/anomaly_detection/supervised_metrics/test_precision_recall.py b/test/nursery/anomaly_detection/supervised_metrics/test_precision_recall.py index 109c34f349..ce87ecf94e 100644 --- a/test/nursery/anomaly_detection/supervised_metrics/test_precision_recall.py +++ b/test/nursery/anomaly_detection/supervised_metrics/test_precision_recall.py @@ -373,24 +373,26 @@ def test_buffered_precision_recall(test_case): @pytest.fixture def labels_and_scores() -> List[Tuple[np.array, np.array]]: label1 = np.array([0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]) - scores1 = np.array([ - 0.2, - 0.3, - 0.5, - 0.7, - 4, - 2.5, - 0.3, - 0.2, - 0.7, - 0.3, - 0.2, - 4, - 3, - 8, - 0.2, - 0.1, - ]) + scores1 = np.array( + [ + 0.2, + 0.3, + 0.5, + 0.7, + 4, + 2.5, + 0.3, + 0.2, + 0.7, + 0.3, + 0.2, + 4, + 3, + 8, + 0.2, + 0.1, + ] + ) label2 = np.array([0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0]) scores2 = np.array( diff --git a/test/nursery/autogluon_tabular/test_autogluon_tabular.py b/test/nursery/autogluon_tabular/test_autogluon_tabular.py index 73fb52925c..f1cc10cae0 100644 --- a/test/nursery/autogluon_tabular/test_autogluon_tabular.py +++ b/test/nursery/autogluon_tabular/test_autogluon_tabular.py @@ -134,41 +134,43 @@ def test_get_features_dataframe( @pytest.mark.parametrize( "dataset, freq, prediction_length", - [( - ListDataset( - [ - { - "start": "1750-01-07 00:00:00", - "target": np.array( - [ - 1089.2, - 1078.91, - 1099.88, - 35790.55, - 34096.95, - 34906.95, - ], - ), - }, - { - "start": "1750-01-07 00:00:00", - "target": np.array( - [ - 1099.2, - 1098.91, - 1069.88, - 35990.55, - 34076.95, - 34766.95, - ], - ), - }, - ], - freq="W-TUE", - ), - "W-TUE", - 2, - )], + [ + ( + ListDataset( + [ + { + "start": "1750-01-07 00:00:00", + "target": np.array( + [ + 1089.2, + 1078.91, + 1099.88, + 35790.55, + 34096.95, + 34906.95, + ], + ), + }, + { + "start": "1750-01-07 00:00:00", + "target": np.array( + [ + 1099.2, + 1098.91, + 1069.88, + 35990.55, + 34076.95, + 34766.95, + ], + ), + }, + ], + freq="W-TUE", + ), + "W-TUE", + 2, + ) + ], ) @pytest.mark.parametrize("lag_indices", [[], [1, 2, 5]]) @pytest.mark.parametrize("disable_auto_regression", [False, True]) diff --git a/test/shell/test_nested_params.py b/test/shell/test_nested_params.py index f00feb3072..797c2374df 100644 --- a/test/shell/test_nested_params.py +++ b/test/shell/test_nested_params.py @@ -15,11 +15,13 @@ def test_nested_params(): - data = decode_nested_parameters({ - "$env.num_workers": "4", - "$evaluation.quantiles": [0.1, 0.5, 0.9], - "prediction_length": 14, - }) + data = decode_nested_parameters( + { + "$env.num_workers": "4", + "$evaluation.quantiles": [0.1, 0.5, 0.9], + "prediction_length": 14, + } + ) hps = data.pop("") assert hps["prediction_length"] == 14 diff --git a/test/time_feature/test_agg_lags.py b/test/time_feature/test_agg_lags.py index 19b69fa2a9..dd3b2f2d9b 100644 --- a/test/time_feature/test_agg_lags.py +++ b/test/time_feature/test_agg_lags.py @@ -22,36 +22,44 @@ expected_lags_rolling = { "prediction_length_2": { - "train": np.array([ - [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5], - [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ]), - "test": np.array([ - [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5, 4.5, 5.5], - [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ]), + "train": np.array( + [ + [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5], + [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + "test": np.array( + [ + [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5, 4.5, 5.5], + [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), }, "prediction_length_1": { - "train": np.array([ - [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5], - [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ]), - "test": np.array([ - [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5, 4.5], - [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5, 3], - [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5, 2], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ]), + "train": np.array( + [ + [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5], + [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + "test": np.array( + [ + [0, 0, 0, 1, 1, 1.5, 2, 2.5, 3, 3.5, 4.5], + [0, 0, 0, 0, 0, 1, 1, 1.5, 2, 2.5, 3], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1.5, 2], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), }, } diff --git a/test/time_feature/test_holiday.py b/test/time_feature/test_holiday.py index 03b444c64f..2d55894885 100644 --- a/test/time_feature/test_holiday.py +++ b/test/time_feature/test_holiday.py @@ -128,11 +128,13 @@ def test_holidays(holiday): test_cases = [ ( pd.date_range(start="2016-12-24", end="2016-12-31", freq="D"), - np.array([ - [1, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1], - ]), + np.array( + [ + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + ] + ), [CHRISTMAS_EVE, CHRISTMAS_DAY, NEW_YEARS_EVE], ), ( @@ -161,89 +163,91 @@ def test_special_date_feature_set_hourly(): start="2016-12-24", end="2016-12-25", freq="H" ) - reference_features = np.array([ - [ - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 1, - 0, - ], - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1, - ], + reference_features = np.array( [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - ], - ]) + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + ] + ) sfs = SpecialDateFeatureSet([CHRISTMAS_EVE, CHRISTMAS_DAY, NEW_YEARS_EVE]) computed_features = sfs(date_indices) diff --git a/test/torch/distribution/test_discrete_distribution.py b/test/torch/distribution/test_discrete_distribution.py index f1223a71f3..0e89d26275 100644 --- a/test/torch/distribution/test_discrete_distribution.py +++ b/test/torch/distribution/test_discrete_distribution.py @@ -62,24 +62,30 @@ def test_rps(values, probs, obs, rps): [ # Duplicate values occur (i) only in the middle (ii) at the extremes ( - torch.tensor([ - [-1.0, 0.0, 0.0, 0.0, 2.0, 2.0, 5.0], - [-1.0, -1.0, 0.0, 0.0, 2.0, 5.0, 5.0], - ]), - torch.tensor([ - [0.1, 0.12, 0.03, 0.15, 0.05, 0.15, 0.4], - [0.15, 0.05, 0.13, 0.12, 0.05, 0.27, 0.23], - ]), + torch.tensor( + [ + [-1.0, 0.0, 0.0, 0.0, 2.0, 2.0, 5.0], + [-1.0, -1.0, 0.0, 0.0, 2.0, 5.0, 5.0], + ] + ), + torch.tensor( + [ + [0.1, 0.12, 0.03, 0.15, 0.05, 0.15, 0.4], + [0.15, 0.05, 0.13, 0.12, 0.05, 0.27, 0.23], + ] + ), ) ], ) @pytest.mark.parametrize( "probs_adjusted", [ - torch.tensor([ - [0.1, 0.0, 0.0, 0.3, 0.0, 0.2, 0.4], - [0.0, 0.2, 0.0, 0.25, 0.05, 0.0, 0.5], - ]), + torch.tensor( + [ + [0.1, 0.0, 0.0, 0.3, 0.0, 0.2, 0.4], + [0.0, 0.2, 0.0, 0.25, 0.05, 0.0, 0.5], + ] + ), ], ) def test_probs_duplicate_values(values, probs, probs_adjusted): diff --git a/test/torch/distribution/test_torch_piecewise_linear.py b/test/torch/distribution/test_torch_piecewise_linear.py index ae8f804362..731596bd24 100644 --- a/test/torch/distribution/test_torch_piecewise_linear.py +++ b/test/torch/distribution/test_torch_piecewise_linear.py @@ -97,12 +97,12 @@ def test_values( expected_target_crps: List[float], ): target = torch.Tensor(target).reshape(shape=(len(target),)) - expected_target_cdf = np.array(expected_target_cdf).reshape(( - len(expected_target_cdf), - )) - expected_target_crps = np.array(expected_target_crps).reshape(( - len(expected_target_crps), - )) + expected_target_cdf = np.array(expected_target_cdf).reshape( + (len(expected_target_cdf),) + ) + expected_target_crps = np.array(expected_target_crps).reshape( + (len(expected_target_crps),) + ) assert all(np.isclose(distr.cdf(target).numpy(), expected_target_cdf)) assert all(np.isclose(distr.crps(target).numpy(), expected_target_crps)) diff --git a/test/torch/model/test_mqf2_modules.py b/test/torch/model/test_mqf2_modules.py index c5ed31ab3d..a8c2685259 100644 --- a/test/torch/model/test_mqf2_modules.py +++ b/test/torch/model/test_mqf2_modules.py @@ -41,16 +41,18 @@ def test_mqf2_modules( distr_output = MQF2DistributionOutput(prediction_length) - lightning_module = MQF2MultiHorizonLightningModule({ - "freq": "1H", - "context_length": context_length, - "prediction_length": prediction_length, - "num_feat_dynamic_real": num_feat_dynamic_real, - "num_feat_static_real": num_feat_static_real, - "num_feat_static_cat": num_feat_static_cat, - "cardinality": cardinality, - "distr_output": distr_output, - }) + lightning_module = MQF2MultiHorizonLightningModule( + { + "freq": "1H", + "context_length": context_length, + "prediction_length": prediction_length, + "num_feat_dynamic_real": num_feat_dynamic_real, + "num_feat_static_real": num_feat_static_real, + "num_feat_static_cat": num_feat_static_cat, + "cardinality": cardinality, + "distr_output": distr_output, + } + ) model = lightning_module.model feat_static_cat = torch.zeros( diff --git a/test/torch/model/test_tft.py b/test/torch/model/test_tft.py index ddcfadb7dd..5ac1ff029e 100644 --- a/test/torch/model/test_tft.py +++ b/test/torch/model/test_tft.py @@ -40,17 +40,19 @@ def test_tft_modules( prediction_length = 6 context_length = 12 - lightning_module = TemporalFusionTransformerLightningModule({ - "context_length": context_length, - "prediction_length": prediction_length, - "d_past_feat_dynamic_real": d_past_feat_dynamic_real, - "c_past_feat_dynamic_cat": c_past_feat_dynamic_cat, - "d_feat_dynamic_real": d_feat_dynamic_real, - "c_feat_dynamic_cat": c_feat_dynamic_cat, - "d_feat_static_real": d_feat_static_real, - "c_feat_static_cat": c_feat_static_cat, - "distr_output": QuantileOutput(quantiles), - }) + lightning_module = TemporalFusionTransformerLightningModule( + { + "context_length": context_length, + "prediction_length": prediction_length, + "d_past_feat_dynamic_real": d_past_feat_dynamic_real, + "c_past_feat_dynamic_cat": c_past_feat_dynamic_cat, + "d_feat_dynamic_real": d_feat_dynamic_real, + "c_feat_dynamic_cat": c_feat_dynamic_cat, + "d_feat_static_real": d_feat_static_real, + "c_feat_static_cat": c_feat_static_cat, + "distr_output": QuantileOutput(quantiles), + } + ) model = lightning_module.model feat_static_cat = torch.zeros( diff --git a/test/torch/test_scaler.py b/test/torch/test_scaler.py index 02fe33e132..48ffb11330 100644 --- a/test/torch/test_scaler.py +++ b/test/torch/test_scaler.py @@ -19,72 +19,88 @@ test_cases = [ ( scaler.MeanScaler(), - torch.Tensor([ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ]), - torch.Tensor([ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ]), + torch.Tensor( + [ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ] + ), + torch.Tensor( + [ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ] + ), torch.Tensor([1.0, 3.0, 1.5, 1e-10, 1.00396824]), ), ( scaler.MeanScaler(default_scale=0.5), - torch.Tensor([ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ]), - torch.Tensor([ - [0.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ]), + torch.Tensor( + [ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ] + ), + torch.Tensor( + [ + [0.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ] + ), torch.Tensor([0.5, 3.0, 1.5, 1e-10, 0.5]), ), ( scaler.MeanScaler(keepdim=True), - torch.Tensor([ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ]), - torch.Tensor([ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ]), + torch.Tensor( + [ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ] + ), + torch.Tensor( + [ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ] + ), torch.Tensor([1.0, 3.0, 1.5, 1e-10, 1.00396824]).unsqueeze(1), ), ( scaler.MeanScaler(), - torch.Tensor([ - [120.0] * 25 + [150.0] * 25, - [0.0] * 10 + [3.0] * 20 + [61.0] * 20, - [0.0] * 50, - [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, - ]), - torch.Tensor([ - [1.0] * 25 + [1.0] * 25, - [0.0] * 10 + [1.0] * 20 + [1.0] * 20, - [0.0] * 50, - [1.0] * 10 + [0.0] * 30 + [1.0] * 10, - ]), + torch.Tensor( + [ + [120.0] * 25 + [150.0] * 25, + [0.0] * 10 + [3.0] * 20 + [61.0] * 20, + [0.0] * 50, + [2e-2] * 10 + [0.0] * 30 + [3e-2] * 10, + ] + ), + torch.Tensor( + [ + [1.0] * 25 + [1.0] * 25, + [0.0] * 10 + [1.0] * 20 + [1.0] * 20, + [0.0] * 50, + [1.0] * 10 + [0.0] * 30 + [1.0] * 10, + ] + ), torch.Tensor([135.0, 32.0, 73.00454712, 2.5e-2]), ), ( @@ -118,22 +134,28 @@ def test_scaler(s, target, observed, expected_scale): @pytest.mark.parametrize( "target, observed", - [( - torch.Tensor([ - [1.0] * 50, - [0.0] * 25 + [3.0] * 25, - [2.0] * 49 + [1.5] * 1, - [0.0] * 50, - [1.0] * 50, - ]), - torch.Tensor([ - [1.0] * 50, - [0.0] * 25 + [1.0] * 25, - [0.0] * 49 + [1.0] * 1, - [1.0] * 50, - [0.0] * 50, - ]), - )], + [ + ( + torch.Tensor( + [ + [1.0] * 50, + [0.0] * 25 + [3.0] * 25, + [2.0] * 49 + [1.5] * 1, + [0.0] * 50, + [1.0] * 50, + ] + ), + torch.Tensor( + [ + [1.0] * 50, + [0.0] * 25 + [1.0] * 25, + [0.0] * 49 + [1.0] * 1, + [1.0] * 50, + [0.0] * 50, + ] + ), + ) + ], ) def test_nopscaler(target, observed): s = scaler.NOPScaler() diff --git a/test/transform/test_transform.py b/test/transform/test_transform.py index d783e094f3..fbcdfb9d16 100644 --- a/test/transform/test_transform.py +++ b/test/transform/test_transform.py @@ -371,45 +371,47 @@ def test_Transformation(): pred_length = 10 - t = transform.Chain([ - transform.AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field="time_feat", - time_features=[ - time_feature.day_of_week, - time_feature.day_of_month, - time_feature.month_of_year, - ], - pred_length=pred_length, - ), - transform.AddAgeFeature( - target_field=FieldName.TARGET, - output_field="age", - pred_length=pred_length, - log_scale=True, - ), - transform.AddObservedValuesIndicator( - target_field=FieldName.TARGET, output_field="observed_values" - ), - transform.VstackFeatures( - output_field="dynamic_feat", - input_fields=["age", "time_feat"], - drop_inputs=True, - ), - transform.InstanceSplitter( - target_field=FieldName.TARGET, - is_pad_field=FieldName.IS_PAD, - start_field=FieldName.START, - forecast_start_field=FieldName.FORECAST_START, - instance_sampler=transform.ExpectedNumInstanceSampler( - num_instances=4 + t = transform.Chain( + [ + transform.AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field="time_feat", + time_features=[ + time_feature.day_of_week, + time_feature.day_of_month, + time_feature.month_of_year, + ], + pred_length=pred_length, ), - past_length=train_length, - future_length=pred_length, - time_series_fields=["dynamic_feat", "observed_values"], - ), - ]) + transform.AddAgeFeature( + target_field=FieldName.TARGET, + output_field="age", + pred_length=pred_length, + log_scale=True, + ), + transform.AddObservedValuesIndicator( + target_field=FieldName.TARGET, output_field="observed_values" + ), + transform.VstackFeatures( + output_field="dynamic_feat", + input_fields=["age", "time_feat"], + drop_inputs=True, + ), + transform.InstanceSplitter( + target_field=FieldName.TARGET, + is_pad_field=FieldName.IS_PAD, + start_field=FieldName.START, + forecast_start_field=FieldName.FORECAST_START, + instance_sampler=transform.ExpectedNumInstanceSampler( + num_instances=4 + ), + past_length=train_length, + future_length=pred_length, + time_series_fields=["dynamic_feat", "observed_values"], + ), + ] + ) assert_serializable(t) @@ -438,52 +440,54 @@ def test_multi_dim_transformation(is_train): first_dim[-1] = np.nan second_dim[0] = np.nan - t = transform.Chain([ - transform.AddTimeFeatures( - start_field=FieldName.START, - target_field=FieldName.TARGET, - output_field="time_feat", - time_features=[ - time_feature.day_of_week, - time_feature.day_of_month, - time_feature.month_of_year, - ], - pred_length=pred_length, - ), - transform.AddAgeFeature( - target_field=FieldName.TARGET, - output_field="age", - pred_length=pred_length, - log_scale=True, - ), - transform.AddObservedValuesIndicator( - target_field=FieldName.TARGET, - output_field="observed_values", - imputation_method=None, - ), - transform.VstackFeatures( - output_field="dynamic_feat", - input_fields=["age", "time_feat"], - drop_inputs=True, - ), - transform.InstanceSplitter( - target_field=FieldName.TARGET, - is_pad_field=FieldName.IS_PAD, - start_field=FieldName.START, - forecast_start_field=FieldName.FORECAST_START, - instance_sampler=( - transform.ExpectedNumInstanceSampler( - num_instances=4, min_future=pred_length - ) - if is_train - else transform.TestSplitSampler() + t = transform.Chain( + [ + transform.AddTimeFeatures( + start_field=FieldName.START, + target_field=FieldName.TARGET, + output_field="time_feat", + time_features=[ + time_feature.day_of_week, + time_feature.day_of_month, + time_feature.month_of_year, + ], + pred_length=pred_length, ), - past_length=train_length, - future_length=pred_length, - time_series_fields=["dynamic_feat", "observed_values"], - output_NTC=False, - ), - ]) + transform.AddAgeFeature( + target_field=FieldName.TARGET, + output_field="age", + pred_length=pred_length, + log_scale=True, + ), + transform.AddObservedValuesIndicator( + target_field=FieldName.TARGET, + output_field="observed_values", + imputation_method=None, + ), + transform.VstackFeatures( + output_field="dynamic_feat", + input_fields=["age", "time_feat"], + drop_inputs=True, + ), + transform.InstanceSplitter( + target_field=FieldName.TARGET, + is_pad_field=FieldName.IS_PAD, + start_field=FieldName.START, + forecast_start_field=FieldName.FORECAST_START, + instance_sampler=( + transform.ExpectedNumInstanceSampler( + num_instances=4, min_future=pred_length + ) + if is_train + else transform.TestSplitSampler() + ), + past_length=train_length, + future_length=pred_length, + time_series_fields=["dynamic_feat", "observed_values"], + output_NTC=False, + ), + ] + ) assert_serializable(t) @@ -633,14 +637,16 @@ def make_test_data(): ds = gluonts.dataset.common.ListDataset( # Mimic output from InstanceSplitter - [{ - "start": "2012-01-01", - "target": multi_dim_target, - "past_target": multi_dim_target, - "future_target": multi_dim_target, - "past_is_pad": past_is_pad, - f"past_{FieldName.OBSERVED_VALUES}": past_observed_target, - }], + [ + { + "start": "2012-01-01", + "target": multi_dim_target, + "past_target": multi_dim_target, + "future_target": multi_dim_target, + "past_is_pad": past_is_pad, + f"past_{FieldName.OBSERVED_VALUES}": past_observed_target, + } + ], freq="1D", one_dim_target=False, ) @@ -735,11 +741,13 @@ def point_process_dataset(): marks = np.array([0, 1, 2, 0, 1, 2, 2, 2]) return ListDataset( - [{ - "target": np.c_[ia_times, marks].T, - "start": pd.Timestamp("2011-01-01 00:00:00"), - "end": pd.Timestamp("2011-01-01 03:00:00"), - }], + [ + { + "target": np.c_[ia_times, marks].T, + "start": pd.Timestamp("2011-01-01 00:00:00"), + "end": pd.Timestamp("2011-01-01 03:00:00"), + } + ], freq="H", one_dim_target=False, use_timestamp=True, @@ -877,14 +885,16 @@ def test_ctsplitter_train_samples_correct_times(point_process_dataset): iter_de = splitter(point_process_dataset, is_train=True) - assert all([ - ( - pd.Timestamp("2011-01-01 01:15:00") - <= d["forecast_start"] - <= pd.Timestamp("2011-01-01 01:45:00") - ) - for d in iter_de - ]) + assert all( + [ + ( + pd.Timestamp("2011-01-01 01:15:00") + <= d["forecast_start"] + <= pd.Timestamp("2011-01-01 01:45:00") + ) + for d in iter_de + ] + ) def test_ctsplitter_train_short_intervals(point_process_dataset):