-
Notifications
You must be signed in to change notification settings - Fork 750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TFT model #962
Add TFT model #962
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
I've just quickly skimmed over the PR and have some minor comments.
src/gluonts/model/tft/_estimator.py
Outdated
self.static_cardinalities = static_cardinalities or {} | ||
self.static_feature_dims = static_feature_dims or {} | ||
self.dynamic_cardinalities = dynamic_cardinalities or {} | ||
self.dynamic_feature_dims = dynamic_feature_dims or {} | ||
self.past_dynamic_features = past_dynamic_features or [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think pydantic is able to handle this.
E.g.:
time_features:List[TimeFeature] = [],
Normally you wouldn't do this in Python, but pydantic will return you a copy of the empty list, so this should work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed in abdc711
src/gluonts/model/tft/_network.py
Outdated
embedding_dims | ||
), "Length of `embedding_dims` and `embedding_dims` should match" | ||
assert all( | ||
[c > 0 for c in feature_dims] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The []
are not needed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed in abdc711
src/gluonts/model/tft/_engine.py
Outdated
OutputTransform = Callable[[DataEntry, np.ndarray], np.ndarray] | ||
|
||
|
||
class Trainer(BaseTrainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't use use the Trainer
defined in gluonts.mx.trainer
? Or ask differently, what would you need to be able to use it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. In the ConvModel PR, he had also overrode it and said that the base trainer do not accept default argument values in hybrid_forward and overloaded the trainer to give default None to unused feature types. But I think we can use the Trainer
directly using boolean or whether to use the features or not in the estimator or directly extend the Trainer
class.
|
||
self.feature_dims = feature_dims | ||
self.dtype = dtype | ||
self.__num_features = len(feature_dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's uncommon to use two leading _
in Python, is this needed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I just followed the implementation in gluonts.mx.block.feature.FeatureEmbedder, in which it also has such a property with name mangling
https://github.com/Gandor26/gluon-ts/blob/san/src/gluonts/mx/block/feature.py#L65
src/gluonts/model/tft/_network.py
Outdated
assert ( | ||
len(feature_dims) > 0 | ||
), "Length of `cardinalities` list must be greater than zero" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert ( | |
len(feature_dims) > 0 | |
), "Length of `cardinalities` list must be greater than zero" | |
assert (feature_dims), "Length of `cardinalities` list must be greater than zero" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed in abdc711
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @Gandor26, it seems that there is substantial code duplication in this PR which could be avoided. See my inline comments, which are probably related to @jaheba's and @dcmaddix's previous comments regarding the Trainer class.
Is the rewriting needed because of optional features/None inputs to the network? Then I think this is a separate issue.
src/gluonts/model/tft/_engine.py
Outdated
args = inspect.signature( | ||
net.hybrid_forward | ||
).parameters | ||
inputs = [] | ||
for n, (name, arg) in enumerate(args.items()): | ||
if n == 0: | ||
if name == "F": | ||
continue | ||
else: | ||
raise RuntimeError( | ||
f"Expected first argument of HybridBlock to be `F`, " | ||
f"but found `{name}`" | ||
) | ||
if name in data_entry: | ||
inputs.append(data_entry[name]) | ||
elif not (arg.default is inspect._empty): | ||
inputs.append(arg.default) | ||
else: | ||
raise RuntimeError( | ||
f"The value of argument `{name}` of HybridBlock is not provided, " | ||
f"and no default value is set." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some observations:
- The code here seems like a refined version of this-plus-this allowing for defaults;
- This is duplicated in the forecast generator defined below;
- Both in the trainer and forecast generator now have an obsolete
input_names
constructor argument.
If you really want to push for this mechanism for TFT, my suggestion here is to propose this new mechanism in a separate PR where you showcase it with a minimal example model. This way you can focus on getting this into the default Trainer
class and ForecastGenerator
types in a backward compatible way (so that other models keep working fine with the default Trainer
and ForecastGenerator
classes).
Then you should then be able to proceed with the current PR, but avoiding a lot of code duplication.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This of course applies to #961 as well, so the code duplication savings double :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, Xiaoyong! I think we can merge it in. Thanks for updating to using the default Trainer.
@dcmaddix Thanks Danielle. Just let you know I ran one of the benchmark tests and got similar metrics, so it's safe to merge it. |
Thanks @Gandor26 for pushing this through! If I remember correctly you had a mechanism to allow for optional arguments to |
* remove logging and add new tft impl * fix multiple bugs in tft * fix import conflict * avoid nan in cold-start * avoid nan in cold-start * add default dummy static feature * add license headers * remove redundant chars in assertions and arg list * remove customized trainer by adding dummy features * update __init__.py * fix import conflict Co-authored-by: Xiaoyong Jin <jxiaoyon@amazon.com> Co-authored-by: Danielle Robinson <dcmaddix@gmail.com>
Issue #, if available:
Description of changes:
Add TFT model [1], benchmark
[1] Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.