@@ -95,6 +95,9 @@ def __init__(
9595 self .gpu_cache : Optional [List [List [torch .Tensor ]]] = None
9696 self ._seq_group_metadata_cache : Dict [str , SequenceGroupMetadata ] = {}
9797
98+ # Buffers saved before sleep
99+ self ._sleep_saved_buffers : Dict [str , torch .Tensor ] = {}
100+
98101 # Torch profiler. Enabled and configured through env vars:
99102 # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
100103 if envs .VLLM_TORCH_PROFILER_DIR :
@@ -124,6 +127,15 @@ def stop_profile(self):
124127
125128 def sleep (self , level : int = 1 ) -> None :
126129 free_bytes_before_sleep = torch .cuda .mem_get_info ()[0 ]
130+
131+ # Save the buffers before level 2 sleep
132+ if level == 2 :
133+ model = self .model_runner .model
134+ self ._sleep_saved_buffers = {
135+ name : buffer .cpu ().clone ()
136+ for name , buffer in model .named_buffers ()
137+ }
138+
127139 allocator = CuMemAllocator .get_instance ()
128140 allocator .sleep (offload_tags = ("weights" , ) if level == 1 else tuple ())
129141 free_bytes_after_sleep , total = torch .cuda .mem_get_info ()
@@ -139,6 +151,14 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
139151 allocator = CuMemAllocator .get_instance ()
140152 allocator .wake_up (tags = tags )
141153
154+ # Restore the buffers after level 2 sleep
155+ if len (self ._sleep_saved_buffers ):
156+ model = self .model_runner .model
157+ for name , buffer in model .named_buffers ():
158+ if name in self ._sleep_saved_buffers :
159+ buffer .data .copy_ (self ._sleep_saved_buffers [name ].data )
160+ self ._sleep_saved_buffers = {}
161+
142162 def init_device (self ) -> None :
143163 if self .device_config .device .type == "cuda" :
144164 # torch.distributed.all_reduce does not free the input tensor until
0 commit comments