Skip to content

Commit

Permalink
Small improvements to utils (#1513)
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe authored Oct 9, 2023
1 parent 3c342bb commit 8de553f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ def maybe_all_reduce_tensor_average(tensor: torch.Tensor) -> torch.Tensor:
:return:
"""
if is_distributed():
tensor = distributed_all_reduce_tensor_average(tensor=tensor, n=torch.distributed.get_world_size())
# .to_dense() is required to ensure we can do maybe_all_reduce_tensor_average(some_vector[3])
tensor = distributed_all_reduce_tensor_average(tensor=tensor.to_dense(), n=torch.distributed.get_world_size())
return tensor


Expand Down
9 changes: 6 additions & 3 deletions src/super_gradients/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,12 @@ def convert_to_tensor(array, dtype=None, device=None):
:param array: torch.tensor / Numpy array / List
"""
if not torch.is_tensor(array):
array = torch.tensor(array)

return array.to(device=device, dtype=dtype)
if isinstance(array, np.ndarray):
return torch.from_numpy(array).to(device=device, dtype=dtype)
else:
return torch.tensor(array, device=device, dtype=dtype)
else:
return array.to(device=device, dtype=dtype)


class HpmStruct:
Expand Down

0 comments on commit 8de553f

Please sign in to comment.