Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TFT model #962

Merged
merged 12 commits into from
Sep 30, 2020
Merged

Add TFT model #962

merged 12 commits into from
Sep 30, 2020

Conversation

Gandor26
Copy link
Contributor

@Gandor26 Gandor26 commented Jul 31, 2020

Issue #, if available:

Description of changes:
Add TFT model [1], benchmark

Datasets Context Length Prediction Length epochs / batches per epoch Methods wP10QL wP50QL wP90QL Running Time (s)
Electricity 168 24 10/1000 DeepAR 0.0292 0.0687 0.0369 900
Transformer 0.0317 0.0761 0.0474 500
TFT 0.0677 0.1509 0.0974 1200
Traffic 168 24 10/1000 DeepAR 0.0618 0.1585 0.1361 1140
Transformer 0.0803 0.1656 0.1062 600
TFT 0.0646 0.1531 0.1077 1500
Parts 24 12 5/2000 DeepAR 0.2221 1.0792 1.6332 360
Transformer 0.2068 1.002 1.7712 480
TFT 0.2135 1.0312 0.9918 1060
Wiki-10k 28 7 5/2000 DeepAR 0.0744 0.2185 0.2147 1300
Transformer NaN NaN NaN 720
TFT 0.0646 0.2042 0.1539 1150
M4-Daily 28 7 10/1000 DeepAR 0.0174 0.0274 0.0151 400
Transformer 0.0455 0.0649 0.039 240
TFT 0.0105 0.0189 0.009 1200

[1] Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Copy link
Contributor

@jaheba jaheba left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I've just quickly skimmed over the PR and have some minor comments.

Comment on lines 94 to 98
self.static_cardinalities = static_cardinalities or {}
self.static_feature_dims = static_feature_dims or {}
self.dynamic_cardinalities = dynamic_cardinalities or {}
self.dynamic_feature_dims = dynamic_feature_dims or {}
self.past_dynamic_features = past_dynamic_features or []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think pydantic is able to handle this.

E.g.:

time_features:List[TimeFeature] = [],

Normally you wouldn't do this in Python, but pydantic will return you a copy of the empty list, so this should work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in abdc711

embedding_dims
), "Length of `embedding_dims` and `embedding_dims` should match"
assert all(
[c > 0 for c in feature_dims]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The [] are not needed here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in abdc711

OutputTransform = Callable[[DataEntry, np.ndarray], np.ndarray]


class Trainer(BaseTrainer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't use use the Trainer defined in gluonts.mx.trainer? Or ask differently, what would you need to be able to use it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. In the ConvModel PR, he had also overrode it and said that the base trainer do not accept default argument values in hybrid_forward and overloaded the trainer to give default None to unused feature types. But I think we can use the Trainer directly using boolean or whether to use the features or not in the estimator or directly extend the Trainer class.


self.feature_dims = feature_dims
self.dtype = dtype
self.__num_features = len(feature_dims)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's uncommon to use two leading _ in Python, is this needed here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I just followed the implementation in gluonts.mx.block.feature.FeatureEmbedder, in which it also has such a property with name mangling
https://github.com/Gandor26/gluon-ts/blob/san/src/gluonts/mx/block/feature.py#L65

Comment on lines 75 to 77
assert (
len(feature_dims) > 0
), "Length of `cardinalities` list must be greater than zero"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert (
len(feature_dims) > 0
), "Length of `cardinalities` list must be greater than zero"
assert (feature_dims), "Length of `cardinalities` list must be greater than zero"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in abdc711

Copy link
Contributor

@lostella lostella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Gandor26, it seems that there is substantial code duplication in this PR which could be avoided. See my inline comments, which are probably related to @jaheba's and @dcmaddix's previous comments regarding the Trainer class.

Is the rewriting needed because of optional features/None inputs to the network? Then I think this is a separate issue.

Comment on lines 128 to 149
args = inspect.signature(
net.hybrid_forward
).parameters
inputs = []
for n, (name, arg) in enumerate(args.items()):
if n == 0:
if name == "F":
continue
else:
raise RuntimeError(
f"Expected first argument of HybridBlock to be `F`, "
f"but found `{name}`"
)
if name in data_entry:
inputs.append(data_entry[name])
elif not (arg.default is inspect._empty):
inputs.append(arg.default)
else:
raise RuntimeError(
f"The value of argument `{name}` of HybridBlock is not provided, "
f"and no default value is set."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some observations:

  • The code here seems like a refined version of this-plus-this allowing for defaults;
  • This is duplicated in the forecast generator defined below;
  • Both in the trainer and forecast generator now have an obsolete input_names constructor argument.

If you really want to push for this mechanism for TFT, my suggestion here is to propose this new mechanism in a separate PR where you showcase it with a minimal example model. This way you can focus on getting this into the default Trainer class and ForecastGenerator types in a backward compatible way (so that other models keep working fine with the default Trainer and ForecastGenerator classes).

Then you should then be able to proceed with the current PR, but avoiding a lot of code duplication.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This of course applies to #961 as well, so the code duplication savings double :-)

Copy link
Contributor

@dcmaddix dcmaddix left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, Xiaoyong! I think we can merge it in. Thanks for updating to using the default Trainer.

@Gandor26
Copy link
Contributor Author

@dcmaddix Thanks Danielle. Just let you know I ran one of the benchmark tests and got similar metrics, so it's safe to merge it.

@dcmaddix dcmaddix merged commit 4af09c7 into awslabs:master Sep 30, 2020
@lostella
Copy link
Contributor

Thanks @Gandor26 for pushing this through! If I remember correctly you had a mechanism to allow for optional arguments to hybrid_forward to work with the training loop, which I thought was pretty cool: we may want to revisit that at some point

kashif pushed a commit to kashif/gluon-ts that referenced this pull request Oct 10, 2020
* remove logging and add new tft impl

* fix multiple bugs in tft

* fix import conflict

* avoid nan in cold-start

* avoid nan in cold-start

* add default dummy static feature

* add license headers

* remove redundant chars in assertions and arg list

* remove customized trainer by adding dummy features

* update __init__.py

* fix import conflict

Co-authored-by: Xiaoyong Jin <jxiaoyon@amazon.com>
Co-authored-by: Danielle Robinson <dcmaddix@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants