Skip to content

Commit

Permalink
Fix bug that auto batch size doesn't consider distributed training (#…
Browse files Browse the repository at this point in the history
…2533)

* consider distributed training while searching batch size

* update unit test

* reveret gpu memory upper bound

* fix typo

* change allocated to reserved

* add unit test for distributed training

* align with pre-commit
  • Loading branch information
eunwoosh authored Oct 11, 2023
1 parent b0eac19 commit 419a0f2
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 10 deletions.
39 changes: 32 additions & 7 deletions src/otx/algorithms/common/adapters/torch/utils/bs_search_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Callable, Dict, Tuple

import torch
import torch.distributed as dist

from otx.algorithms.common.utils.logger import get_logger

Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(self, train_func: Callable[[int], None], default_bs: int, max_bs: i

def _try_batch_size(self, bs: int) -> Tuple[bool, int]:
cuda_oom = False
torch.cuda.reset_max_memory_allocated(device=None)
torch.cuda.reset_max_memory_cached(device=None)
torch.cuda.empty_cache()

try:
Expand All @@ -51,18 +52,42 @@ def _try_batch_size(self, bs: int) -> Tuple[bool, int]:
else:
raise e

max_memory_allocated = torch.cuda.max_memory_allocated(device=None)
max_memory_reserved = torch.cuda.max_memory_reserved(device=None)

if dist.is_initialized(): # Aggregate all results and broadcast to all processes
rank = dist.get_rank()
try_result = torch.tensor([int(cuda_oom), max_memory_reserved], dtype=torch.int64).cuda()

if rank == 0:
try_result_arr = [torch.empty(2, dtype=torch.int64).cuda() for _ in range(dist.get_world_size())]
dist.gather(try_result, gather_list=try_result_arr, dst=0)
else:
dist.gather(try_result, dst=0)

if rank == 0:
try_result_arr = torch.stack(try_result_arr)
cuda_oom = torch.any(try_result_arr[:, 0]) # type: ignore
max_memory_reserved = torch.max(try_result_arr[:, 1]) # type: ignore
total_try_result = torch.tensor([cuda_oom, max_memory_reserved], dtype=torch.int64).cuda()
else:
total_try_result = torch.empty(2, dtype=torch.int64).cuda()

dist.broadcast(total_try_result, src=0)

cuda_oom = total_try_result[0].bool().item()
max_memory_reserved = total_try_result[1].item()

if not cuda_oom:
# Because heapq only supports min heap, use negatized batch size
self._bs_try_history[bs] = max_memory_allocated
self._bs_try_history[bs] = max_memory_reserved

logger.debug(
f"Adapting Batch size => bs : {bs}, CUDA_OOM : {cuda_oom}, "
f"GPU memory usage : {max_memory_allocated / self._total_mem}%"
f"GPU memory usage : {max_memory_reserved / self._total_mem}%"
)
torch.cuda.empty_cache()

return cuda_oom, max_memory_allocated
return cuda_oom, max_memory_reserved

@staticmethod
def _get_even_center_val(val1: int, val2: int) -> int:
Expand All @@ -82,10 +107,10 @@ def auto_decrease_batch_size(self) -> int:
lowest_unavailable_bs = self._default_bs + 2

while True:
cuda_oom, max_memory_allocated = self._try_batch_size(current_bs)
cuda_oom, max_memory_reserved = self._try_batch_size(current_bs)

# If GPU memory usage is too close to limit, CUDA OOM can be raised during training
if cuda_oom or max_memory_allocated > self._mem_upper_bound:
if cuda_oom or max_memory_reserved > self._mem_upper_bound:
if current_bs < lowest_unavailable_bs:
lowest_unavailable_bs = current_bs
current_bs = self._get_even_center_val(current_bs, available_bs)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional, List

import pytest
import torch

from tests.test_suite.e2e_test_system import e2e_pytest_unit
from otx.algorithms.common.adapters.torch.utils import BsSearchAlgo
Expand All @@ -11,6 +14,8 @@ class TestBsSearchAlgo:
def setup_test(self, mocker):
self.mock_torch = mocker.patch.object(bs_search_algo, "torch")
self.mock_torch.cuda.mem_get_info.return_value = (1, 10000)
self.mock_dist = mocker.patch.object(bs_search_algo, "dist")
self.mock_dist.is_initialized.return_value = False

def test_init(self, mocker):
BsSearchAlgo(mocker.MagicMock(), 4, 10)
Expand All @@ -35,11 +40,122 @@ def mock_train_func(batch_size):
else:
mem_usage = 8500 * batch_size / max_runnable_bs

self.mock_torch.cuda.max_memory_allocated.return_value = mem_usage
self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage
return mem_usage

return mock_train_func

def test_try_batch_size(self):
mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80)
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
batch_size = 40

cuda_oom, max_memory_reserved = bs_search_algo._try_batch_size(batch_size)

assert cuda_oom is False
assert max_memory_reserved == mock_train_func(batch_size)
self.mock_torch.cuda.reset_max_memory_cached.assert_called()
self.mock_torch.cuda.empty_cache.assert_called()

