Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import contextlib
import datetime
import os
import sys
from functools import partial
from unittest import mock

Expand Down Expand Up @@ -272,7 +271,6 @@ def _test_two_groups(strategy, left_collective, right_collective):


@skip_distributed_unavailable
@pytest.mark.skip(reason="TODO(carmocca): causing hangs in CI")
def test_two_groups():
if sys.platform == "win32" and (sys.version_info.major, sys.version_info.minor) == (3, 10):
pytest.skip("Unresolved hang")
collective_launch(_test_two_groups, [torch.device("cpu")] * 3, num_groups=2)
33 changes: 17 additions & 16 deletions tests/tests_pytorch/utilities/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
unpicklable_function = lambda: None


def model_cases():
def model_and_trainer_cases():
class TestHparamsNamespace(LightningModule):
learning_rate = 1

Expand Down Expand Up @@ -64,44 +64,45 @@ class TestModel4(LightningModule): # fail case
batch_size = 1

model4 = TestModel4()
trainer = Trainer()
model4.trainer = trainer
trainer1 = Trainer()
model4.trainer = trainer1
datamodule = LightningDataModule()
datamodule.batch_size = 8
trainer.datamodule = datamodule
trainer1.datamodule = datamodule

model5 = LightningModule()
model5.trainer = trainer
model5.trainer = trainer1

class TestModel6(LightningModule): # test for datamodule w/ hparams w/o attribute (should use datamodule)
hparams = TestHparamsDict

model6 = TestModel6()
model6.trainer = trainer
model6.trainer = trainer1

TestHparamsDict2 = {"batch_size": 2}

class TestModel7(LightningModule): # test for datamodule w/ hparams w/ attribute (should use datamodule)
hparams = TestHparamsDict2

model7 = TestModel7()
model7.trainer = trainer
model7.trainer = trainer1

class TestDataModule8(LightningDataModule): # test for hparams dict
hparams = TestHparamsDict2

model8 = TestModel1()
trainer = Trainer()
model8.trainer = trainer
trainer2 = Trainer()
model8.trainer = trainer2
datamodule = TestDataModule8()
trainer.datamodule = datamodule
trainer2.datamodule = datamodule

return model1, model2, model3, model4, model5, model6, model7, model8
return (model1, model2, model3, model4, model5, model6, model7, model8), (trainer1, trainer2)


def test_lightning_hasattr():
"""Test that the lightning_hasattr works in all cases."""
model1, model2, model3, model4, model5, model6, model7, model8 = models = model_cases()
models, _ = model_and_trainer_cases()
model1, model2, model3, model4, model5, model6, model7, model8 = models
assert lightning_hasattr(model1, "learning_rate"), "lightning_hasattr failed to find namespace variable"
assert lightning_hasattr(model2, "learning_rate"), "lightning_hasattr failed to find hparams namespace variable"
assert lightning_hasattr(model3, "learning_rate"), "lightning_hasattr failed to find hparams dict variable"
Expand All @@ -121,12 +122,12 @@ def test_lightning_hasattr():

def test_lightning_getattr():
"""Test that the lightning_getattr works in all cases."""
models = model_cases()
models, _ = model_and_trainer_cases()
*__, model5, model6, model7, model8 = models
for i, m in enumerate(models[:3]):
value = lightning_getattr(m, "learning_rate")
assert value == i, "attribute not correctly extracted"

model5, model6, model7, model8 = models[4:]
assert lightning_getattr(model5, "batch_size") == 8, "batch_size not correctly extracted"
assert lightning_getattr(model6, "batch_size") == 8, "batch_size not correctly extracted"
assert lightning_getattr(model7, "batch_size") == 8, "batch_size not correctly extracted"
Expand All @@ -142,12 +143,12 @@ def test_lightning_getattr():

def test_lightning_setattr(tmpdir):
"""Test that the lightning_setattr works in all cases."""
models = model_cases()
models, _ = model_and_trainer_cases()
*__, model5, model6, model7, model8 = models
for m in models[:3]:
lightning_setattr(m, "learning_rate", 10)
assert lightning_getattr(m, "learning_rate") == 10, "attribute not correctly set"

model5, model6, model7, model8 = models[4:]
lightning_setattr(model5, "batch_size", 128)
lightning_setattr(model6, "batch_size", 128)
lightning_setattr(model7, "batch_size", 128)
Expand Down