From d2589fc817fe0c413e18c71eb52da47b0eed1015 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 3 Jul 2024 02:07:14 +0200 Subject: [PATCH] update test --- tests/tests_pytorch/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 802c1a17bc448..f5e90fdabf944 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1766,7 +1766,7 @@ def current_memory(): trainer.fit(model) assert trainer.strategy.model is model - assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu") + assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cuda", 0) assert trainer.callback_metrics["train_loss"].device == torch.device("cpu") assert current_memory() <= initial