def test_try_batch_size_cuda_oom(self):
mock_train_func = self.get_mock_train_func(cuda_oom_bound=100, max_runnable_bs=80)
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
batch_size = 200

cuda_oom, _ = bs_search_algo._try_batch_size(batch_size)

assert cuda_oom is True
self.mock_torch.cuda.reset_max_memory_cached.assert_called()
self.mock_torch.cuda.empty_cache.assert_called()

def _prepare_dist_test(self, broadcast_val: torch.Tensor, gather_val: Optional[List[torch.Tensor]] = None):
self.mock_dist.is_initialized.return_value = True

# mocking torch.distributed.broadcast
def mock_broadcast(tensor: torch.Tensor, src: int):
tensor.copy_(broadcast_val)

self.mock_dist.broadcast.side_effect = mock_broadcast

# mocking torch.distributed.gather if gather_val is given
def mock_gather(tensor: torch.Tensor, gather_list: Optional[List[torch.Tensor]] = None, dst: int = 0):
for i in range(len(gather_list)):
gather_list[i].copy_(gather_val[i])

if gather_val is not None:
self.mock_dist.gather.side_effect = mock_gather

# revert some of torch function
def mock_tensor_cuda(self, *args, **kwargs):
return self

torch.Tensor.cuda = mock_tensor_cuda
self.mock_torch.tensor = torch.tensor
self.mock_torch.int64 = torch.int64
self.mock_torch.max = torch.max
self.mock_torch.any = torch.any
self.mock_torch.stack = torch.stack
self.mock_torch.empty = torch.empty

def test_try_batch_size_distributed_not_rank_0(self):
self.mock_dist.get_rank.return_value = 1
broadcasted_cuda_oom = False
broadcasted_max_memory_reserved = 4000
self._prepare_dist_test(
broadcast_val=torch.tensor([broadcasted_cuda_oom, broadcasted_max_memory_reserved], dtype=torch.int64)
)
mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80)
batch_size = 40
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
w1_max_memory_reserved = mock_train_func(batch_size)

cuda_oom, max_memory_reserved = bs_search_algo._try_batch_size(batch_size)

# check dist.gather is called and get [cuda_oom, maxmemory_reserved] as arguments.
self.mock_dist.gather.assert_called_once()
assert self.mock_dist.gather.call_args.args[0][0].item() == False
assert self.mock_dist.gather.call_args.args[0][1].item() == w1_max_memory_reserved
assert self.mock_dist.gather.call_args.kwargs["dst"] == 0
# check dist.broadcast is called
self.mock_dist.broadcast.assert_called_once()
assert self.mock_dist.broadcast.call_args.kwargs["src"] == 0
# check broadcased values are returned
assert cuda_oom is broadcasted_cuda_oom
assert max_memory_reserved == broadcasted_max_memory_reserved

def test_try_batch_size_distributed_rank_0(self):
self.mock_dist.get_rank.return_value = 0
self.mock_dist.get_world_size.return_value = 2
self._prepare_dist_test(
broadcast_val=torch.tensor([True, 4000], dtype=torch.int64),
gather_val=[
torch.tensor([False, 3000], dtype=torch.int64),
torch.tensor([True, 4000], dtype=torch.int64),
],
)
mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80)
batch_size = 40
bs_search_algo = BsSearchAlgo(mock_train_func, 128, 1000)
w0_max_memory_reserved = mock_train_func(batch_size)

cuda_oom, max_memory_reserved = bs_search_algo._try_batch_size(batch_size)

# check dist.gather is called and get [cuda_oom, max_memory_reserved] as arguments.
self.mock_dist.gather.assert_called_once()
assert self.mock_dist.gather.call_args.args[0][0].item() == False
assert self.mock_dist.gather.call_args.args[0][1].item() == w0_max_memory_reserved
assert self.mock_dist.gather.call_args.kwargs["dst"] == 0
# check if any process get cuda oom then set cuda_oom to True and
# set max_memory_reserved to maximum value of processes'
self.mock_dist.broadcast.assert_called_once()
self.mock_dist.broadcast.assert_called_once()
assert self.mock_dist.broadcast.call_args.kwargs["src"] == 0
assert self.mock_dist.broadcast.call_args.args[0][0].item() == True
assert self.mock_dist.broadcast.call_args.args[0][1].item() == 4000
# check proper values are returned
assert cuda_oom is True
assert max_memory_reserved == 4000

def test_auto_decrease_batch_size(self):
mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=80)

Expand Down Expand Up @@ -91,7 +207,7 @@ def mock_train_func(batch_size):
mem_usage = 9000
else:
mem_usage = 1000
self.mock_torch.cuda.max_memory_allocated.return_value = mem_usage
self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage
return mem_usage

bs_search_algo = BsSearchAlgo(mock_train_func, 64, 1000)
Expand All @@ -108,7 +224,7 @@ def mock_train_func(batch_size):
mem_usage = 9000
else:
mem_usage = 1000 + batch_size / 1000
self.mock_torch.cuda.max_memory_allocated.return_value = mem_usage
self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage
return mem_usage

bs_search_algo = BsSearchAlgo(mock_train_func, 64, 1000)
Expand Down

0 comments on commit 419a0f2

Please sign in to comment.