Skip to content

Commit

Permalink
reveret gpu memory upper bound
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Oct 6, 2023
1 parent f7e77d3 commit 72329ac
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def __init__(self, train_func: Callable[[int], None], default_bs: int, max_bs: i
self._max_bs = max_bs
self._bs_try_history: Dict[int, int] = {}
_, self._total_mem = torch.cuda.mem_get_info()
self._mem_lower_bound = 0.85 * self._total_mem
self._mem_upper_bound = 0.9 * self._total_mem
self._mem_lower_bound = 0.8 * self._total_mem
self._mem_upper_bound = 0.85 * self._total_mem

def _try_batch_size(self, bs: int) -> Tuple[bool, int]:
cuda_oom = False
Expand Down Expand Up @@ -75,7 +75,7 @@ def _try_batch_size(self, bs: int) -> Tuple[bool, int]:
dist.broadcast(total_try_result, src=0)

cuda_oom = total_try_result[0].bool().item()
max_memory_allocated = total_try_result[1].to(torch.int64).item()
max_memory_allocated = total_try_result[1].item()

if not cuda_oom:
# Because heapq only supports min heap, use negatized batch size
Expand Down Expand Up @@ -163,7 +163,7 @@ def find_big_enough_batch_size(self, drop_last: bool = False) -> int:
return self._default_bs

# estimate batch size using equation
estimation_pct = 0.87
estimation_pct = 0.82
while True:
estimated_bs = self._estimate_batch_size(estimation_pct)
if estimated_bs in self._bs_try_history:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def mock_train_func(batch_size):
mem_usage = 10000
raise RuntimeError("CUDA out of memory.")
elif batch_size > max_runnable_bs:
mem_usage = 9000 + 1000 * batch_size / (cuda_oom_bound - max_runnable_bs)
mem_usage = 8500 + 1500 * batch_size / (cuda_oom_bound - max_runnable_bs)
else:
mem_usage = 9000 * batch_size / max_runnable_bs
mem_usage = 8500 * batch_size / max_runnable_bs

self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage
return mem_usage
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_find_big_enough_batch_size(self, max_runnable_bs, max_bs, expected_bs):
adapted_bs = bs_search_algo.find_big_enough_batch_size()

if expected_bs is None:
assert 8000 <= mock_train_func(adapted_bs) <= 9000
assert 7500 <= mock_train_func(adapted_bs) <= 8500
else:
assert adapted_bs == expected_bs

Expand All @@ -88,7 +88,7 @@ def mock_train_func(batch_size):
mem_usage = 10000
raise RuntimeError("CUDA out of memory.")
elif batch_size > 100:
mem_usage = 9500
mem_usage = 9000
else:
mem_usage = 1000
self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage
Expand All @@ -105,7 +105,7 @@ def mock_train_func(batch_size):
mem_usage = 10000
raise RuntimeError("CUDA out of memory.")
elif batch_size > 100:
mem_usage = 9500
mem_usage = 9000
else:
mem_usage = 1000 + batch_size / 1000
self.mock_torch.cuda.max_memory_reserved.return_value = mem_usage
Expand All @@ -114,7 +114,7 @@ def mock_train_func(batch_size):
bs_search_algo = BsSearchAlgo(mock_train_func, 64, 1000)
adapted_bs = bs_search_algo.find_big_enough_batch_size()

assert mock_train_func(adapted_bs) <= 9000
assert mock_train_func(adapted_bs) <= 8500

def test_find_big_enough_batch_size_drop_last(self):
mock_train_func = self.get_mock_train_func(cuda_oom_bound=10000, max_runnable_bs=180)
Expand Down

0 comments on commit 72329ac

Please sign in to comment.