diff --git a/src/gluonts/torch/model/predictor.py b/src/gluonts/torch/model/predictor.py index aa3ba592c1..fdb5ac3e7d 100644 --- a/src/gluonts/torch/model/predictor.py +++ b/src/gluonts/torch/model/predictor.py @@ -73,12 +73,12 @@ def network(self) -> nn.Module: def predict( self, dataset: Dataset, num_samples: Optional[int] = None ) -> Iterator[Forecast]: - self.input_transform += SelectFields( - self.input_names + self.required_fields, allow_missing=True - ) inference_data_loader = InferenceDataLoader( dataset, - transform=self.input_transform, + transform=self.input_transform + + SelectFields( + self.input_names + self.required_fields, allow_missing=True + ), batch_size=self.batch_size, stack_fn=lambda data: batchify(data, self.device), )