Skip to content

Commit

Permalink
Add TFT model (#962)
Browse files Browse the repository at this point in the history
* remove logging and add new tft impl

* fix multiple bugs in tft

* fix import conflict

* avoid nan in cold-start

* avoid nan in cold-start

* add default dummy static feature

* add license headers

* remove redundant chars in assertions and arg list

* remove customized trainer by adding dummy features

* update __init__.py

* fix import conflict

Co-authored-by: Xiaoyong Jin <jxiaoyon@amazon.com>
Co-authored-by: Danielle Robinson <dcmaddix@gmail.com>
  • Loading branch information
3 people authored Sep 30, 2020
1 parent 60f7d39 commit 4af09c7
Show file tree
Hide file tree
Showing 5 changed files with 1,452 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/gluonts/model/tft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from ._estimator import TemporalFusionTransformerEstimator

__all__ = [
"TemporalFusionTransformerEstimator",
]

# fix Sphinx issues, see https://bit.ly/2K2eptM
for item in __all__:
if hasattr(item, "__module__"):
setattr(item, "__module__", __name__)
341 changes: 341 additions & 0 deletions src/gluonts/model/tft/_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,341 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Optional, Dict, List
from itertools import chain

import numpy as np
import mxnet as mx
from mxnet.gluon import HybridBlock

from gluonts.core.component import DType, validated
from gluonts.dataset.field_names import FieldName
from gluonts.model.estimator import GluonEstimator
from gluonts.model.predictor import RepresentableBlockPredictor
from gluonts.support.util import copy_parameters
from gluonts.time_feature import (
TimeFeature,
time_features_from_frequency_str,
)
from gluonts.model.forecast_generator import QuantileForecastGenerator
from gluonts.mx.trainer import Trainer
from gluonts.transform import (
AddObservedValuesIndicator,
AddTimeFeatures,
AsNumpyArray,
Chain,
ExpectedNumInstanceSampler,
Transformation,
VstackFeatures,
SetField,
)

from ._transform import TFTInstanceSplitter
from ._network import (
TemporalFusionTransformerTrainingNetwork,
TemporalFusionTransformerPredictionNetwork,
)


class TemporalFusionTransformerEstimator(GluonEstimator):
@validated()
def __init__(
self,
freq: str,
context_length: int,
prediction_length: Optional[int] = None,
trainer: Trainer = Trainer(),
hidden_dim: int = 32,
variable_dim: Optional[int] = None,
num_heads: int = 4,
num_outputs: int = 3,
num_instance_per_series: int = 100,
dropout_rate: float = 0.1,
time_features: List[TimeFeature] = [],
static_cardinalities: Dict[str, int] = {},
dynamic_cardinalities: Dict[str, int] = {},
static_feature_dims: Dict[str, int] = {},
dynamic_feature_dims: Dict[str, int] = {},
past_dynamic_features: List[str] = [],
) -> None:
super(TemporalFusionTransformerEstimator, self).__init__(
trainer=trainer
)
assert (
prediction_length > 0
), "The value of `prediction_length` should be > 0"
assert (
context_length is None or context_length > 0
), "The value of `context_length` should be > 0"
assert dropout_rate >= 0, "The value of `dropout_rate` should be >= 0"

self.freq = freq
self.prediction_length = prediction_length
self.context_length = context_length or prediction_length
self.dropout_rate = dropout_rate
self.hidden_dim = hidden_dim
self.variable_dim = variable_dim or hidden_dim
self.num_heads = num_heads
self.num_outputs = num_outputs
self.num_instance_per_series = num_instance_per_series

if not time_features:
self.time_features = time_features_from_frequency_str(self.freq)
else:
self.time_features = time_features
self.static_cardinalities = static_cardinalities
self.dynamic_cardinalities = dynamic_cardinalities
self.static_feature_dims = static_feature_dims
self.dynamic_feature_dims = dynamic_feature_dims
self.past_dynamic_features = past_dynamic_features

self.past_dynamic_cardinalities = {}
self.past_dynamic_feature_dims = {}
for name in self.past_dynamic_features:
if name in self.dynamic_cardinalities:
self.past_dynamic_cardinalities[
name
] = self.dynamic_cardinalities.pop(name)
elif name in self.dynamic_feature_dims:
self.past_dynamic_feature_dims[
name
] = self.dynamic_feature_dims.pop(name)
else:
raise ValueError(
f"Feature name {name} is not provided in feature dicts"
)

def create_transformation(self) -> Transformation:
transforms = (
[AsNumpyArray(field=FieldName.TARGET, expected_ndim=1)]
+ (
[
AsNumpyArray(field=name, expected_ndim=1)
for name in self.static_cardinalities.keys()
]
)
+ [
AsNumpyArray(field=name, expected_ndim=1)
for name in chain(
self.static_feature_dims.keys(),
self.dynamic_cardinalities.keys(),
)
]
+ [
AsNumpyArray(field=name, expected_ndim=2)
for name in self.dynamic_feature_dims.keys()
]
+ [
AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
),
AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
),
]
)

ts_fields = []
past_ts_fields = []

if self.static_cardinalities:
transforms.append(
VstackFeatures(
output_field=FieldName.FEAT_STATIC_CAT,
input_fields=list(self.static_cardinalities.keys()),
)
)
else:
transforms.extend(
[
SetField(
output_field=FieldName.FEAT_STATIC_CAT, value=[0.0],
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_CAT, expected_ndim=1
),
]
)

