Skip to content

Commit

Permalink
Fix get_backend bug and add clear_device_cache function (#2857)
Browse files Browse the repository at this point in the history
* added clear_device_cache

* set lambda: 0 for mps and cpu
  • Loading branch information
NurmaU authored Jul 3, 2024
1 parent 92404fb commit 8330b37
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 40 deletions.
6 changes: 3 additions & 3 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,17 @@ def get_backend():
elif is_cuda_available():
return "cuda", torch.cuda.device_count(), torch.cuda.memory_allocated
elif is_mps_available(min_version="2.0"):
return "mps", 1, torch.mps.current_allocated_memory()
return "mps", 1, torch.mps.current_allocated_memory
elif is_mps_available():
return "mps", 1, 0
return "mps", 1, lambda: 0
elif is_mlu_available():
return "mlu", torch.mlu.device_count(), torch.mlu.memory_allocated
elif is_npu_available():
return "npu", torch.npu.device_count(), torch.npu.memory_allocated
elif is_xpu_available():
return "xpu", torch.xpu.device_count(), torch.xpu.memory_allocated
else:
return "cpu", 1, 0
return "cpu", 1, lambda: 0


torch_device, device_count, memory_allocated_func = get_backend()
Expand Down
47 changes: 18 additions & 29 deletions src/accelerate/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@
from .imports import is_mlu_available, is_mps_available, is_npu_available, is_xpu_available


def clear_device_cache():
gc.collect()

if is_xpu_available():
torch.xpu.empty_cache()
elif is_mlu_available():
torch.mlu.empty_cache()
elif is_npu_available():
torch.npu.empty_cache()
elif is_mps_available(min_version="2.0"):
torch.mps.empty_cache()
else:
torch.cuda.empty_cache()


def release_memory(*objects):
"""
Releases memory from `objects` by setting them to `None` and calls `gc.collect()` and `torch.cuda.empty_cache()`.
Expand All @@ -52,17 +67,7 @@ def release_memory(*objects):
objects = list(objects)
for i in range(len(objects)):
objects[i] = None
gc.collect()
if is_xpu_available():
torch.xpu.empty_cache()
elif is_mlu_available():
torch.mlu.empty_cache()
elif is_npu_available():
torch.npu.empty_cache()
elif is_mps_available(min_version="2.0"):
torch.mps.empty_cache()
else:
torch.cuda.empty_cache()
clear_device_cache()
return objects


Expand Down Expand Up @@ -118,15 +123,7 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i

def decorator(*args, **kwargs):
nonlocal batch_size
gc.collect()
if is_xpu_available():
torch.xpu.empty_cache()
elif is_mlu_available():
torch.mlu.empty_cache()
elif is_npu_available():
torch.npu.empty_cache()
else:
torch.cuda.empty_cache()
clear_device_cache()
params = list(inspect.signature(function).parameters.keys())
# Guard against user error
if len(params) < (len(args) + 1):
Expand All @@ -142,15 +139,7 @@ def decorator(*args, **kwargs):
return function(batch_size, *args, **kwargs)
except Exception as e:
if should_reduce_batch_size(e):
gc.collect()
if is_xpu_available():
torch.xpu.empty_cache()
elif is_mlu_available():
torch.mlu.empty_cache()
elif is_npu_available():
torch.npu.empty_cache()
else:
torch.cuda.empty_cache()
clear_device_cache()
batch_size //= 2
else:
raise
Expand Down
10 changes: 2 additions & 8 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
is_torch_xla_available,
is_xpu_available,
)
from .memory import clear_device_cache
from .offload import load_offloaded_weight, offload_weight, save_offload_index
from .tqdm import is_tqdm_available, tqdm
from .versions import compare_versions
Expand Down Expand Up @@ -456,14 +457,7 @@ def set_module_tensor_to_device(
module.weight = module.weight.cuda(device_index)
# clean pre and post foward hook
if device != "cpu":
if is_npu_available():
torch.npu.empty_cache()
elif is_mlu_available():
torch.mlu.empty_cache()
elif is_xpu_available():
torch.xpu.empty_cache()
else:
torch.cuda.empty_cache()
clear_device_cache()

# When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in
# order to avoid duplicating memory, see above.
Expand Down

0 comments on commit 8330b37

Please sign in to comment.