Skip to content

Commit 8bb74cd

Browse files
committed
simple
1 parent 48c16cd commit 8bb74cd

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

tests/trainer/optimization/test_parity_optimization.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def parity_automatic_train_with_one_optimizer(ctx):
163163
accumulate_grad_batches=accumulate_grad_batches if not ctx["vanilla"] else 1,
164164
amp_backend=ctx["amp_backend"],
165165
precision=ctx["precision"],
166-
gpus=1
166+
gpus=1 if ctx["device"] == 'cuda' else 0,
167167
)
168168
trainer.fit(model)
169169

@@ -174,35 +174,37 @@ def parity_automatic_train_with_one_optimizer(ctx):
174174
################## TESTS ################## # noqa E266
175175

176176

177-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
178-
@pytest.mark.parametrize(["precision", "amp_backend"], [
179-
pytest.param(16, "native"),
180-
pytest.param(32, "native"),
177+
@pytest.mark.parametrize(["precision", "amp_backend", "device"], [
178+
pytest.param(16, "native", "cuda",
179+
marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")),
180+
pytest.param(32, "native", "cpu"),
181181
])
182-
@pytest.mark.parametrize('accumulate_grad_batches', [1, 2])
183-
def test_parity_automatic_training_with_one_optimizer(tmpdir, amp_backend, precision, accumulate_grad_batches):
182+
@pytest.mark.parametrize('accumulate_grad_batches', [1])
183+
def test_parity_automatic_training_with_one_optimizer(tmpdir, amp_backend, precision, device, accumulate_grad_batches):
184184
"""
185185
Test training with accumulated gradients with and within enable_pl_optimizer reaches the same weights
186186
"""
187187
# prepare arguments
188188
if accumulate_grad_batches > 1:
189189
accumulate_grad_batches = np.random.randint(2, accumulate_grad_batches + 1)
190190

191-
ctx = {}
192-
ctx["tmpdir"] = tmpdir
193-
ctx["accumulate_grad_batches"] = accumulate_grad_batches
194-
ctx["amp_backend"] = amp_backend
195-
ctx["precision"] = precision
196-
ctx["using_amp"] = (amp_backend in ["native"]) and precision == 16
197-
ctx["max_epochs"] = np.random.randint(1, 3)
198-
ctx["limit_train_batches"] = np.random.randint(11, 27)
191+
ctx = dict(
192+
tmpdir=tmpdir,
193+
accumulate_grad_batches=accumulate_grad_batches,
194+
amp_backend=amp_backend,
195+
precision=precision,
196+
using_amp=(amp_backend in ["native"]) and precision == 16,
197+
max_epochs=np.random.randint(1, 3),
198+
limit_train_batches=np.random.randint(11, 27),
199+
limit_val_batches=0,
200+
initial_weights={},
201+
enable_pl_optimizer=True,
202+
mocked=False,
203+
vanilla=False,
204+
device=device,
205+
)
199206
expected_batches = ctx["max_epochs"] * ctx["limit_train_batches"]
200207
ctx["expected_batches"] = expected_batches
201-
ctx["limit_val_batches"] = 0
202-
ctx["initial_weights"] = {}
203-
ctx["enable_pl_optimizer"] = True
204-
ctx["mocked"] = False
205-
ctx["vanilla"] = False
206208

207209
model_wi_pl_optimizer = parity_automatic_train_with_one_optimizer(ctx)
208210

0 commit comments

Comments
 (0)