Skip to content

Commit 3da28fd

Browse files
tchatoncarmoccamergify[bot]
authored
[feat] 1/2 Add trainer.predict (#5579)
* start adding predict * add predict * resolve test * add predict * remove limit_predict * update * add test for predict * typo * update on comments * remove predict_step * update ddp_shareded * check ddp_sharded * resolve on comments * resolve isort * update dp * add test dp 1 gpu * made default forward * resolve path * resolve bug * update on comments * resolve doc * resolve bug * update * resolve bug * update on comments * resolve pep8 * update test doc * update on comments * solve special tests * resolve bug * resolve flake8 * Update pytorch_lightning/callbacks/progress.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * add predict to LightningModule * missing predict * typo * rename is_prediction to _predicting * add * update * update * update doc Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 221c4a0 commit 3da28fd

28 files changed

+354
-59
lines changed

CHANGELOG.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Add support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590))
1313

14+
1415
- Add Support for multiple train loaders ([#1959](https://github.com/PyTorchLightning/pytorch-lightning/pull/1959))
1516

1617

@@ -68,6 +69,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6869
- Added compositional metrics ([#5464](https://github.com/PyTorchLightning/pytorch-lightning/pull/5464))
6970

7071

72+
- Added Trainer method `predict(...)` for high performence predictions ([#5579](https://github.com/PyTorchLightning/pytorch-lightning/pull/5579))
73+
74+
7175
- Added AUC/AUROC class interface ([#5479](https://github.com/PyTorchLightning/pytorch-lightning/pull/5479))
7276

7377

@@ -120,7 +124,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
120124

121125

122126
- Removed deprecated `TrainResult` ([#5323](https://github.com/PyTorchLightning/pytorch-lightning/pull/5323))
123-
127+
124128

125129
- Removed deprecated `EvalResult` ([#5633](https://github.com/PyTorchLightning/pytorch-lightning/pull/5633))
126130

@@ -155,7 +159,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
155159

156160
- Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195))
157161
- Logging only on `not should_accumulate()` during training ([#5417](https://github.com/PyTorchLightning/pytorch-lightning/pull/5417))
158-
- Resolve interpolation bug with Hydra ([#5406](https://github.com/PyTorchLightning/pytorch-lightning/pull/5406))
162+
- Resolve interpolation bug with Hydra ([#5406](https://github.com/PyTorchLightning/pytorch-lightning/pull/5406))
159163
- Check environ before selecting a seed to prevent warning message ([#4743](https://github.com/PyTorchLightning/pytorch-lightning/pull/4743))
160164

161165

docs/source/starter/introduction_guide.rst

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,8 +881,30 @@ Or maybe we have a model that we use to do generation
881881
z = sample_noise()
882882
generated_imgs = model(z)
883883
884-
How you split up what goes in ``forward`` vs ``training_step`` depends on how you want to use this model for
884+
885+
To perform inference at scale, it is possible to use ``trainer.predict`` with LightningModule ``predict`` function
886+
By default, LightningModule ``predict`` calls forward, but it can be overriden to add any processing logic.
887+
888+
.. code-block:: python
889+
890+
class LitMNISTDreamer(LightningModule):
891+
892+
def forward(self, z):
893+
imgs = self.decoder(z)
894+
return imgs
895+
896+
def predict(self, batch, batch_idx: int , dataloader_idx: int = None):
897+
return self(batch)
898+
899+
900+
model = LitMNISTDreamer()
901+
trainer.predict(model, datamodule)
902+
903+
904+
How you split up what goes in ``forward`` vs ``training_step`` vs ``predict`` depends on how you want to use this model for
885905
prediction.
906+
However, we recommend ``forward`` to contain only tensor operation with your model, ``training_step`` to encapsulate ``forward`` logic with logging,
907+
metrics and loss computation and ``predict`` to encapsulate ``forward`` with preprocess, postprocess functions.
886908

887909
----------------
888910

pytorch_lightning/accelerators/legacy/cpu_accelerator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def validation_step(self, args):
7777
def test_step(self, args):
7878
return self._step(self.trainer.model.test_step, args)
7979

80+
def predict(self, args):
81+
return self._step(self.trainer.model.predict, args)
82+
8083
def sync_tensor(self,
8184
tensor: Union[torch.Tensor],
8285
group: Optional[Any] = None,

pytorch_lightning/accelerators/legacy/ddp2_accelerator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ def validation_step(self, args):
6666
def test_step(self, args):
6767
return self._step(args)
6868

69+
def predict(self, args):
70+
return self._step(args)
71+
6972
def _step(self, args):
7073
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
7174
if self.trainer.amp_backend == AMPType.NATIVE:

pytorch_lightning/accelerators/legacy/ddp_accelerator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ def validation_step(self, args):
164164
def test_step(self, args):
165165
return self._step(args)
166166

167+
def predict(self, args):
168+
return self._step(args)
169+
167170
def _step(self, args):
168171
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
169172
if self.trainer.amp_backend == AMPType.NATIVE:

pytorch_lightning/accelerators/legacy/ddp_cpu_spawn_accelerator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def validation_step(self, args):
178178
def test_step(self, args):
179179
return self._step(args)
180180

181+
def predict(self, args):
182+
return self._step(args)
183+
181184
def _step(self, args):
182185
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
183186
if self.trainer.amp_backend == AMPType.NATIVE:

pytorch_lightning/accelerators/legacy/ddp_hpc_accelerator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def validation_step(self, args):
8383
def test_step(self, args):
8484
return self._step(args)
8585

86+
def predict(self, args):
87+
return self._step(args)
88+
8689
def _step(self, args):
8790
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
8891
if self.trainer.amp_backend == AMPType.NATIVE:

pytorch_lightning/accelerators/legacy/ddp_spawn_accelerator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,9 @@ def validation_step(self, args):
212212
def test_step(self, args):
213213
return self._step(args)
214214

215+
def predict(self, args):
216+
return self._step(args)
217+
215218
def _step(self, args):
216219
args = self.ddp_plugin.on_before_forward(self.trainer.get_model(), *args)
217220
if self.trainer.amp_backend == AMPType.NATIVE:

pytorch_lightning/accelerators/legacy/dp_accelerator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def validation_step(self, args):
132132
def test_step(self, args):
133133
return self._step(args)
134134

135+
def predict(self, args):
136+
return self._step(args)
137+
135138
def training_step_end(self, output):
136139
if isinstance(output, Result):
137140
output.dp_reduce()

pytorch_lightning/accelerators/legacy/gpu_accelerator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def validation_step(self, args):
8585
def test_step(self, args):
8686
return self._step(self.trainer.model.test_step, args)
8787

88+
def predict(self, args):
89+
return self._step(self.trainer.model.predict, args)
90+
8891
def to_device(self, batch):
8992
gpu_id = 0
9093
if isinstance(self.trainer.data_parallel_device_ids, list):

0 commit comments

Comments
 (0)