From 072b9861468ae443da6489249e34320d5560277c Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 9 Oct 2023 16:06:24 +0300 Subject: [PATCH] Small improvements to utils --- .../training/utils/distributed_training_utils.py | 3 ++- src/super_gradients/training/utils/utils.py | 9 ++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/super_gradients/training/utils/distributed_training_utils.py b/src/super_gradients/training/utils/distributed_training_utils.py index da790170a5..2887be03c4 100755 --- a/src/super_gradients/training/utils/distributed_training_utils.py +++ b/src/super_gradients/training/utils/distributed_training_utils.py @@ -436,7 +436,8 @@ def maybe_all_reduce_tensor_average(tensor: torch.Tensor) -> torch.Tensor: :return: """ if is_distributed(): - tensor = distributed_all_reduce_tensor_average(tensor=tensor, n=torch.distributed.get_world_size()) + # .to_dense() is required to ensure we can do maybe_all_reduce_tensor_average(some_vector[3]) + tensor = distributed_all_reduce_tensor_average(tensor=tensor.to_dense(), n=torch.distributed.get_world_size()) return tensor diff --git a/src/super_gradients/training/utils/utils.py b/src/super_gradients/training/utils/utils.py index c04f3e04ac..ab67e09ca4 100755 --- a/src/super_gradients/training/utils/utils.py +++ b/src/super_gradients/training/utils/utils.py @@ -38,9 +38,12 @@ def convert_to_tensor(array, dtype=None, device=None): :param array: torch.tensor / Numpy array / List """ if not torch.is_tensor(array): - array = torch.tensor(array) - - return array.to(device=device, dtype=dtype) + if isinstance(array, np.ndarray): + return torch.from_numpy(array).to(device=device, dtype=dtype) + else: + return torch.tensor(array, device=device, dtype=dtype) + else: + return array.to(device=device, dtype=dtype) class HpmStruct: