Skip to content

Commit

Permalink
Breaking: change num_output with quantiles in TFT (mxnet) (#2879)
Browse files Browse the repository at this point in the history
  • Loading branch information
baniasbaabe authored Aug 30, 2023
1 parent 3ce5c00 commit b19238f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
55 changes: 51 additions & 4 deletions src/gluonts/mx/model/tft/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,53 @@ def _default_feat_args(dims_or_cardinalities: List[int]):


class TemporalFusionTransformerEstimator(GluonEstimator):
"""
Parameters
----------
freq
Frequency of the data to train on and predict.
prediction_length
Length of the prediction horizon.
context_length
Number of previous time series values provided as input to the encoder.
(default: None).
trainer
Trainer object to be used (default: Trainer())
hidden_dim
Size of the LSTM & transformer hidden states.
variable_dim
Size of the feature embeddings.
num_heads
Number of attention heads in self-attention layer in the decoder.
quantiles
List of quantiles that the model will learn to predict.
Defaults to [0.1, 0.5, 0.9]
num_instances_per_series
Number of samples to generate for each time series when training.
dropout_rate
Dropout regularization parameter (default: 0.1).
time_features
List of time features, from :py:mod:`gluonts.time_feature`, to use as
dynamic real features in addition to the provided data (default: None,
in which case these are automatically determined based on freq).
static_cardinalities
Cardinalities of the categorical static features.
dynamic_cardinalities
Cardinalities of the categorical dynamic features that are known in the future.
static_feature_dims
Sizes of the real-valued static features.
dynamic_dims
Sizes of the real-valued dynamic features that are known in the future.
past_dynamic_features
List of names of the real-valued dynamic features that are only known in the past.
train_sampler
Controls the sampling of windows during training.
validation_sampler
Controls the sampling of windows during validation.
batch_size
The size of the batches to be used training and prediction.
"""

@validated()
def __init__(
self,
Expand All @@ -76,7 +123,7 @@ def __init__(
hidden_dim: int = 32,
variable_dim: Optional[int] = None,
num_heads: int = 4,
num_outputs: int = 3,
quantiles: List[float] = [0.1, 0.5, 0.9],
num_instance_per_series: int = 100,
dropout_rate: float = 0.1,
time_features: List[TimeFeature] = [],
Expand Down Expand Up @@ -104,7 +151,7 @@ def __init__(
self.hidden_dim = hidden_dim
self.variable_dim = variable_dim or hidden_dim
self.num_heads = num_heads
self.num_outputs = num_outputs
self.quantiles = quantiles
self.num_instance_per_series = num_instance_per_series

if not time_features:
Expand Down Expand Up @@ -372,7 +419,7 @@ def create_training_network(
d_var=self.variable_dim,
d_hidden=self.hidden_dim,
n_head=self.num_heads,
n_output=self.num_outputs,
quantiles=self.quantiles,
d_past_feat_dynamic_real=_default_feat_args(
list(self.past_dynamic_feature_dims.values())
),
Expand Down Expand Up @@ -406,7 +453,7 @@ def create_predictor(
d_var=self.variable_dim,
d_hidden=self.hidden_dim,
n_head=self.num_heads,
n_output=self.num_outputs,
quantiles=self.quantiles,
d_past_feat_dynamic_real=_default_feat_args(
list(self.past_dynamic_feature_dims.values())
),
Expand Down
8 changes: 2 additions & 6 deletions src/gluonts/mx/model/tft/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __init__(
d_var: int,
d_hidden: int,
n_head: int,
n_output: int,
quantiles: List[float],
d_past_feat_dynamic_real: List[int],
c_past_feat_dynamic_cat: List[int],
d_feat_dynamic_real: List[int],
Expand All @@ -167,11 +167,7 @@ def __init__(
self.d_var = d_var
self.d_hidden = d_hidden
self.n_head = n_head
self.n_output = n_output
self.quantiles = sum(
([i / 10, 1.0 - i / 10] for i in range(1, (n_output + 1) // 2)),
[0.5],
)
self.quantiles = quantiles
self.normalize_eps = 1e-5

self.d_past_feat_dynamic_real = d_past_feat_dynamic_real
Expand Down

0 comments on commit b19238f

Please sign in to comment.