@@ -3517,7 +3517,6 @@ def capture_model(self) -> int:
35173517 compilation_counter .num_gpu_runner_capture_triggers += 1
35183518
35193519 start_time = time .perf_counter ()
3520- start_free_gpu_memory = torch .cuda .mem_get_info ()[0 ]
35213520
35223521 @contextmanager
35233522 def freeze_gc ():
@@ -3540,6 +3539,7 @@ def freeze_gc():
35403539 # can reuse the memory pool allocated for the large shapes.
35413540 set_cudagraph_capturing_enabled (True )
35423541 with freeze_gc (), graph_capture (device = self .device ):
3542+ start_free_gpu_memory = torch .cuda .mem_get_info ()[0 ]
35433543 cudagraph_mode = self .compilation_config .cudagraph_mode
35443544 assert cudagraph_mode is not None
35453545 if cudagraph_mode .mixed_mode () != CUDAGraphMode .NONE :
@@ -3568,6 +3568,9 @@ def freeze_gc():
35683568 cudagraph_runtime_mode = CUDAGraphMode .FULL ,
35693569 uniform_decode = True )
35703570
3571+ torch .cuda .synchronize ()
3572+ end_free_gpu_memory = torch .cuda .mem_get_info ()[0 ]
3573+
35713574 # Disable cudagraph capturing globally, so any unexpected cudagraph
35723575 # capturing will be detected and raise an error after here.
35733576 # Note: We don't put it into graph_capture context manager because
@@ -3576,7 +3579,6 @@ def freeze_gc():
35763579 set_cudagraph_capturing_enabled (False )
35773580
35783581 end_time = time .perf_counter ()
3579- end_free_gpu_memory = torch .cuda .mem_get_info ()[0 ]
35803582 elapsed_time = end_time - start_time
35813583 cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
35823584 # This usually takes 5~20 seconds.
0 commit comments