diff --git a/pytorch_lightning/trainer/data_loading_mixin.py b/pytorch_lightning/trainer/data_loading_mixin.py index 7755bc7d5dfef..52541a86fc0cb 100644 --- a/pytorch_lightning/trainer/data_loading_mixin.py +++ b/pytorch_lightning/trainer/data_loading_mixin.py @@ -24,7 +24,7 @@ def init_train_dataloader(self, model): self.get_train_dataloader = model.train_dataloader # determine number of training batches - if isinstance(self.get_train_dataloader(), IterableDataset): + if isinstance(self.get_train_dataloader().dataset, IterableDataset): self.nb_training_batches = float('inf') else: self.nb_training_batches = len(self.get_train_dataloader()) @@ -167,7 +167,7 @@ def get_dataloaders(self, model): self.get_val_dataloaders() # support IterableDataset for train data - self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader(), IterableDataset) + self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader().dataset, IterableDataset) if self.is_iterable_train_dataloader and not isinstance(self.val_check_interval, int): m = ''' When using an iterableDataset for train_dataloader,