Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Any-Winter-4079 authored and lstein committed Sep 12, 2022
1 parent 9cdf3ac commit 25d9ccc
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,25 +210,29 @@ def forward(self, x):
h_ = torch.zeros_like(k, device=q.device)

device_type = 'mps' if q.device.type == 'mps' else 'cuda'

if device_type == 'mps':
mem_free_total = psutil.virtual_memory().available
else:
if device_type == 'cuda':
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch

tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4
mem_required = tensor_size * 2.5
steps = 1
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4
mem_required = tensor_size * 2.5
steps = 1

if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))

slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]

slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
else:
if psutil.virtual_memory().available / (1024**3) < 12:
slice_size = 1
else:
slice_size = min(q.shape[1], math.floor(2**30 / (q.shape[0] * q.shape[1])))

for i in range(0, q.shape[1], slice_size):
end = i + slice_size

Expand Down

0 comments on commit 25d9ccc

Please sign in to comment.