@@ -197,26 +197,26 @@ def __init__(
197
197
self ._batch_ip2 : Optional [In ] = None
198
198
self ._context : TrainPipelineContext = context_type (version = 0 )
199
199
200
+ def _set_module_context (self , context : TrainPipelineContext ) -> None :
201
+ for module in self ._pipelined_modules :
202
+ module .forward .set_context (context )
203
+
200
204
def enqueue_batch (self , dataloader_iter : Iterator [In ]) -> bool :
201
205
batch , context = self .copy_batch_to_gpu (dataloader_iter )
202
206
if batch is None :
203
207
return False
204
208
self .batches .append (batch )
205
209
# pyre-ignore [6]
206
210
self .contexts .append (context )
211
+
207
212
return True
208
213
209
214
def dequeue_batch (self ) -> None :
210
215
self .batches .popleft ()
211
216
self .contexts .popleft ()
212
217
# update PipelineForwards context to match next forward pass
213
218
if len (self .batches ) >= 1 :
214
- for module in self ._pipelined_modules :
215
- module .forward .set_context (self .contexts [0 ])
216
-
217
- # legacy support
218
- self ._context = self .contexts [0 ]
219
- self ._context .version = 0
219
+ self ._set_module_context (self .contexts [0 ])
220
220
221
221
def fill_pipeline (self , dataloader_iter : Iterator [In ]) -> None :
222
222
# pipeline is already filled
@@ -247,6 +247,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
247
247
if not self .batches :
248
248
raise StopIteration
249
249
250
+ # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
251
+ self ._set_module_context (self .contexts [0 ])
252
+
250
253
if self ._model .training :
251
254
with record_function ("## zero_grad ##" ):
252
255
self ._optimizer .zero_grad ()
@@ -298,8 +301,7 @@ def _init_pipelined_modules(
298
301
the splits collective in the input dist.
299
302
"""
300
303
if self ._pipelined_modules :
301
- for module in self ._pipelined_modules :
302
- module .forward .set_context (context )
304
+ self ._set_module_context (context )
303
305
self .start_sparse_data_dist (batch , context )
304
306
return
305
307
@@ -385,6 +387,7 @@ def _copy_batch_to_gpu(self, dataloader_iter: Iterator[In]) -> Optional[In]:
385
387
"""
386
388
DEPRECATED: exists for backward compatibility on TrainPipelineContext.version 0
387
389
"""
390
+ self ._set_module_context (self ._context )
388
391
batch , _ = self .copy_batch_to_gpu (dataloader_iter )
389
392
return batch
390
393
@@ -393,6 +396,7 @@ def _start_sparse_data_dist(self, batch: Optional[In]) -> None:
393
396
DEPRECATED: exists for backward compatibility
394
397
Waits for batch to finish getting copied to GPU, then starts the input dist.
395
398
"""
399
+ self ._set_module_context (self ._context )
396
400
self .start_sparse_data_dist (batch , self ._context )
397
401
398
402
def _wait_sparse_data_dist (self ) -> None :
@@ -401,7 +405,7 @@ def _wait_sparse_data_dist(self) -> None:
401
405
Waits on the input dist splits requests to get the input dist tensors requests,
402
406
and populates the context with them.
403
407
"""
404
- assert self ._context . version == 0 , "Context version == 0 is required"
408
+ self ._set_module_context ( self . _context )
405
409
with record_function ("## wait_sparse_data_dist ##" ):
406
410
with torch .cuda .stream (self ._data_dist_stream ):
407
411
self ._context .module_contexts = (
@@ -676,73 +680,67 @@ def __init__(
676
680
apply_jit = apply_jit ,
677
681
context_type = PrefetchTrainPipelineContext ,
678
682
)
683
+ self ._context = PrefetchTrainPipelineContext (version = 0 )
679
684
self ._prefetch_stream : Optional [torch .cuda .streams .Stream ] = (
680
685
(torch .cuda .Stream ()) if self ._device .type == "cuda" else None
681
686
)
682
687
self ._default_stream : Optional [torch .cuda .streams .Stream ] = (
683
688
(torch .cuda .Stream ()) if self ._device .type == "cuda" else None
684
689
)
690
+ self ._batch_ip3 : Optional [In ] = None
685
691
686
- def fill_pipeline (self , dataloader_iter : Iterator [In ]) -> None :
687
- # pipeline is full
688
- if len ( self .batches ) >= 3 :
692
+ def _fill_pipeline (self , dataloader_iter : Iterator [In ]) -> None :
693
+ # pipeline is already filled
694
+ if self ._batch_i and self . _batch_ip1 and self . _batch_ip2 :
689
695
return
690
- # executes last batch(es) in pipeline
691
- if self .batches and self ._execute_all_batches :
696
+ # executes last batch in pipeline
697
+ if self ._execute_all_batches and ( self ._batch_i or self . _batch_ip1 ) :
692
698
return
693
699
694
- # batch 0
695
- if not self .enqueue_batch (dataloader_iter ):
696
- return
700
+ # batch 1
701
+ self ._batch_i = self ._copy_batch_to_gpu (dataloader_iter )
702
+ if self ._batch_i is None :
703
+ raise StopIteration
704
+
697
705
self ._init_pipelined_modules (
698
- # pyre-ignore [6]
699
- self .batches [0 ],
700
- self .contexts [0 ],
701
- # pyre-ignore [6]
706
+ self ._batch_i ,
707
+ self ._context ,
708
+ # pyre-ignore
702
709
PrefetchPipelinedForward ,
703
710
)
704
- self .wait_sparse_data_dist (self .contexts [0 ])
705
- self ._prefetch (self .batches [0 ], self .contexts [0 ])
706
-
707
- # batch 1
708
- if not self .enqueue_batch (dataloader_iter ):
709
- return
710
- self .start_sparse_data_dist (self .batches [1 ], self .contexts [1 ])
711
- self .wait_sparse_data_dist (self .contexts [1 ])
711
+ self ._start_sparse_data_dist (self ._batch_i )
712
+ self ._wait_sparse_data_dist ()
713
+ self ._prefetch (self ._batch_i )
712
714
713
715
# batch 2
714
- if not self .enqueue_batch (dataloader_iter ):
715
- return
716
+ self ._batch_ip1 = self ._copy_batch_to_gpu (dataloader_iter )
717
+ self ._start_sparse_data_dist (self ._batch_ip1 )
718
+ self ._wait_sparse_data_dist ()
719
+
720
+ # batch 3
721
+ self ._batch_ip2 = self ._copy_batch_to_gpu (dataloader_iter )
716
722
717
723
def progress (self , dataloader_iter : Iterator [In ]) -> Out :
718
- self .fill_pipeline (dataloader_iter )
719
- if not self .batches :
720
- raise StopIteration
724
+ self ._fill_pipeline (dataloader_iter )
721
725
722
726
if self ._model .training :
723
727
with record_function ("## zero_grad ##" ):
724
728
self ._optimizer .zero_grad ()
725
729
726
730
with record_function ("## wait_for_batch ##" ):
727
- _wait_for_batch (cast (In , self .batches [ 0 ] ), self ._prefetch_stream )
731
+ _wait_for_batch (cast (In , self ._batch_i ), self ._prefetch_stream )
728
732
729
- if len (self .batches ) >= 3 :
730
- self .start_sparse_data_dist (self .batches [2 ], self .contexts [2 ])
733
+ self ._start_sparse_data_dist (self ._batch_ip2 )
731
734
732
- # batch 3
733
- self .enqueue_batch (dataloader_iter )
735
+ self ._batch_ip3 = self ._copy_batch_to_gpu (dataloader_iter )
734
736
735
737
# forward
736
738
with record_function ("## forward ##" ):
737
- losses , output = cast (
738
- Tuple [torch .Tensor , Out ], self ._model (self .batches [0 ])
739
- )
739
+ losses , output = cast (Tuple [torch .Tensor , Out ], self ._model (self ._batch_i ))
740
740
741
- if len (self .batches ) >= 2 :
742
- self ._prefetch (self .batches [1 ], self .contexts [1 ])
741
+ self ._prefetch (self ._batch_ip1 )
743
742
744
- if len (self .batches ) >= 3 :
745
- self .wait_sparse_data_dist (self .contexts [2 ])
743
+ self ._wait_sparse_data_dist ()
746
744
747
745
if self ._model .training :
748
746
# backward
@@ -753,24 +751,30 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
753
751
with record_function ("## optimizer ##" ):
754
752
self ._optimizer .step ()
755
753
756
- self .dequeue_batch ()
754
+ self ._batch_i = self ._batch_ip1
755
+ self ._batch_ip1 = self ._batch_ip2
756
+ self ._batch_ip2 = self ._batch_ip3
757
+
757
758
return output
758
759
759
- def _prefetch (self , batch : Optional [In ], context : TrainPipelineContext ) -> None :
760
+ def _prefetch (self , batch : Optional [In ]) -> None :
760
761
"""
761
762
Waits for input dist to finish, then prefetches data.
762
763
"""
763
764
if batch is None :
764
765
return
765
- with record_function (f"## sharded_module_prefetch { context .index } ##" ):
766
+ self ._context .module_input_post_prefetch .clear ()
767
+ self ._context .module_contexts_post_prefetch .clear ()
768
+
769
+ with record_function ("## sharded_module_prefetch ##" ):
766
770
with torch .cuda .stream (self ._prefetch_stream ):
767
771
batch .record_stream (torch .cuda .current_stream ())
768
772
for sharded_module in self ._pipelined_modules :
769
773
forward = sharded_module .forward
770
774
assert isinstance (forward , PrefetchPipelinedForward )
771
775
772
- assert forward ._name in context .input_dist_tensors_requests
773
- request = context .input_dist_tensors_requests [forward ._name ]
776
+ assert forward ._name in self . _context .input_dist_tensors_requests
777
+ request = self . _context .input_dist_tensors_requests [forward ._name ]
774
778
assert isinstance (request , Awaitable )
775
779
with record_function ("## wait_sparse_data_dist ##" ):
776
780
# Finish waiting on the dist_stream,
@@ -790,16 +794,16 @@ def _prefetch(self, batch: Optional[In], context: TrainPipelineContext) -> None:
790
794
data .record_stream (cur_stream )
791
795
data .record_stream (self ._default_stream )
792
796
793
- module_context = context .module_contexts [forward ._name ]
794
- module_context .record_stream (cur_stream )
795
- module_context .record_stream (self ._default_stream )
797
+ ctx = self . _context .module_contexts [forward ._name ]
798
+ ctx .record_stream (cur_stream )
799
+ ctx .record_stream (self ._default_stream )
796
800
797
801
sharded_module .prefetch (
798
802
dist_input = data , forward_stream = self ._default_stream
799
803
)
800
- context .module_input_post_prefetch [forward ._name ] = data
801
- context .module_contexts_post_prefetch [forward ._name ] = (
802
- context .module_contexts [forward ._name ]
804
+ self . _context .module_input_post_prefetch [forward ._name ] = data
805
+ self . _context .module_contexts_post_prefetch [forward ._name ] = (
806
+ self . _context .module_contexts [forward ._name ]
803
807
)
804
808
805
809
0 commit comments