diff --git a/src/otx/algorithms/common/adapters/torch/utils/bs_search_algo.py b/src/otx/algorithms/common/adapters/torch/utils/bs_search_algo.py index 0e8b7343ac6..5b1457c6ede 100644 --- a/src/otx/algorithms/common/adapters/torch/utils/bs_search_algo.py +++ b/src/otx/algorithms/common/adapters/torch/utils/bs_search_algo.py @@ -6,6 +6,7 @@ from typing import Callable, Dict, Tuple import torch +import torch.distributed as dist from otx.algorithms.common.utils.logger import get_logger @@ -40,7 +41,7 @@ def __init__(self, train_func: Callable[[int], None], default_bs: int, max_bs: i def _try_batch_size(self, bs: int) -> Tuple[bool, int]: cuda_oom = False - torch.cuda.reset_max_memory_allocated(device=None) + torch.cuda.reset_max_memory_cached(device=None) torch.cuda.empty_cache() try: @@ -51,18 +52,42 @@ def _try_batch_size(self, bs: int) -> Tuple[bool, int]: else: raise e - max_memory_allocated = torch.cuda.max_memory_allocated(device=None) + max_memory_reserved = torch.cuda.max_memory_reserved(device=None) + + if dist.is_initialized(): # Aggregate all results and broadcast to all processes + rank = dist.get_rank() + try_result = torch.tensor([int(cuda_oom), max_memory_reserved], dtype=torch.int64).cuda() + + if rank == 0: + try_result_arr = [torch.empty(2, dtype=torch.int64).cuda() for _ in range(dist.get_world_size())] + dist.gather(try_result, gather_list=try_result_arr, dst=0) + else: + dist.gather(try_result, dst=0) + + if rank == 0: + try_result_arr = torch.stack(try_result_arr) + cuda_oom = torch.any(try_result_arr[:, 0]) # type: ignore + max_memory_reserved = torch.max(try_result_arr[:, 1]) # type: ignore + total_try_result = torch.tensor([cuda_oom, max_memory_reserved], dtype=torch.int64).cuda() + else: + total_try_result = torch.empty(2, dtype=torch.int64).cuda() + + dist.broadcast(total_try_result, src=0) + + cuda_oom = total_try_result[0].bool().item() + max_memory_reserved = total_try_result[1].item() + if not cuda_oom: # Because heapq only supports min heap, use negatized batch size - self._bs_try_history[bs] = max_memory_allocated + self._bs_try_history[bs] = max_memory_reserved logger.debug( f"Adapting Batch size => bs : {bs}, CUDA_OOM : {cuda_oom}, " - f"GPU memory usage : {max_memory_allocated / self._total_mem}%" + f"GPU memory usage : {max_memory_reserved / self._total_mem}%" ) torch.cuda.empty_cache() - return cuda_oom, max_memory_allocated + return cuda_oom, max_memory_reserved @staticmethod def _get_even_center_val(val1: int, val2: int) -> int: @@ -82,10 +107,10 @@ def auto_decrease_batch_size(self) -> int: lowest_unavailable_bs = self._default_bs + 2 while True: - cuda_oom, max_memory_allocated = self._try_batch_size(current_bs) + cuda_oom, max_memory_reserved = self._try_batch_size(current_bs) # If GPU memory usage is too close to limit, CUDA OOM can be raised during training - if cuda_oom or max_memory_allocated > self._mem_upper_bound: + if cuda_oom or max_memory_reserved > self._mem_upper_bound: if current_bs < lowest_unavailable_bs: lowest_unavailable_bs = current_bs current_bs = self._get_even_center_val(current_bs, available_bs) diff --git a/tests/unit/algorithms/common/adapters/torch/utils/test_bs_search_algo.py b/tests/unit/algorithms/common/adapters/torch/utils/test_bs_search_algo.py index d0649dc29bf..a347968dc5e 100644 --- a/tests/unit/algorithms/common/adapters/torch/utils/test_bs_search_algo.py +++ b/tests/unit/algorithms/common/adapters/torch/utils/test_bs_search_algo.py @@ -1,4 +1,7 @@ +from typing import Optional, List + import pytest +import torch from tests.test_suite.e2e_test_system import e2e_pytest_unit from otx.algorithms.common.adapters.torch.utils import BsSearchAlgo @@ -11,6 +14,8 @@ class TestBsSearchAlgo: def setup_test(self, mocker): self.mock_torch = mocker.patch.object(bs_search_algo, "torch") self.mock_torch.cuda.mem_get_info.return_value = (1, 10000) + self.mock_dist = mocker.patch.object(bs_search_algo, "dist") + self.mock_dist.is_initialized.return_value = False def test_init(self, mocker): BsSearchAlgo(mocker.MagicMock(), 4, 10) @@ -35,11 +40,122 @@ def mock_train_func(batch_size): else: mem_usage = 8500 * batch_size / max_runnable_bs - self.mock_torch.cuda.max_memory_allocated.return_value = mem_usage + self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage return mem_usage return mock_train_func + def test_try_batch_size(self): + mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80) + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) + batch_size = 40 + + cuda_oom, max_memory_reserved = bs_search_algo._try_batch_size(batch_size) + + assert cuda_oom is False + assert max_memory_reserved == mock_train_func(batch_size) + self.mock_torch.cuda.reset_max_memory_cached.assert_called() + self.mock_torch.cuda.empty_cache.assert_called() + + def test_try_batch_size_cuda_oom(self): + mock_train_func = self.get_mock_train_func(cuda_oom_bound=100, max_runnable_bs=80) + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) + batch_size = 200 + + cuda_oom, _ = bs_search_algo._try_batch_size(batch_size) + + assert cuda_oom is True + self.mock_torch.cuda.reset_max_memory_cached.assert_called() + self.mock_torch.cuda.empty_cache.assert_called() + + def _prepare_dist_test(self, broadcast_val: torch.Tensor, gather_val: Optional[List[torch.Tensor]] = None): + self.mock_dist.is_initialized.return_value = True + + # mocking torch.distributed.broadcast + def mock_broadcast(tensor: torch.Tensor, src: int): + tensor.copy_(broadcast_val) + + self.mock_dist.broadcast.side_effect = mock_broadcast + + # mocking torch.distributed.gather if gather_val is given + def mock_gather(tensor: torch.Tensor, gather_list: Optional[List[torch.Tensor]] = None, dst: int = 0): + for i in range(len(gather_list)): + gather_list[i].copy_(gather_val[i]) + + if gather_val is not None: + self.mock_dist.gather.side_effect = mock_gather + + # revert some of torch function + def mock_tensor_cuda(self, *args, **kwargs): + return self + + torch.Tensor.cuda = mock_tensor_cuda + self.mock_torch.tensor = torch.tensor + self.mock_torch.int64 = torch.int64 + self.mock_torch.max = torch.max + self.mock_torch.any = torch.any + self.mock_torch.stack = torch.stack + self.mock_torch.empty = torch.empty + + def test_try_batch_size_distributed_not_rank_0(self): + self.mock_dist.get_rank.return_value = 1 + broadcasted_cuda_oom = False + broadcasted_max_memory_reserved = 4000 + self._prepare_dist_test( + broadcast_val=torch.tensor([broadcasted_cuda_oom, broadcasted_max_memory_reserved], dtype=torch.int64) + ) + mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80) + batch_size = 40 + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) + w1_max_memory_reserved = mock_train_func(batch_size) + + cuda_oom, max_memory_reserved = bs_search_algo._try_batch_size(batch_size) + + # check dist.gather is called and get [cuda_oom, maxmemory_reserved] as arguments. + self.mock_dist.gather.assert_called_once() + assert self.mock_dist.gather.call_args.args[0][0].item() == False + assert self.mock_dist.gather.call_args.args[0][1].item() == w1_max_memory_reserved + assert self.mock_dist.gather.call_args.kwargs["dst"] == 0 + # check dist.broadcast is called + self.mock_dist.broadcast.assert_called_once() + assert self.mock_dist.broadcast.call_args.kwargs["src"] == 0 + # check broadcased values are returned + assert cuda_oom is broadcasted_cuda_oom + assert max_memory_reserved == broadcasted_max_memory_reserved + + def test_try_batch_size_distributed_rank_0(self): + self.mock_dist.get_rank.return_value = 0 + self.mock_dist.get_world_size.return_value = 2 + self._prepare_dist_test( + broadcast_val=torch.tensor([True, 4000], dtype=torch.int64), + gather_val=[ + torch.tensor([False, 3000], dtype=torch.int64), + torch.tensor([True, 4000], dtype=torch.int64), + ], + ) + mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80) + batch_size = 40 + bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000) + w0_max_memory_reserved = mock_train_func(batch_size) + + cuda_oom, max_memory_reserved = bs_search_algo._try_batch_size(batch_size) + + # check dist.gather is called and get [cuda_oom, max_memory_reserved] as arguments. + self.mock_dist.gather.assert_called_once() + assert self.mock_dist.gather.call_args.args[0][0].item() == False + assert self.mock_dist.gather.call_args.args[0][1].item() == w0_max_memory_reserved + assert self.mock_dist.gather.call_args.kwargs["dst"] == 0 + # check if any process get cuda oom then set cuda_oom to True and + # set max_memory_reserved to maximum value of processes' + self.mock_dist.broadcast.assert_called_once() + self.mock_dist.broadcast.assert_called_once() + assert self.mock_dist.broadcast.call_args.kwargs["src"] == 0 + assert self.mock_dist.broadcast.call_args.args[0][0].item() == True + assert self.mock_dist.broadcast.call_args.args[0][1].item() == 4000 + # check proper values are returned + assert cuda_oom is True + assert max_memory_reserved == 4000 + def test_auto_decrease_batch_size(self): mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80) @@ -91,7 +207,7 @@ def mock_train_func(batch_size): mem_usage = 9000 else: mem_usage = 1000 - self.mock_torch.cuda.max_memory_allocated.return_value = mem_usage + self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage return mem_usage bs_search_algo = BsSearchAlgo(mock_train_func, 64, 1000) @@ -108,7 +224,7 @@ def mock_train_func(batch_size): mem_usage = 9000 else: mem_usage = 1000 + batch_size / 1000 - self.mock_torch.cuda.max_memory_allocated.return_value = mem_usage + self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage return mem_usage bs_search_algo = BsSearchAlgo(mock_train_func, 64, 1000)