From 4140af5796c1bf2bc9c97d10d50328eb5d95c9fe Mon Sep 17 00:00:00 2001 From: Parth Raut Date: Thu, 12 Dec 2024 22:16:34 -0500 Subject: [PATCH] move to cuda --- zeus/utils/framework.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zeus/utils/framework.py b/zeus/utils/framework.py index 86b9bd8a..c938c6e1 100644 --- a/zeus/utils/framework.py +++ b/zeus/utils/framework.py @@ -122,7 +122,7 @@ def all_reduce( return object # wrap object in a tensor - tensor = torch.Tensor(object, device="cuda") + tensor = torch.Tensor(object).cuda() # determine operation if operation == "sum":