if self.static_feature_dims:
transforms.append(
VstackFeatures(
output_field=FieldName.FEAT_STATIC_REAL,
input_fields=list(self.static_feature_dims.keys()),
)
)
else:
transforms.extend(
[
SetField(
output_field=FieldName.FEAT_STATIC_REAL, value=[0.0],
),
AsNumpyArray(
field=FieldName.FEAT_STATIC_REAL, expected_ndim=1
),
]
)

if self.dynamic_cardinalities:
transforms.append(
VstackFeatures(
output_field=FieldName.FEAT_DYNAMIC_CAT,
input_fields=list(self.dynamic_cardinalities.keys()),
)
)
ts_fields.append(FieldName.FEAT_DYNAMIC_CAT)
else:
transforms.extend(
[
SetField(
output_field=FieldName.FEAT_DYNAMIC_CAT, value=[0.0],
),
AsNumpyArray(
field=FieldName.FEAT_DYNAMIC_CAT, expected_ndim=1
),
]
)

input_fields = [FieldName.FEAT_TIME]
if self.dynamic_feature_dims:
input_fields += list(self.dynamic_feature_dims.keys())
transforms.append(
VstackFeatures(
input_fields=input_fields,
output_field=FieldName.FEAT_DYNAMIC_REAL,
)
)
ts_fields.append(FieldName.FEAT_DYNAMIC_REAL)

if self.past_dynamic_cardinalities:
transforms.append(
VstackFeatures(
output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat",
input_fields=list(self.past_dynamic_cardinalities.keys()),
)
)
past_ts_fields.append(FieldName.PAST_FEAT_DYNAMIC + "_cat")
else:
transforms.extend(
[
SetField(
output_field=FieldName.PAST_FEAT_DYNAMIC + "_cat",
value=[0.0],
),
AsNumpyArray(
field=FieldName.PAST_FEAT_DYNAMIC + "_cat",
expected_ndim=1,
),
]
)

if self.past_dynamic_feature_dims:
transforms.append(
VstackFeatures(
output_field=FieldName.PAST_FEAT_DYNAMIC_REAL,
input_fields=list(self.past_dynamic_feature_dims.keys()),
)
)
past_ts_fields.append(FieldName.PAST_FEAT_DYNAMIC_REAL)
else:
transforms.extend(
[
SetField(
output_field=FieldName.PAST_FEAT_DYNAMIC_REAL,
value=[[0.0]],
),
AsNumpyArray(
field=FieldName.PAST_FEAT_DYNAMIC_REAL, expected_ndim=2
),
]
)

transforms.append(
TFTInstanceSplitter(
train_sampler=ExpectedNumInstanceSampler(
num_instances=self.num_instance_per_series,
),
past_length=self.context_length,
future_length=self.prediction_length,
time_series_fields=ts_fields,
past_time_series_fields=past_ts_fields,
)
)

return Chain(transforms)

def create_training_network(
self,
) -> TemporalFusionTransformerTrainingNetwork:
network = TemporalFusionTransformerTrainingNetwork(
context_length=self.context_length,
prediction_length=self.prediction_length,
d_var=self.variable_dim,
d_hidden=self.hidden_dim,
n_head=self.num_heads,
n_output=self.num_outputs,
d_past_feat_dynamic_real=list(
self.past_dynamic_feature_dims.values()
),
c_past_feat_dynamic_cat=list(
self.past_dynamic_cardinalities.values()
),
d_feat_dynamic_real=[1] * len(self.time_features)
+ list(self.dynamic_feature_dims.values()),
c_feat_dynamic_cat=list(self.dynamic_cardinalities.values()),
d_feat_static_real=list(self.static_feature_dims.values()),
c_feat_static_cat=list(self.static_cardinalities.values()),
dropout=self.dropout_rate,
)
return network

def create_predictor(
self, transformation: Transformation, trained_network: HybridBlock
) -> RepresentableBlockPredictor:
prediction_network = TemporalFusionTransformerPredictionNetwork(
context_length=self.context_length,
prediction_length=self.prediction_length,
d_var=self.variable_dim,
d_hidden=self.hidden_dim,
n_head=self.num_heads,
n_output=self.num_outputs,
d_past_feat_dynamic_real=list(
self.past_dynamic_feature_dims.values()
),
c_past_feat_dynamic_cat=list(
self.past_dynamic_cardinalities.values()
),
d_feat_dynamic_real=[1] * len(self.time_features)
+ list(self.dynamic_feature_dims.values()),
c_feat_dynamic_cat=list(self.dynamic_cardinalities.values()),
d_feat_static_real=list(self.static_feature_dims.values()),
c_feat_static_cat=list(self.static_cardinalities.values()),
dropout=self.dropout_rate,
)
copy_parameters(trained_network, prediction_network)
return RepresentableBlockPredictor(
input_transform=transformation,
prediction_net=prediction_network,
batch_size=self.trainer.batch_size,
freq=self.freq,
prediction_length=self.prediction_length,
ctx=self.trainer.ctx,
forecast_generator=QuantileForecastGenerator(
quantiles=[str(q) for q in prediction_network.quantiles],
),
)
Loading

0 comments on commit 4af09c7

Please sign in to comment.