Skip to content

Commit 2ffc0de

Browse files
Support predict_dataset in LightningDataModule.from_datasets (#12942)
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
1 parent 88c202e commit 2ffc0de

File tree

3 files changed

+25
-7
lines changed

3 files changed

+25
-7
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3333
- Include a version suffix for new "last" checkpoints of later runs in the same directory ([#12902](https://github.com/PyTorchLightning/pytorch-lightning/pull/12902))
3434

3535

36+
- Added missing `predict_dataset` argument in `LightningDataModule.from_datasets` to create predict dataloaders ([#12942](https://github.com/PyTorchLightning/pytorch-lightning/pull/12942))
37+
38+
3639
- Added class name prefix to metrics logged by `DeviceStatsMonitor` ([#12228](https://github.com/PyTorchLightning/pytorch-lightning/pull/12228))
3740

41+
42+
3843
### Changed
3944

4045
- Enable validation during overfitting ([#12527](https://github.com/PyTorchLightning/pytorch-lightning/pull/12527))

pytorch_lightning/core/datamodule.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def from_datasets(
102102
train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None,
103103
val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
104104
test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
105+
predict_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
105106
batch_size: int = 1,
106107
num_workers: int = 0,
107108
):
@@ -112,6 +113,7 @@ def from_datasets(
112113
train_dataset: (optional) Dataset to be used for train_dataloader()
113114
val_dataset: (optional) Dataset or list of Dataset to be used for val_dataloader()
114115
test_dataset: (optional) Dataset or list of Dataset to be used for test_dataloader()
116+
predict_dataset: (optional) Dataset or list of Dataset to be used for predict_dataloader()
115117
batch_size: Batch size to use for each dataloader. Default is 1.
116118
num_workers: Number of subprocesses to use for data loading. 0 means that the
117119
data will be loaded in the main process. Number of CPUs available.
@@ -139,13 +141,20 @@ def test_dataloader():
139141
return [dataloader(ds) for ds in test_dataset]
140142
return dataloader(test_dataset)
141143

144+
def predict_dataloader():
145+
if isinstance(predict_dataset, Sequence):
146+
return [dataloader(ds) for ds in predict_dataset]
147+
return dataloader(predict_dataset)
148+
142149
datamodule = cls()
143150
if train_dataset is not None:
144151
datamodule.train_dataloader = train_dataloader
145152
if val_dataset is not None:
146153
datamodule.val_dataloader = val_dataloader
147154
if test_dataset is not None:
148155
datamodule.test_dataloader = test_dataloader
156+
if predict_dataset is not None:
157+
datamodule.predict_dataloader = predict_dataloader
149158
return datamodule
150159

151160
def state_dict(self) -> Dict[str, Any]:
@@ -154,7 +163,7 @@ def state_dict(self) -> Dict[str, Any]:
154163
Returns:
155164
A dictionary containing datamodule state.
156165
"""
157-
return {}
166+
return dict()
158167

159168
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
160169
"""Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

tests/core/test_datamodules.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -377,9 +377,9 @@ def test_dm_init_from_datasets_dataloaders(iterable):
377377
with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
378378
dm.train_dataloader()
379379
dl_mock.assert_called_once_with(train_ds, batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True)
380-
with pytest.raises(MisconfigurationException):
380+
with pytest.raises(MisconfigurationException, match="`val_dataloader` must be implemented"):
381381
_ = dm.val_dataloader()
382-
with pytest.raises(MisconfigurationException):
382+
with pytest.raises(MisconfigurationException, match="`test_dataloader` must be implemented"):
383383
_ = dm.test_dataloader()
384384

385385
train_ds_sequence = [ds(), ds()]
@@ -392,9 +392,9 @@ def test_dm_init_from_datasets_dataloaders(iterable):
392392
call(train_ds_sequence[1], batch_size=4, shuffle=not iterable, num_workers=0, pin_memory=True),
393393
]
394394
)
395-
with pytest.raises(MisconfigurationException):
395+
with pytest.raises(MisconfigurationException, match="`val_dataloader` must be implemented"):
396396
_ = dm.val_dataloader()
397-
with pytest.raises(MisconfigurationException):
397+
with pytest.raises(MisconfigurationException, match="`test_dataloader` must be implemented"):
398398
_ = dm.test_dataloader()
399399

400400
valid_ds = ds()
@@ -405,21 +405,25 @@ def test_dm_init_from_datasets_dataloaders(iterable):
405405
dl_mock.assert_called_with(valid_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
406406
dm.test_dataloader()
407407
dl_mock.assert_called_with(test_ds, batch_size=2, shuffle=False, num_workers=0, pin_memory=True)
408-
with pytest.raises(MisconfigurationException):
408+
with pytest.raises(MisconfigurationException, match="`train_dataloader` must be implemented"):
409409
_ = dm.train_dataloader()
410410

411411
valid_dss = [ds(), ds()]
412412
test_dss = [ds(), ds()]
413-
dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, batch_size=4, num_workers=0)
413+
predict_dss = [ds(), ds()]
414+
dm = LightningDataModule.from_datasets(train_ds, valid_dss, test_dss, predict_dss, batch_size=4, num_workers=0)
414415
with mock.patch("pytorch_lightning.core.datamodule.DataLoader") as dl_mock:
415416
dm.val_dataloader()
416417
dm.test_dataloader()
418+
dm.predict_dataloader()
417419
dl_mock.assert_has_calls(
418420
[
419421
call(valid_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
420422
call(valid_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
421423
call(test_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
422424
call(test_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
425+
call(predict_dss[0], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
426+
call(predict_dss[1], batch_size=4, shuffle=False, num_workers=0, pin_memory=True),
423427
]
424428
)
425429

0 commit comments

Comments
 (0)