Skip to content

Commit

Permalink
Fix: prevent accumulation of SelectFields in PyTorchPredictor (#2951
Browse files Browse the repository at this point in the history
)

* Prevent redundant accumulation of fields

* update fix

---------

Co-authored-by: Cameronwood611 <cwood611@uab.edu>
Co-authored-by: Lorenzo Stella <stellalo@amazon.com>
  • Loading branch information
3 people authored Aug 7, 2023
1 parent 13297b0 commit a944581
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/gluonts/torch/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def network(self) -> nn.Module:
def predict(
self, dataset: Dataset, num_samples: Optional[int] = None
) -> Iterator[Forecast]:
self.input_transform += SelectFields(
self.input_names + self.required_fields, allow_missing=True
)
inference_data_loader = InferenceDataLoader(
dataset,
transform=self.input_transform,
transform=self.input_transform
+ SelectFields(
self.input_names + self.required_fields, allow_missing=True
),
batch_size=self.batch_size,
stack_fn=lambda data: batchify(data, self.device),
)
Expand Down

0 comments on commit a944581

Please sign in to comment.