Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix OOMs during server rebalancing #150

Merged
merged 3 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/petals/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,13 @@ def get_pools(self) -> Sequence[PrioritizedTaskPool]:
def get_info(self) -> Dict[str, Any]:
"""Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
return dict(super().get_info(), inference_schema=self.inference_schema)

def shutdown(self):
# Break the cyclic references, otherwise TransformerBackend may be not garbage-collected
self.forward_pool = self.backward_pool = self.inference_pool = None

# Explicitly free the GPU memory. This is not necessary at the time this code is written,
# but may help to avoid future issues when the module is not garbage-collected for some reasons
dummy = torch.tensor([])
for p in self.module.parameters():
p.data = dummy
30 changes: 23 additions & 7 deletions src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def run(self):
if self.stop.wait(timeout):
return

if not self.module_container.handlers_alive:
logger.warning("One of connection handlers crashed, restarting the server")
if not self.module_container.is_healthy():
logger.warning("One of subprocesses crashed, restarting the server")
break

if self._should_choose_other_blocks():
Expand All @@ -252,8 +252,19 @@ def _clean_memory_and_fds(self):
gc.collect() # In particular, this closes unused file descriptors

cur_proc = psutil.Process()
num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)]
logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left")
num_fds = [proc.num_fds() for proc in [cur_proc] + cur_proc.children(recursive=True)]
logger.info(f"Cleaning up, left {sum(num_fds)} open file descriptors")

if self.device.type == "cuda":
torch.cuda.empty_cache()

allocated_vram = torch.cuda.memory_allocated(self.device)
reserved_vram = torch.cuda.memory_reserved(self.device)
gib = 1024**3
logger.info(
f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
f"{reserved_vram / gib:.1f} GiB reserved memory"
)

def _choose_blocks(self) -> List[int]:
if self.strict_block_indices is not None:
Expand Down Expand Up @@ -470,9 +481,10 @@ def ready(self) -> mp.synchronize.Event:
"""
return self.runtime.ready # mp.Event that is true if self is ready to process batches

@property
def handlers_alive(self) -> bool:
return all(handler.is_alive() for handler in self.conn_handlers)
def is_healthy(self) -> bool:
return all(handler.is_alive() for handler in self.conn_handlers) and all(
pool.is_alive() for pool in self.runtime.pools
)

def shutdown(self):
"""
Expand Down Expand Up @@ -510,6 +522,10 @@ def shutdown(self):
logger.debug(f"Shutting down runtime")
self.runtime.shutdown()

logger.debug("Shutting down backends")
for backend in self.module_backends.values():
backend.shutdown()

logger.info("Module container shut down successfully")


Expand Down