Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply dynamo to training_step, validation_step, test_step, predict_step #15957

Merged
merged 2 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for `torch.compile` ([#15922](https://github.com/Lightning-AI/lightning/pull/15922), [15957](https://github.com/Lightning-AI/lightning/pull/15957))


- Added support for DDP with `LRFinder` ([#15304](https://github.com/Lightning-AI/lightning/pull/15304))


Expand Down
12 changes: 12 additions & 0 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1980,9 +1980,17 @@ def from_compiled(cls, model: "torch._dynamo.OptimizedModule") -> "pl.LightningM
"compiler": "dynamo",
"dynamo_ctx": model.dynamo_ctx,
"original_forward": orig_module.forward,
"original_training_step": orig_module.training_step,
"original_validation_step": orig_module.validation_step,
"original_test_step": orig_module.test_step,
"original_predict_step": orig_module.predict_step,
}

orig_module.forward = model.dynamo_ctx(orig_module.forward) # type: ignore[assignment]
orig_module.training_step = model.dynamo_ctx(orig_module.training_step) # type: ignore[assignment]
orig_module.validation_step = model.dynamo_ctx(orig_module.validation_step) # type: ignore[assignment]
orig_module.test_step = model.dynamo_ctx(orig_module.test_step) # type: ignore[assignment]
orig_module.predict_step = model.dynamo_ctx(orig_module.predict_step) # type: ignore[assignment]
return orig_module

@classmethod
Expand Down Expand Up @@ -2011,6 +2019,10 @@ def to_uncompiled(cls, model: Union["pl.LightningModule", "torch._dynamo.Optimiz
raise ValueError("`model` must either be an instance of torch._dynamo.OptimizedModule or LightningModule")

model.forward = model._compiler_ctx["original_forward"] # type: ignore[assignment]
model.training_step = model._compiler_ctx["original_training_step"] # type: ignore[assignment]
model.validation_step = model._compiler_ctx["original_validation_step"] # type: ignore[assignment]
model.test_step = model._compiler_ctx["original_test_step"] # type: ignore[assignment]
model.predict_step = model._compiler_ctx["original_predict_step"] # type: ignore[assignment]
model._compiler_ctx = None

return model
Expand Down
21 changes: 19 additions & 2 deletions tests/tests_pytorch/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.optim import Adam, SGD

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel, DemoModel
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_11, _TORCH_GREATER_EQUAL_1_13
Expand Down Expand Up @@ -446,15 +446,32 @@ def test_trainer_reference_recursively():
@RunIf(min_torch="1.14.0.dev20221202")
def test_compile_uncompile():

lit_model = DemoModel()
lit_model = BoringModel()
model_compiled = torch.compile(lit_model)

lit_model_compiled = LightningModule.from_compiled(model_compiled)

def has_dynamo(fn):
return any(el for el in dir(fn) if el.startswith("_torchdynamo"))

assert isinstance(lit_model_compiled, LightningModule)
assert lit_model_compiled._compiler_ctx is not None
assert has_dynamo(lit_model_compiled.forward)
assert has_dynamo(lit_model_compiled.training_step)
assert has_dynamo(lit_model_compiled.validation_step)
assert has_dynamo(lit_model_compiled.test_step)
assert has_dynamo(lit_model_compiled.predict_step)

lit_model_orig = LightningModule.to_uncompiled(lit_model)

assert lit_model_orig._compiler_ctx is None
assert lit_model_orig.forward == lit_model.forward
assert lit_model_orig.training_step == lit_model.training_step
assert lit_model_orig.validation_step == lit_model.validation_step
assert lit_model_orig.test_step == lit_model.test_step
assert lit_model_orig.predict_step == lit_model.predict_step
assert not has_dynamo(lit_model_orig.forward)
assert not has_dynamo(lit_model_orig.training_step)
assert not has_dynamo(lit_model_orig.validation_step)
assert not has_dynamo(lit_model_orig.test_step)
assert not has_dynamo(lit_model_orig.predict_step)
3 changes: 1 addition & 2 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from pytorch_lightning.demos.boring_classes import (
BoringDataModule,
BoringModel,
DemoModel,
RandomDataset,
RandomIterableDataset,
RandomIterableDatasetWithLen,
Expand Down Expand Up @@ -2245,7 +2244,7 @@ def on_fit_start(self):
# TODO: replace with 1.14 when it is released
@RunIf(min_torch="1.14.0.dev20221202")
def test_trainer_compiled_model():
model = DemoModel()
model = BoringModel()

model = torch.compile(model)

Expand Down