Skip to content

Commit

Permalink
Introduce PrecisionPlugin.forward_context() (#9988)
Browse files Browse the repository at this point in the history
Co-authored-by: thomas chaton <thomas@grid.ai>
  • Loading branch information
awaelchli and tchaton authored Oct 18, 2021
1 parent 3f355d0 commit 10d0b41
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 58 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `strategy` argument to Trainer ([#8597](https://github.com/PyTorchLightning/pytorch-lightning/pull/8597))


- LightningLite:
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))


### Changed

- Setting `Trainer(accelerator="ddp_cpu")` now does not spawn a subprocess if `num_processes` is kept `1` along with `num_nodes > 1` ([#9603](https://github.com/PyTorchLightning/pytorch-lightning/pull/9603)).
Expand Down
32 changes: 1 addition & 31 deletions pytorch_lightning/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,37 +92,7 @@ def connect(
return super().connect(model, optimizers, lr_schedulers)

@contextmanager
def train_step_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type.
See: :meth:`torch.set_default_tensor_type`
"""
torch.set_default_tensor_type(torch.DoubleTensor)
yield
torch.set_default_tensor_type(torch.FloatTensor)

@contextmanager
def val_step_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type.
See: :meth:`torch.set_default_tensor_type`
"""
torch.set_default_tensor_type(torch.DoubleTensor)
yield
torch.set_default_tensor_type(torch.FloatTensor)

@contextmanager
def test_step_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type.
See: :meth:`torch.set_default_tensor_type`
"""
torch.set_default_tensor_type(torch.DoubleTensor)
yield
torch.set_default_tensor_type(torch.FloatTensor)

@contextmanager
def predict_step_context(self) -> Generator[None, None, None]:
def forward_context(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type.
See: :meth:`torch.set_default_tensor_type`
Expand Down
20 changes: 1 addition & 19 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,25 +102,7 @@ def autocast_context_manager(self) -> torch.cuda.amp.autocast:
return torch.cuda.amp.autocast()

@contextmanager
def train_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context."""
with self.autocast_context_manager():
yield

@contextmanager
def val_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context."""
with self.autocast_context_manager():
yield

@contextmanager
def test_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context."""
with self.autocast_context_manager():
yield

@contextmanager
def predict_step_context(self) -> Generator[None, None, None]:
def forward_context(self) -> Generator[None, None, None]:
"""Enable autocast context."""
with self.autocast_context_manager():
yield
Expand Down
25 changes: 17 additions & 8 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,30 @@ def post_dispatch(self) -> None:
"""Hook to do something after the training/evaluation/prediction finishes."""

@contextlib.contextmanager
def train_step_context(self) -> Generator:
"""A contextmanager for the training step."""
def forward_context(self) -> Generator[None, None, None]:
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
yield

@contextlib.contextmanager
def val_step_context(self) -> Generator:
def train_step_context(self) -> Generator[None, None, None]:
"""A contextmanager for the training step."""
with self.forward_context():
yield

@contextlib.contextmanager
def val_step_context(self) -> Generator[None, None, None]:
"""A contextmanager for the validation step."""
yield
with self.forward_context():
yield

@contextlib.contextmanager
def test_step_context(self) -> Generator:
def test_step_context(self) -> Generator[None, None, None]:
"""A contextmanager for the test step."""
yield
with self.forward_context():
yield

@contextlib.contextmanager
def predict_step_context(self) -> Generator:
def predict_step_context(self) -> Generator[None, None, None]:
"""A contextmanager for the predict step."""
yield
with self.forward_context():
yield

0 comments on commit 10d0b41

Please sign in to comment.