|
13 | 13 | # limitations under the License. |
14 | 14 | import os |
15 | 15 | import sys |
16 | | -from contextlib import nullcontext |
17 | 16 | from unittest import mock |
18 | 17 |
|
19 | 18 | import pytest |
20 | 19 | import torch |
21 | | -from lightning_utilities.core.imports import RequirementCache |
22 | 20 |
|
23 | | -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_4 |
| 21 | +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 |
24 | 22 | from lightning.pytorch import LightningModule, Trainer |
25 | 23 | from lightning.pytorch.demos.boring_classes import BoringModel |
26 | 24 | from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled |
|
34 | 32 | @pytest.mark.skipif(sys.platform == "darwin", reason="fatal error: 'omp.h' file not found") |
35 | 33 | @RunIf(dynamo=True, deepspeed=True) |
36 | 34 | @mock.patch("lightning.pytorch.trainer.call._call_and_handle_interrupt") |
37 | | -def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0): |
| 35 | +def test_trainer_compiled_model_deepspeed(_, tmp_path, monkeypatch, mps_count_0): |
38 | 36 | trainer_kwargs = { |
39 | 37 | "default_root_dir": tmp_path, |
40 | 38 | "fast_dev_run": True, |
@@ -69,22 +67,52 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0): |
69 | 67 | assert trainer.model._compiler_ctx is None |
70 | 68 |
|
71 | 69 | # some strategies do not support it |
72 | | - if RequirementCache("deepspeed"): |
73 | | - compiled_model = torch.compile(model) |
74 | | - mock_cuda_count(monkeypatch, 2) |
75 | | - |
76 | | - # TODO: Update deepspeed to avoid deprecation warning for `torch.cuda.amp.custom_fwd` on import |
77 | | - warn_context = ( |
78 | | - pytest.warns(FutureWarning, match="torch.cuda.amp.*is deprecated") |
79 | | - if _TORCH_GREATER_EQUAL_2_4 |
80 | | - else nullcontext() |
81 | | - ) |
82 | | - |
83 | | - with warn_context: |
84 | | - trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs) |
85 | | - |
86 | | - with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"): |
87 | | - trainer.fit(compiled_model) |
| 70 | + compiled_model = torch.compile(model) |
| 71 | + mock_cuda_count(monkeypatch, 2) |
| 72 | + |
| 73 | + trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs) |
| 74 | + |
| 75 | + with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"): |
| 76 | + trainer.fit(compiled_model) |
| 77 | + |
| 78 | + |
| 79 | +# https://github.com/pytorch/pytorch/issues/95708 |
| 80 | +@pytest.mark.skipif(sys.platform == "darwin", reason="fatal error: 'omp.h' file not found") |
| 81 | +@RunIf(dynamo=True) |
| 82 | +@mock.patch("lightning.pytorch.trainer.call._call_and_handle_interrupt") |
| 83 | +def test_trainer_compiled_model_ddp(_, tmp_path, monkeypatch, mps_count_0): |
| 84 | + trainer_kwargs = { |
| 85 | + "default_root_dir": tmp_path, |
| 86 | + "fast_dev_run": True, |
| 87 | + "logger": False, |
| 88 | + "enable_checkpointing": False, |
| 89 | + "enable_model_summary": False, |
| 90 | + "enable_progress_bar": False, |
| 91 | + } |
| 92 | + |
| 93 | + model = BoringModel() |
| 94 | + compiled_model = torch.compile(model) |
| 95 | + assert model._compiler_ctx is compiled_model._compiler_ctx # shared reference |
| 96 | + |
| 97 | + # can train with compiled model |
| 98 | + trainer = Trainer(**trainer_kwargs) |
| 99 | + trainer.fit(compiled_model) |
| 100 | + assert trainer.model._compiler_ctx["compiler"] == "dynamo" |
| 101 | + |
| 102 | + # the compiled model can be uncompiled |
| 103 | + to_uncompiled_model = to_uncompiled(compiled_model) |
| 104 | + assert model._compiler_ctx is None |
| 105 | + assert compiled_model._compiler_ctx is None |
| 106 | + assert to_uncompiled_model._compiler_ctx is None |
| 107 | + |
| 108 | + # the compiled model needs to be passed |
| 109 | + with pytest.raises(ValueError, match="required to be a compiled LightningModule"): |
| 110 | + to_uncompiled(to_uncompiled_model) |
| 111 | + |
| 112 | + # the uncompiled model can be fitted |
| 113 | + trainer = Trainer(**trainer_kwargs) |
| 114 | + trainer.fit(model) |
| 115 | + assert trainer.model._compiler_ctx is None |
88 | 116 |
|
89 | 117 | # ddp does |
90 | 118 | trainer = Trainer(strategy="ddp", **trainer_kwargs) |
|
0 commit comments