Skip to content

Commit 3d7095d

Browse files
committed
update tests
1 parent 17257a6 commit 3d7095d

File tree

4 files changed

+10
-4
lines changed

4 files changed

+10
-4
lines changed

tests/tests_pytorch/accelerators/test_accelerator_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,9 @@ def test_ipython_compatible_dp_strategy_gpu(_, monkeypatch):
225225
assert trainer.strategy.launcher is None
226226

227227

228+
@mock.patch("pytorch_lightning.strategies.launchers.spawn.mp.get_all_start_methods", return_value=["fork"])
228229
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
229-
def test_ipython_compatible_strategy_tpu(_, monkeypatch):
230+
def test_ipython_compatible_strategy_tpu(_, __, monkeypatch):
230231
monkeypatch.setattr(pytorch_lightning.utilities, "_IS_INTERACTIVE", True)
231232
trainer = Trainer(accelerator="tpu")
232233
assert trainer.strategy.launcher.is_interactive_compatible

tests/tests_pytorch/deprecated_api/test_remove_1-8.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,8 +1140,9 @@ def test_trainer_gpus(monkeypatch, trainer_kwargs):
11401140
assert trainer.gpus == trainer_kwargs["devices"]
11411141

11421142

1143-
def test_trainer_tpu_cores(monkeypatch):
1144-
monkeypatch.setattr(pytorch_lightning.accelerators.tpu.TPUAccelerator, "is_available", lambda _: True)
1143+
@mock.patch("pytorch_lightning.strategies.launchers.spawn.mp.get_all_start_methods", return_value=["fork"])
1144+
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
1145+
def test_trainer_tpu_cores(*_):
11451146
trainer = Trainer(accelerator="tpu", devices=8)
11461147
with pytest.deprecated_call(
11471148
match=(

tests/tests_pytorch/deprecated_api/test_remove_2-0.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_v2_0_0_deprecated_gpus(*_):
3636

3737
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.is_available", return_value=True)
3838
@mock.patch("pytorch_lightning.accelerators.tpu.TPUAccelerator.parse_devices", return_value=8)
39+
@mock.patch("pytorch_lightning.strategies.launchers.spawn.mp.get_all_start_methods", return_value=["fork"])
3940
def test_v2_0_0_deprecated_tpu_cores(*_):
4041
with pytest.deprecated_call(match=r"is deprecated in v1.7 and will be removed in v2.0."):
4142
_ = Trainer(tpu_cores=8)

tests/tests_pytorch/strategies/test_strategy_registry.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from unittest import mock
15+
1416
import pytest
1517

1618
from pytorch_lightning import Trainer
@@ -79,7 +81,8 @@ def test_deepspeed_strategy_registry_with_trainer(tmpdir, strategy):
7981
assert isinstance(trainer.strategy, DeepSpeedStrategy)
8082

8183

82-
def test_tpu_spawn_debug_strategy_registry(tmpdir):
84+
@mock.patch("pytorch_lightning.strategies.launchers.spawn.mp.get_all_start_methods", return_value=["fork"])
85+
def test_tpu_spawn_debug_strategy_registry(_):
8386

8487
strategy = "tpu_spawn_debug"
8588

0 commit comments

Comments
 (0)