Skip to content

Commit

Permalink
align with pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Oct 10, 2023
1 parent 0883e4f commit a72c1c0
Showing 1 changed file with 4 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,21 @@ def _prepare_dist_test(self, broadcast_val: torch.Tensor, gather_val: Optional[L
# mocking torch.distributed.broadcast
def mock_broadcast(tensor: torch.Tensor, src: int):
tensor.copy_(broadcast_val)

self.mock_dist.broadcast.side_effect = mock_broadcast

# mocking torch.distributed.gather if gather_val is given
def mock_gather(tensor: torch.Tensor, gather_list: Optional[List[torch.Tensor]] = None, dst: int = 0):
for i in range(len(gather_list)):
gather_list[i].copy_(gather_val[i])

if gather_val is not None:
self.mock_dist.gather.side_effect = mock_gather

# revert some of torch function
def mock_tensor_cuda(self, *args, **kwargs):
return self

torch.Tensor.cuda = mock_tensor_cuda
self.mock_torch.tensor = torch.tensor
self.mock_torch.int64 = torch.int64
Expand Down Expand Up @@ -128,7 +131,7 @@ def test_try_batch_size_distributed_rank_0(self):
gather_val=[
torch.tensor([False, 3000], dtype=torch.int64),
torch.tensor([True, 4000], dtype=torch.int64),
]
],
)
mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80)
batch_size = 40
Expand Down

0 comments on commit a72c1c0

Please sign in to comment.