Skip to content

Commit 915606b

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
More aggressive memory freeing from TrainPipelineContext (#1967)
Summary: X-link: facebookresearch/recipes#43 As users highlighted, TrainPipeline refactoring introduced memory regression ~2% due to more context management for code readability. This results in higher peak memory (takes longer for a context to drop out of refcount) relatively easy to get a lot more aggressive about releasing memory stored in TrainPipelineContext. broader internal discusion: https://fb.workplace.com/groups/970281557043698/permalink/1664528510952329/ Differential Revision: D57123339 Privacy Context Container: 1203980333745195
1 parent 4190f50 commit 915606b

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def _func_to_benchmark(
297297
)
298298
if rank == 0:
299299
print(
300-
f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.1f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.1f} GB"
300+
f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.3f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.3f} GB"
301301
)
302302

303303

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,8 @@ def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None:
378378
for names, awaitable in context.fused_splits_awaitables:
379379
for name, request in zip(names, awaitable.wait()):
380380
context.input_dist_tensors_requests[name] = request
381+
context.input_dist_splits_requests.clear()
382+
context.fused_splits_awaitables.clear()
381383

382384
def _copy_batch_to_gpu(self, dataloader_iter: Iterator[In]) -> Optional[In]:
383385
"""

torchrec/distributed/train_pipeline/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,10 @@ def get_context(self) -> TrainPipelineContext:
187187
class PipelinedForward(BaseForward):
188188
# pyre-ignore [2, 24]
189189
def __call__(self, *input, **kwargs) -> Awaitable:
190-
assert self._name in self._context.input_dist_tensors_requests
191-
request = self._context.input_dist_tensors_requests[self._name]
190+
assert (
191+
self._name in self._context.input_dist_tensors_requests
192+
), "Invalid PipelinedForward usage, please do not directly call model.forward()"
193+
request = self._context.input_dist_tensors_requests.pop(self._name)
192194
assert isinstance(request, Awaitable)
193195
with record_function("## wait_sparse_data_dist ##"):
194196
# Finish waiting on the dist_stream,
@@ -198,6 +200,8 @@ def __call__(self, *input, **kwargs) -> Awaitable:
198200

199201
# Make sure that both result of input_dist and context
200202
# are properly transferred to the current stream.
203+
ctx = self._context.module_contexts.pop(self._name)
204+
201205
if self._stream is not None:
202206
torch.cuda.current_stream().wait_stream(self._stream)
203207
cur_stream = torch.cuda.current_stream()
@@ -206,13 +210,9 @@ def __call__(self, *input, **kwargs) -> Awaitable:
206210
data, (torch.Tensor, Multistreamable)
207211
), f"{type(data)} must implement Multistreamable interface"
208212
data.record_stream(cur_stream)
209-
210-
ctx = self._context.module_contexts[self._name]
211213
ctx.record_stream(cur_stream)
212214

213-
return self._module.compute_and_output_dist(
214-
self._context.module_contexts[self._name], data
215-
)
215+
return self._module.compute_and_output_dist(ctx, data)
216216

217217

218218
class EmbeddingPipelinedForward(BaseForward):

0 commit comments

Comments
 (0)