-
Notifications
You must be signed in to change notification settings - Fork 750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: prevent accumulation of SelectFields
in PyTorchPredictor
#2951
Conversation
@Cameronwood611 thanks for this; it seems like this fix is breaking the behaviour of models, so I'm not sure it's the right way to go. Is there any way you could share the model you're using for #2947, for debugging purposes? Or even a simpler, smaller dummy model that just runs into the issue. What kind of model are we talking about? |
Yes! I updated the issue linked but I'll also post the example here --sorry about that. I noticed I had to raise the for loop above 1000 iterations (around the default recursion depth) for it to break; not sure if that's related to different model params, not loading in the model, etc. I'm not sure of the side effects this change has but maybe we could look into a certain filter or adding SelectFields into a set? Would be willing to help with this effort. import random
import numpy as np
import pandas as pd
from gluonts.torch import PyTorchPredictor
from gluonts.dataset.common import ListDataset, DataEntry
from gluonts.torch import TemporalFusionTransformerEstimator
estimator = TemporalFusionTransformerEstimator(
freq='S',
context_length=10,
prediction_length=5,
trainer_kwargs={'max_epochs': 5}
)
train = ListDataset([{'target': np.array([random.uniform(-1, 1) for _ in range(50)]), 'start': pd.Period('01-01-2023', freq='S')} for _ in range(20)], freq='S')
pred = ListDataset([{'target': np.array([random.uniform(-1, 1) for _ in range(50)]), 'start': pd.Period('01-01-2023', freq='S')} for _ in range(20)], freq='S')
model = estimator.train(train, pred)
x, y, z = np.array([x for x in range(10)]), np.array([y for y in range (10)]), np.array([z for z in range(10)]) #for example
start = pd.Period('01-01-2023', freq='S')
_input = ListDataset([{'target': x, 'start': start},
{'target': y, 'start': start},
{'target': z, 'start': start}], freq='S')
for i in range(5000): #used to be 1000 for it to break..
pred = model.predict(_input)
output = [sig.mean for sig in pred] #this calls on a generator which causes the recursion error |
src/gluonts/torch/model/predictor.py
Outdated
self.input_transform += SelectFields( | ||
self.input_names + self.required_fields, allow_missing=True | ||
) | ||
self.provided_fields = True | ||
inference_data_loader = InferenceDataLoader( | ||
dataset, | ||
transform=self.input_transform, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Cameronwood611 the problem is with the +=
which updates the predictor state every time .predict
is invoked. Solution should be to either do here
input_transform = self.input_transform + SelectFields(...)
inference_data_loader = InferenceDataLoader(
...,
transform=input_transform,
....,
)
or directly in the constructor
self.input_transform = input_transform + SelectFields(...)
and remove lines 76-78 from here.
@Cameronwood611 thanks for the bug hunting and the fix, if you update it as suggested then it should work fine, and we'll backport it and release it soon |
SelectFields
in PyTorchPredictor
…slabs#2951) * Prevent redundant accumulation of fields * update fix --------- Co-authored-by: Cameronwood611 <cwood611@uab.edu> Co-authored-by: Lorenzo Stella <stellalo@amazon.com>
Thank you! |
* Fix JsonLinesFile slicing. (#2925) * Zebras: Fix index handling of SplitFrame.resize. (#2938) * Docs: fix missing values use-case in `PandasDataset` docs (#2941) * Ignore F403 errors in preludes. (#2948) * Fix: prevent accumulation of `SelectFields` in `PyTorchPredictor` (#2951) * Prevent redundant accumulation of fields * update fix --------- Co-authored-by: Cameronwood611 <cwood611@uab.edu> Co-authored-by: Lorenzo Stella <stellalo@amazon.com> * [Docs] fix link to NPTS implementation (#2953) * Revert "Fix JsonLinesFile slicing. (#2925)" This reverts commit fa7f9a0. --------- Co-authored-by: Jasper <schjaspe@amazon.de> Co-authored-by: cneely33 <cneely33@gmail.com> Co-authored-by: cameronwood611 <cameron.wood611@gmail.com> Co-authored-by: Cameronwood611 <cwood611@uab.edu>
Issue #, if available: #2947
Description of changes: Have simple check to not accumulate the same field names, resulting in a RecursionError.
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
Please tag this pr with at least one of these labels to make our release process faster: bug fix
@abdulfatir @lostella Please review as this relates to changes you (co-)authored