diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 4fadda276..5fa93390f 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -85,3 +85,13 @@ def get_pools(self) -> Sequence[PrioritizedTaskPool]: def get_info(self) -> Dict[str, Any]: """Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration.""" return dict(super().get_info(), inference_schema=self.inference_schema) + + def shutdown(self): + # Break the cyclic references, otherwise TransformerBackend may be not garbage-collected + self.forward_pool = self.backward_pool = self.inference_pool = None + + # Explicitly free the GPU memory. This is not necessary at the time this code is written, + # but may help to avoid future issues when the module is not garbage-collected for some reasons + dummy = torch.tensor([]) + for p in self.module.parameters(): + p.data = dummy diff --git a/src/petals/server/server.py b/src/petals/server/server.py index 97b03e040..caaeb6b67 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -235,8 +235,8 @@ def run(self): if self.stop.wait(timeout): return - if not self.module_container.handlers_alive: - logger.warning("One of connection handlers crashed, restarting the server") + if not self.module_container.is_healthy(): + logger.warning("One of subprocesses crashed, restarting the server") break if self._should_choose_other_blocks(): @@ -252,8 +252,19 @@ def _clean_memory_and_fds(self): gc.collect() # In particular, this closes unused file descriptors cur_proc = psutil.Process() - num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)] - logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left") + num_fds = [proc.num_fds() for proc in [cur_proc] + cur_proc.children(recursive=True)] + logger.info(f"Cleaning up, left {sum(num_fds)} open file descriptors") + + if self.device.type == "cuda": + torch.cuda.empty_cache() + + allocated_vram = torch.cuda.memory_allocated(self.device) + reserved_vram = torch.cuda.memory_reserved(self.device) + gib = 1024**3 + logger.info( + f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, " + f"{reserved_vram / gib:.1f} GiB reserved memory" + ) def _choose_blocks(self) -> List[int]: if self.strict_block_indices is not None: @@ -470,9 +481,10 @@ def ready(self) -> mp.synchronize.Event: """ return self.runtime.ready # mp.Event that is true if self is ready to process batches - @property - def handlers_alive(self) -> bool: - return all(handler.is_alive() for handler in self.conn_handlers) + def is_healthy(self) -> bool: + return all(handler.is_alive() for handler in self.conn_handlers) and all( + pool.is_alive() for pool in self.runtime.pools + ) def shutdown(self): """ @@ -510,6 +522,10 @@ def shutdown(self): logger.debug(f"Shutting down runtime") self.runtime.shutdown() + logger.debug("Shutting down backends") + for backend in self.module_backends.values(): + backend.shutdown() + logger.info("Module container shut down successfully")