diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index b419b6ce5..40ec51dcb 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -297,7 +297,7 @@ def _func_to_benchmark( ) if rank == 0: print( - f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.1f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.1f} GB" + f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.3f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.3f} GB" ) diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 8460cc836..3d49413fb 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -378,6 +378,8 @@ def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None: for names, awaitable in context.fused_splits_awaitables: for name, request in zip(names, awaitable.wait()): context.input_dist_tensors_requests[name] = request + context.input_dist_splits_requests.clear() + context.fused_splits_awaitables.clear() def _copy_batch_to_gpu(self, dataloader_iter: Iterator[In]) -> Optional[In]: """ diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 549dd5c72..bb250aa57 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -187,8 +187,10 @@ def get_context(self) -> TrainPipelineContext: class PipelinedForward(BaseForward): # pyre-ignore [2, 24] def __call__(self, *input, **kwargs) -> Awaitable: - assert self._name in self._context.input_dist_tensors_requests - request = self._context.input_dist_tensors_requests[self._name] + assert ( + self._name in self._context.input_dist_tensors_requests + ), "Invalid PipelinedForward usage, please do not directly call model.forward()" + request = self._context.input_dist_tensors_requests.pop(self._name) assert isinstance(request, Awaitable) with record_function("## wait_sparse_data_dist ##"): # Finish waiting on the dist_stream, @@ -198,6 +200,8 @@ def __call__(self, *input, **kwargs) -> Awaitable: # Make sure that both result of input_dist and context # are properly transferred to the current stream. + ctx = self._context.module_contexts.pop(self._name) + if self._stream is not None: torch.cuda.current_stream().wait_stream(self._stream) cur_stream = torch.cuda.current_stream() @@ -206,13 +210,9 @@ def __call__(self, *input, **kwargs) -> Awaitable: data, (torch.Tensor, Multistreamable) ), f"{type(data)} must implement Multistreamable interface" data.record_stream(cur_stream) - - ctx = self._context.module_contexts[self._name] ctx.record_stream(cur_stream) - return self._module.compute_and_output_dist( - self._context.module_contexts[self._name], data - ) + return self._module.compute_and_output_dist(ctx, data) class EmbeddingPipelinedForward(BaseForward):