-
Notifications
You must be signed in to change notification settings - Fork 755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Future features #1757
base: dev
Are you sure you want to change the base?
Future features #1757
Conversation
Nice, thanks @mbohlkeschneider! @lovvge Michael has extended DeepAR to also support past dynamic features, similar to MQ-CNN. |
) | ||
self.rnn.add(cell) | ||
self.rnn.cast(dtype=dtype) | ||
self.rnn = self.rnn_block( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename this as self.encoder
for better readability? Then self.decoder
is either a reference to it or an entirely different block depending on past only features.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Yes, I struggled with the naming a bit. In the canonical
DeepAR, encoder
would sound weird (although I think this would still be fine).
future_time_feat: Tensor, | ||
future_target: Tensor, | ||
future_observed_values: Tensor, | ||
past_past_feat_dynamic_real: Optional[Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
past two times? You mean past_only_feat_dynamic_real
or past_only_time_feat
? Also docstring needs to be updated for this new argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is weird as well. My goal was to use the build-in feature naming from FieldName
for the hybrid_forward
call: https://github.com/awslabs/gluon-ts/blob/master/src/gluonts/dataset/field_names.py#L30, which would be past_feat_dynamic_real
. This then gets turned to past_past_feat_dynamic_real
in the InstanceSplitter
https://github.com/awslabs/gluon-ts/blob/master/src/gluonts/transform/split.py#L179.
One way to circumvent is to have a different channel in the InstanceSplitter
that does not do the field name modification to that channel (past only time series channel)?
Do you have other suggestions?
@@ -1140,6 +1209,8 @@ def hybrid_forward( | |||
Tensor | |||
Predicted samples | |||
""" | |||
if past_past_feat_dynamic_real is not None: | |||
past_past_feat_dynamic_real.mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for spotting this. Will remove it, this was to test whether the data gets actually passed.
@@ -127,4 +127,4 @@ def __init__(self, input_fields: List[str]) -> None: | |||
self.input_fields = input_fields | |||
|
|||
def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry: | |||
return {f: data[f] for f in self.input_fields} | |||
return {f: data[f] for f in self.input_fields if f in data.keys()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this results in silent errors in other cases (ignoring required fields simply because they are not in the data)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I don't think the error would be silent in this case (at least mxnet models will complain for sure!). This change is required to have optional fields in the network forward passes, though. I think the fundamental question that we have to ask:
Do we want to have optional inputs for the networks or not?
Is there a timeline for this feature? |
Issue #, if available:
Addresses issue #1657.
Description of changes:
Modified
DeepAR
to take optionalpast_feat_dynamic_real
features that are only known in training range. This required some changes in theInstanceSplitter
to accommodate that feature as a separate field. Additionally, some changes in theForecastGenerator
andSelectFields
transformation where required to support optional fields.By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup