Skip to content

Commit abca58c

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
Use legacy logic + add to benchmark to avoid regression (#1969)
Summary: Pull Request resolved: #1969 PrefetchTrainPipelineSparseDist - use legacy TrainPipeline API and will refactor newer internals assuming memory neutral / or better. Reviewed By: henrylhtsang Differential Revision: D57143337 fbshipit-source-id: 37df010a3e2fe16365b1190367869016ab386f72
1 parent 0240073 commit abca58c

File tree

2 files changed

+66
-58
lines changed

2 files changed

+66
-58
lines changed

torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@
3636
TrainPipelineBase,
3737
TrainPipelineSparseDist,
3838
)
39-
from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSemiSync
39+
from torchrec.distributed.train_pipeline.train_pipelines import (
40+
PrefetchTrainPipelineSparseDist,
41+
TrainPipelineSemiSync,
42+
)
4043
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
4144
from torchrec.modules.embedding_configs import EmbeddingBagConfig
4245

@@ -260,6 +263,7 @@ def runner(
260263
TrainPipelineBase,
261264
TrainPipelineSparseDist,
262265
TrainPipelineSemiSync,
266+
PrefetchTrainPipelineSparseDist,
263267
]:
264268
pipeline = pipeline_clazz(
265269
model=sharded_model,

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 61 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -197,26 +197,26 @@ def __init__(
197197
self._batch_ip2: Optional[In] = None
198198
self._context: TrainPipelineContext = context_type(version=0)
199199

200+
def _set_module_context(self, context: TrainPipelineContext) -> None:
201+
for module in self._pipelined_modules:
202+
module.forward.set_context(context)
203+
200204
def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool:
201205
batch, context = self.copy_batch_to_gpu(dataloader_iter)
202206
if batch is None:
203207
return False
204208
self.batches.append(batch)
205209
# pyre-ignore [6]
206210
self.contexts.append(context)
211+
207212
return True
208213

209214
def dequeue_batch(self) -> None:
210215
self.batches.popleft()
211216
self.contexts.popleft()
212217
# update PipelineForwards context to match next forward pass
213218
if len(self.batches) >= 1:
214-
for module in self._pipelined_modules:
215-
module.forward.set_context(self.contexts[0])
216-
217-
# legacy support
218-
self._context = self.contexts[0]
219-
self._context.version = 0
219+
self._set_module_context(self.contexts[0])
220220

221221
def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
222222
# pipeline is already filled
@@ -247,6 +247,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
247247
if not self.batches:
248248
raise StopIteration
249249

250+
# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
251+
self._set_module_context(self.contexts[0])
252+
250253
if self._model.training:
251254
with record_function("## zero_grad ##"):
252255
self._optimizer.zero_grad()
@@ -298,8 +301,7 @@ def _init_pipelined_modules(
298301
the splits collective in the input dist.
299302
"""
300303
if self._pipelined_modules:
301-
for module in self._pipelined_modules:
302-
module.forward.set_context(context)
304+
self._set_module_context(context)
303305
self.start_sparse_data_dist(batch, context)
304306
return
305307

@@ -385,6 +387,7 @@ def _copy_batch_to_gpu(self, dataloader_iter: Iterator[In]) -> Optional[In]:
385387
"""
386388
DEPRECATED: exists for backward compatibility on TrainPipelineContext.version 0
387389
"""
390+
self._set_module_context(self._context)
388391
batch, _ = self.copy_batch_to_gpu(dataloader_iter)
389392
return batch
390393

@@ -393,6 +396,7 @@ def _start_sparse_data_dist(self, batch: Optional[In]) -> None:
393396
DEPRECATED: exists for backward compatibility
394397
Waits for batch to finish getting copied to GPU, then starts the input dist.
395398
"""
399+
self._set_module_context(self._context)
396400
self.start_sparse_data_dist(batch, self._context)
397401

398402
def _wait_sparse_data_dist(self) -> None:
@@ -401,7 +405,7 @@ def _wait_sparse_data_dist(self) -> None:
401405
Waits on the input dist splits requests to get the input dist tensors requests,
402406
and populates the context with them.
403407
"""
404-
assert self._context.version == 0, "Context version == 0 is required"
408+
self._set_module_context(self._context)
405409
with record_function("## wait_sparse_data_dist ##"):
406410
with torch.cuda.stream(self._data_dist_stream):
407411
self._context.module_contexts = (
@@ -676,73 +680,67 @@ def __init__(
676680
apply_jit=apply_jit,
677681
context_type=PrefetchTrainPipelineContext,
678682
)
683+
self._context = PrefetchTrainPipelineContext(version=0)
679684
self._prefetch_stream: Optional[torch.cuda.streams.Stream] = (
680685
(torch.cuda.Stream()) if self._device.type == "cuda" else None
681686
)
682687
self._default_stream: Optional[torch.cuda.streams.Stream] = (
683688
(torch.cuda.Stream()) if self._device.type == "cuda" else None
684689
)
690+
self._batch_ip3: Optional[In] = None
685691

686-
def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
687-
# pipeline is full
688-
if len(self.batches) >= 3:
692+
def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
693+
# pipeline is already filled
694+
if self._batch_i and self._batch_ip1 and self._batch_ip2:
689695
return
690-
# executes last batch(es) in pipeline
691-
if self.batches and self._execute_all_batches:
696+
# executes last batch in pipeline
697+
if self._execute_all_batches and (self._batch_i or self._batch_ip1):
692698
return
693699

694-
# batch 0
695-
if not self.enqueue_batch(dataloader_iter):
696-
return
700+
# batch 1
701+
self._batch_i = self._copy_batch_to_gpu(dataloader_iter)
702+
if self._batch_i is None:
703+
raise StopIteration
704+
697705
self._init_pipelined_modules(
698-
# pyre-ignore [6]
699-
self.batches[0],
700-
self.contexts[0],
701-
# pyre-ignore [6]
706+
self._batch_i,
707+
self._context,
708+
# pyre-ignore
702709
PrefetchPipelinedForward,
703710
)
704-
self.wait_sparse_data_dist(self.contexts[0])
705-
self._prefetch(self.batches[0], self.contexts[0])
706-
707-
# batch 1
708-
if not self.enqueue_batch(dataloader_iter):
709-
return
710-
self.start_sparse_data_dist(self.batches[1], self.contexts[1])
711-
self.wait_sparse_data_dist(self.contexts[1])
711+
self._start_sparse_data_dist(self._batch_i)
712+
self._wait_sparse_data_dist()
713+
self._prefetch(self._batch_i)
712714

713715
# batch 2
714-
if not self.enqueue_batch(dataloader_iter):
715-
return
716+
self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter)
717+
self._start_sparse_data_dist(self._batch_ip1)
718+
self._wait_sparse_data_dist()
719+
720+
# batch 3
721+
self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter)
716722

717723
def progress(self, dataloader_iter: Iterator[In]) -> Out:
718-
self.fill_pipeline(dataloader_iter)
719-
if not self.batches:
720-
raise StopIteration
724+
self._fill_pipeline(dataloader_iter)
721725

722726
if self._model.training:
723727
with record_function("## zero_grad ##"):
724728
self._optimizer.zero_grad()
725729

726730
with record_function("## wait_for_batch ##"):
727-
_wait_for_batch(cast(In, self.batches[0]), self._prefetch_stream)
731+
_wait_for_batch(cast(In, self._batch_i), self._prefetch_stream)
728732

729-
if len(self.batches) >= 3:
730-
self.start_sparse_data_dist(self.batches[2], self.contexts[2])
733+
self._start_sparse_data_dist(self._batch_ip2)
731734

732-
# batch 3
733-
self.enqueue_batch(dataloader_iter)
735+
self._batch_ip3 = self._copy_batch_to_gpu(dataloader_iter)
734736

735737
# forward
736738
with record_function("## forward ##"):
737-
losses, output = cast(
738-
Tuple[torch.Tensor, Out], self._model(self.batches[0])
739-
)
739+
losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i))
740740

741-
if len(self.batches) >= 2:
742-
self._prefetch(self.batches[1], self.contexts[1])
741+
self._prefetch(self._batch_ip1)
743742

744-
if len(self.batches) >= 3:
745-
self.wait_sparse_data_dist(self.contexts[2])
743+
self._wait_sparse_data_dist()
746744

747745
if self._model.training:
748746
# backward
@@ -753,24 +751,30 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
753751
with record_function("## optimizer ##"):
754752
self._optimizer.step()
755753

756-
self.dequeue_batch()
754+
self._batch_i = self._batch_ip1
755+
self._batch_ip1 = self._batch_ip2
756+
self._batch_ip2 = self._batch_ip3
757+
757758
return output
758759

759-
def _prefetch(self, batch: Optional[In], context: TrainPipelineContext) -> None:
760+
def _prefetch(self, batch: Optional[In]) -> None:
760761
"""
761762
Waits for input dist to finish, then prefetches data.
762763
"""
763764
if batch is None:
764765
return
765-
with record_function(f"## sharded_module_prefetch {context.index} ##"):
766+
self._context.module_input_post_prefetch.clear()
767+
self._context.module_contexts_post_prefetch.clear()
768+
769+
with record_function("## sharded_module_prefetch ##"):
766770
with torch.cuda.stream(self._prefetch_stream):
767771
batch.record_stream(torch.cuda.current_stream())
768772
for sharded_module in self._pipelined_modules:
769773
forward = sharded_module.forward
770774
assert isinstance(forward, PrefetchPipelinedForward)
771775

772-
assert forward._name in context.input_dist_tensors_requests
773-
request = context.input_dist_tensors_requests[forward._name]
776+
assert forward._name in self._context.input_dist_tensors_requests
777+
request = self._context.input_dist_tensors_requests[forward._name]
774778
assert isinstance(request, Awaitable)
775779
with record_function("## wait_sparse_data_dist ##"):
776780
# Finish waiting on the dist_stream,
@@ -790,16 +794,16 @@ def _prefetch(self, batch: Optional[In], context: TrainPipelineContext) -> None:
790794
data.record_stream(cur_stream)
791795
data.record_stream(self._default_stream)
792796

793-
module_context = context.module_contexts[forward._name]
794-
module_context.record_stream(cur_stream)
795-
module_context.record_stream(self._default_stream)
797+
ctx = self._context.module_contexts[forward._name]
798+
ctx.record_stream(cur_stream)
799+
ctx.record_stream(self._default_stream)
796800

797801
sharded_module.prefetch(
798802
dist_input=data, forward_stream=self._default_stream
799803
)
800-
context.module_input_post_prefetch[forward._name] = data
801-
context.module_contexts_post_prefetch[forward._name] = (
802-
context.module_contexts[forward._name]
804+
self._context.module_input_post_prefetch[forward._name] = data
805+
self._context.module_contexts_post_prefetch[forward._name] = (
806+
self._context.module_contexts[forward._name]
803807
)
804808

805809

0 commit comments

Comments
 (0)