diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index b419b6ce5..16913a901 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -36,7 +36,10 @@ TrainPipelineBase, TrainPipelineSparseDist, ) -from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSemiSync +from torchrec.distributed.train_pipeline.train_pipelines import ( + PrefetchTrainPipelineSparseDist, + TrainPipelineSemiSync, +) from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType from torchrec.modules.embedding_configs import EmbeddingBagConfig @@ -260,6 +263,7 @@ def runner( TrainPipelineBase, TrainPipelineSparseDist, TrainPipelineSemiSync, + PrefetchTrainPipelineSparseDist, ]: pipeline = pipeline_clazz( model=sharded_model, @@ -297,7 +301,7 @@ def _func_to_benchmark( ) if rank == 0: print( - f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.1f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.1f} GB" + f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.3f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.3f} GB" ) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 8460cc836..f8685ab1b 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -197,6 +197,10 @@ def __init__( self._batch_ip2: Optional[In] = None self._context: TrainPipelineContext = context_type(version=0) + def _set_module_context(self, context: TrainPipelineContext) -> None: + for module in self._pipelined_modules: + module.forward.set_context(context) + def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool: batch, context = self.copy_batch_to_gpu(dataloader_iter) if batch is None: @@ -204,6 +208,7 @@ def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool: self.batches.append(batch) # pyre-ignore [6] self.contexts.append(context) + return True def dequeue_batch(self) -> None: @@ -211,12 +216,7 @@ def dequeue_batch(self) -> None: self.contexts.popleft() # update PipelineForwards context to match next forward pass if len(self.batches) >= 1: - for module in self._pipelined_modules: - module.forward.set_context(self.contexts[0]) - - # legacy support - self._context = self.contexts[0] - self._context.version = 0 + self._set_module_context(self.contexts[0]) def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: # pipeline is already filled @@ -247,6 +247,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: if not self.batches: raise StopIteration + # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only) + self._set_module_context(self.contexts[0]) + if self._model.training: with record_function("## zero_grad ##"): self._optimizer.zero_grad() @@ -298,8 +301,7 @@ def _init_pipelined_modules( the splits collective in the input dist. """ if self._pipelined_modules: - for module in self._pipelined_modules: - module.forward.set_context(context) + self._set_module_context(context) self.start_sparse_data_dist(batch, context) return @@ -378,11 +380,14 @@ def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None: for names, awaitable in context.fused_splits_awaitables: for name, request in zip(names, awaitable.wait()): context.input_dist_tensors_requests[name] = request + context.input_dist_splits_requests.clear() + context.fused_splits_awaitables.clear() def _copy_batch_to_gpu(self, dataloader_iter: Iterator[In]) -> Optional[In]: """ DEPRECATED: exists for backward compatibility on TrainPipelineContext.version 0 """ + self._set_module_context(self._context) batch, _ = self.copy_batch_to_gpu(dataloader_iter) return batch @@ -391,6 +396,7 @@ def _start_sparse_data_dist(self, batch: Optional[In]) -> None: DEPRECATED: exists for backward compatibility Waits for batch to finish getting copied to GPU, then starts the input dist. """ + self._set_module_context(self._context) self.start_sparse_data_dist(batch, self._context) def _wait_sparse_data_dist(self) -> None: @@ -399,7 +405,7 @@ def _wait_sparse_data_dist(self) -> None: Waits on the input dist splits requests to get the input dist tensors requests, and populates the context with them. """ - assert self._context.version == 0, "Context version == 0 is required" + self._set_module_context(self._context) with record_function("## wait_sparse_data_dist ##"): with torch.cuda.stream(self._data_dist_stream): self._context.module_contexts = ( @@ -674,73 +680,67 @@ def __init__( apply_jit=apply_jit, context_type=PrefetchTrainPipelineContext, ) + self._context = PrefetchTrainPipelineContext(version=0) self._prefetch_stream: Optional[torch.cuda.streams.Stream] = ( (torch.cuda.Stream()) if self._device.type == "cuda" else None ) self._default_stream: Optional[torch.cuda.streams.Stream] = ( (torch.cuda.Stream()) if self._device.type == "cuda" else None ) + self._batch_ip3: Optional[In] = None - def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: - # pipeline is full - if len(self.batches) >= 3: + def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: + # pipeline is already filled + if self._batch_i and self._batch_ip1 and self._batch_ip2: return - # executes last batch(es) in pipeline - if self.batches and self._execute_all_batches: + # executes last batch in pipeline + if self._execute_all_batches and (self._batch_i or self._batch_ip1): return - # batch 0 - if not self.enqueue_batch(dataloader_iter): - return + # batch 1 + self._batch_i = self._copy_batch_to_gpu(dataloader_iter) + if self._batch_i is None: + raise StopIteration + self._init_pipelined_modules( - # pyre-ignore [6] - self.batches[0], - self.contexts[0], - # pyre-ignore [6] + self._batch_i, + self._context, + # pyre-ignore PrefetchPipelinedForward, ) - self.wait_sparse_data_dist(self.contexts[0]) - self._prefetch(self.batches[0], self.contexts[0]) - - # batch 1 - if not self.enqueue_batch(dataloader_iter): - return - self.start_sparse_data_dist(self.batches[1], self.contexts[1]) - self.wait_sparse_data_dist(self.contexts[1]) + self._start_sparse_data_dist(self._batch_i) + self._wait_sparse_data_dist() + self._prefetch(self._batch_i) # batch 2 - if not self.enqueue_batch(dataloader_iter): - return + self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter) + self._start_sparse_data_dist(self._batch_ip1) + self._wait_sparse_data_dist() + + # batch 3 + self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter) def progress(self, dataloader_iter: Iterator[In]) -> Out: - self.fill_pipeline(dataloader_iter) - if not self.batches: - raise StopIteration + self._fill_pipeline(dataloader_iter) if self._model.training: with record_function("## zero_grad ##"): self._optimizer.zero_grad() with record_function("## wait_for_batch ##"): - _wait_for_batch(cast(In, self.batches[0]), self._prefetch_stream) + _wait_for_batch(cast(In, self._batch_i), self._prefetch_stream) - if len(self.batches) >= 3: - self.start_sparse_data_dist(self.batches[2], self.contexts[2]) + self._start_sparse_data_dist(self._batch_ip2) - # batch 3 - self.enqueue_batch(dataloader_iter) + self._batch_ip3 = self._copy_batch_to_gpu(dataloader_iter) # forward with record_function("## forward ##"): - losses, output = cast( - Tuple[torch.Tensor, Out], self._model(self.batches[0]) - ) + losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i)) - if len(self.batches) >= 2: - self._prefetch(self.batches[1], self.contexts[1]) + self._prefetch(self._batch_ip1) - if len(self.batches) >= 3: - self.wait_sparse_data_dist(self.contexts[2]) + self._wait_sparse_data_dist() if self._model.training: # backward @@ -751,24 +751,30 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out: with record_function("## optimizer ##"): self._optimizer.step() - self.dequeue_batch() + self._batch_i = self._batch_ip1 + self._batch_ip1 = self._batch_ip2 + self._batch_ip2 = self._batch_ip3 + return output - def _prefetch(self, batch: Optional[In], context: TrainPipelineContext) -> None: + def _prefetch(self, batch: Optional[In]) -> None: """ Waits for input dist to finish, then prefetches data. """ if batch is None: return - with record_function(f"## sharded_module_prefetch {context.index} ##"): + self._context.module_input_post_prefetch.clear() + self._context.module_contexts_post_prefetch.clear() + + with record_function("## sharded_module_prefetch ##"): with torch.cuda.stream(self._prefetch_stream): batch.record_stream(torch.cuda.current_stream()) for sharded_module in self._pipelined_modules: forward = sharded_module.forward assert isinstance(forward, PrefetchPipelinedForward) - assert forward._name in context.input_dist_tensors_requests - request = context.input_dist_tensors_requests[forward._name] + assert forward._name in self._context.input_dist_tensors_requests + request = self._context.input_dist_tensors_requests[forward._name] assert isinstance(request, Awaitable) with record_function("## wait_sparse_data_dist ##"): # Finish waiting on the dist_stream, @@ -788,16 +794,16 @@ def _prefetch(self, batch: Optional[In], context: TrainPipelineContext) -> None: data.record_stream(cur_stream) data.record_stream(self._default_stream) - module_context = context.module_contexts[forward._name] - module_context.record_stream(cur_stream) - module_context.record_stream(self._default_stream) + ctx = self._context.module_contexts[forward._name] + ctx.record_stream(cur_stream) + ctx.record_stream(self._default_stream) sharded_module.prefetch( dist_input=data, forward_stream=self._default_stream ) - context.module_input_post_prefetch[forward._name] = data - context.module_contexts_post_prefetch[forward._name] = ( - context.module_contexts[forward._name] + self._context.module_input_post_prefetch[forward._name] = data + self._context.module_contexts_post_prefetch[forward._name] = ( + self._context.module_contexts[forward._name] ) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 549dd5c72..bb250aa57 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -187,8 +187,10 @@ def get_context(self) -> TrainPipelineContext: class PipelinedForward(BaseForward): # pyre-ignore [2, 24] def __call__(self, *input, **kwargs) -> Awaitable: - assert self._name in self._context.input_dist_tensors_requests - request = self._context.input_dist_tensors_requests[self._name] + assert ( + self._name in self._context.input_dist_tensors_requests + ), "Invalid PipelinedForward usage, please do not directly call model.forward()" + request = self._context.input_dist_tensors_requests.pop(self._name) assert isinstance(request, Awaitable) with record_function("## wait_sparse_data_dist ##"): # Finish waiting on the dist_stream, @@ -198,6 +200,8 @@ def __call__(self, *input, **kwargs) -> Awaitable: # Make sure that both result of input_dist and context # are properly transferred to the current stream. + ctx = self._context.module_contexts.pop(self._name) + if self._stream is not None: torch.cuda.current_stream().wait_stream(self._stream) cur_stream = torch.cuda.current_stream() @@ -206,13 +210,9 @@ def __call__(self, *input, **kwargs) -> Awaitable: data, (torch.Tensor, Multistreamable) ), f"{type(data)} must implement Multistreamable interface" data.record_stream(cur_stream) - - ctx = self._context.module_contexts[self._name] ctx.record_stream(cur_stream) - return self._module.compute_and_output_dist( - self._context.module_contexts[self._name], data - ) + return self._module.compute_and_output_dist(ctx, data) class EmbeddingPipelinedForward(BaseForward):