Skip to content

Commit

Permalink
Improve GPU VRAM cleanup during rebalancing
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Dec 12, 2022
1 parent 1d31dd4 commit b251516
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,20 @@ def run(self):
def _clean_memory_and_fds(self):
del self.module_container
gc.collect() # In particular, this closes unused file descriptors

cur_proc = psutil.Process()
num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)]
logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left")
num_fds = [proc.num_fds() for proc in [cur_proc] + cur_proc.children(recursive=True)]
logger.info(f"Cleaning up, left {sum(num_fds)} open file descriptors")

if self.device.type == "cuda":
torch.cuda.empty_cache()

allocated_vram = torch.cuda.memory_allocated(self.device)
reserved_vram = torch.cuda.memory_reserved(self.device)
gib = 1024**3
logger.info(
f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
f"{reserved_vram / gib:.1f} GiB reserved memory"
)

def _choose_blocks(self) -> List[int]:
if self.strict_block_indices is not None:
Expand Down Expand Up @@ -471,9 +481,8 @@ def ready(self) -> mp.synchronize.Event:
return self.runtime.ready # mp.Event that is true if self is ready to process batches

def is_healthy(self) -> bool:
return (
all(handler.is_alive() for handler in self.conn_handlers) and
all(pool.is_alive() for pool in self.runtime.pools)
return all(handler.is_alive() for handler in self.conn_handlers) and all(
pool.is_alive() for pool in self.runtime.pools
)

def shutdown(self):
Expand Down Expand Up @@ -512,6 +521,13 @@ def shutdown(self):
logger.debug(f"Shutting down runtime")
self.runtime.shutdown()

logger.debug("Cleaning up memory")
# Necessary since links to `backend.module` may be left somewhere, so it is not garbage-collected properly
dummy = torch.tensor([])
for backend in self.module_backends.values():
for p in backend.module.parameters():
p.data = dummy

logger.info("Module container shut down successfully")


Expand Down

0 comments on commit b251516

Please sign in to comment.