-
Notifications
You must be signed in to change notification settings - Fork 750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented moving average #926
Conversation
Could we add some basic tests for this? Something like sanity checks, shape checks, trivial input/output and maybe edge cases (NaNs, empty sequence, length 1 sequence, etc.)? Maybe under |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's actually too hard to understand what the code does, since in my view too many things are factored out into separate functions.
|
||
ds = ListDataset([{"target": target, "start": start}], freq=freq) | ||
item = next(iter(ds)) | ||
predictions = mp.predict_item(item).samples[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
predictions = mp.predict_item(item).samples[0] | |
predictions = mp.predict_item(item).mean |
I like this a bit better.
src/gluonts/model/trivial/mean.py
Outdated
# assert ( | ||
# context_length >= 1 | ||
# ), "The value of `context_length` should be >= 1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should still have a check when context_length
is an integer.
src/gluonts/model/trivial/mean.py
Outdated
start_date = forecast_start(item) | ||
return SampleForecast( | ||
samples=np.array([target[-self.prediction_length :]]), | ||
start_date=start_date, | ||
freq=self.freq, | ||
item_id=item.get(FieldName.ITEM_ID), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
start_date = forecast_start(item) | |
return SampleForecast( | |
samples=np.array([target[-self.prediction_length :]]), | |
start_date=start_date, | |
freq=self.freq, | |
item_id=item.get(FieldName.ITEM_ID), | |
) | |
return SampleForecast( | |
samples=np.array([target[-self.prediction_length :]]), | |
start_date=forecast_start(item), | |
freq=self.freq, | |
item_id=item.get(FieldName.ITEM_ID), | |
) |
* Add moving average * Add moving average * Add moving average - added tests * context_length now optional::forecast_start is now used::updated tests * refinments added * added assert + forecast_start updates * forecast_start updates * forecast_start updates + assert Co-authored-by: Pedro Eduardo Mercado Lopez <pedroml@amazon.com>
Issue #, if available:
Description of changes:
-- This is a deterministic task (no distribution of observations is assumed)
-- For the case where prediction_length = 1, the output is the standard moving average of the previous context_length observations
-- For the case where prediction_length > 1, the input target is appended with previous moving averages, further computing moving averages until prediction_length of them are computed.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.