Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
TrainPipelineBase,
TrainPipelineSparseDist,
)
from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSemiSync
from torchrec.distributed.train_pipeline.train_pipelines import (
PrefetchTrainPipelineSparseDist,
TrainPipelineSemiSync,
)
from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingType
from torchrec.modules.embedding_configs import EmbeddingBagConfig

Expand Down Expand Up @@ -260,6 +263,7 @@ def runner(
TrainPipelineBase,
TrainPipelineSparseDist,
TrainPipelineSemiSync,
PrefetchTrainPipelineSparseDist,
]:
pipeline = pipeline_clazz(
model=sharded_model,
Expand Down Expand Up @@ -297,7 +301,7 @@ def _func_to_benchmark(
)
if rank == 0:
print(
f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.1f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.1f} GB"
f" {pipeline_clazz.__name__: <{35}} | Runtime (P90): {result.runtime_percentile(90)/1000:5.3f} s | Memory (P90): {result.max_mem_percentile(90)/1000:5.3f} GB"
)


Expand Down
120 changes: 63 additions & 57 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,26 +197,26 @@ def __init__(
self._batch_ip2: Optional[In] = None
self._context: TrainPipelineContext = context_type(version=0)

def _set_module_context(self, context: TrainPipelineContext) -> None:
for module in self._pipelined_modules:
module.forward.set_context(context)

def enqueue_batch(self, dataloader_iter: Iterator[In]) -> bool:
batch, context = self.copy_batch_to_gpu(dataloader_iter)
if batch is None:
return False
self.batches.append(batch)
# pyre-ignore [6]
self.contexts.append(context)

return True

def dequeue_batch(self) -> None:
self.batches.popleft()
self.contexts.popleft()
# update PipelineForwards context to match next forward pass
if len(self.batches) >= 1:
for module in self._pipelined_modules:
module.forward.set_context(self.contexts[0])

# legacy support
self._context = self.contexts[0]
self._context.version = 0
self._set_module_context(self.contexts[0])

def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
# pipeline is already filled
Expand Down Expand Up @@ -247,6 +247,9 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
if not self.batches:
raise StopIteration

# TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only)
self._set_module_context(self.contexts[0])

if self._model.training:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()
Expand Down Expand Up @@ -298,8 +301,7 @@ def _init_pipelined_modules(
the splits collective in the input dist.
"""
if self._pipelined_modules:
for module in self._pipelined_modules:
module.forward.set_context(context)
self._set_module_context(context)
self.start_sparse_data_dist(batch, context)
return

Expand Down Expand Up @@ -378,11 +380,14 @@ def wait_sparse_data_dist(self, context: TrainPipelineContext) -> None:
for names, awaitable in context.fused_splits_awaitables:
for name, request in zip(names, awaitable.wait()):
context.input_dist_tensors_requests[name] = request
context.input_dist_splits_requests.clear()
context.fused_splits_awaitables.clear()

def _copy_batch_to_gpu(self, dataloader_iter: Iterator[In]) -> Optional[In]:
"""
DEPRECATED: exists for backward compatibility on TrainPipelineContext.version 0
"""
self._set_module_context(self._context)
batch, _ = self.copy_batch_to_gpu(dataloader_iter)
return batch

Expand All @@ -391,6 +396,7 @@ def _start_sparse_data_dist(self, batch: Optional[In]) -> None:
DEPRECATED: exists for backward compatibility
Waits for batch to finish getting copied to GPU, then starts the input dist.
"""
self._set_module_context(self._context)
self.start_sparse_data_dist(batch, self._context)

def _wait_sparse_data_dist(self) -> None:
Expand All @@ -399,7 +405,7 @@ def _wait_sparse_data_dist(self) -> None:
Waits on the input dist splits requests to get the input dist tensors requests,
and populates the context with them.
"""
assert self._context.version == 0, "Context version == 0 is required"
self._set_module_context(self._context)
with record_function("## wait_sparse_data_dist ##"):
with torch.cuda.stream(self._data_dist_stream):
self._context.module_contexts = (
Expand Down Expand Up @@ -674,73 +680,67 @@ def __init__(
apply_jit=apply_jit,
context_type=PrefetchTrainPipelineContext,
)
self._context = PrefetchTrainPipelineContext(version=0)
self._prefetch_stream: Optional[torch.cuda.streams.Stream] = (
(torch.cuda.Stream()) if self._device.type == "cuda" else None
)
self._default_stream: Optional[torch.cuda.streams.Stream] = (
(torch.cuda.Stream()) if self._device.type == "cuda" else None
)
self._batch_ip3: Optional[In] = None

def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
# pipeline is full
if len(self.batches) >= 3:
def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
# pipeline is already filled
if self._batch_i and self._batch_ip1 and self._batch_ip2:
return
# executes last batch(es) in pipeline
if self.batches and self._execute_all_batches:
# executes last batch in pipeline
if self._execute_all_batches and (self._batch_i or self._batch_ip1):
return

# batch 0
if not self.enqueue_batch(dataloader_iter):
return
# batch 1
self._batch_i = self._copy_batch_to_gpu(dataloader_iter)
if self._batch_i is None:
raise StopIteration

self._init_pipelined_modules(
# pyre-ignore [6]
self.batches[0],
self.contexts[0],
# pyre-ignore [6]
self._batch_i,
self._context,
# pyre-ignore
PrefetchPipelinedForward,
)
self.wait_sparse_data_dist(self.contexts[0])
self._prefetch(self.batches[0], self.contexts[0])

# batch 1
if not self.enqueue_batch(dataloader_iter):
return
self.start_sparse_data_dist(self.batches[1], self.contexts[1])
self.wait_sparse_data_dist(self.contexts[1])
self._start_sparse_data_dist(self._batch_i)
self._wait_sparse_data_dist()
self._prefetch(self._batch_i)

# batch 2
if not self.enqueue_batch(dataloader_iter):
return
self._batch_ip1 = self._copy_batch_to_gpu(dataloader_iter)
self._start_sparse_data_dist(self._batch_ip1)
self._wait_sparse_data_dist()

# batch 3
self._batch_ip2 = self._copy_batch_to_gpu(dataloader_iter)

def progress(self, dataloader_iter: Iterator[In]) -> Out:
self.fill_pipeline(dataloader_iter)
if not self.batches:
raise StopIteration
self._fill_pipeline(dataloader_iter)

if self._model.training:
with record_function("## zero_grad ##"):
self._optimizer.zero_grad()

with record_function("## wait_for_batch ##"):
_wait_for_batch(cast(In, self.batches[0]), self._prefetch_stream)
_wait_for_batch(cast(In, self._batch_i), self._prefetch_stream)

if len(self.batches) >= 3:
self.start_sparse_data_dist(self.batches[2], self.contexts[2])
self._start_sparse_data_dist(self._batch_ip2)

# batch 3
self.enqueue_batch(dataloader_iter)
self._batch_ip3 = self._copy_batch_to_gpu(dataloader_iter)

# forward
with record_function("## forward ##"):
losses, output = cast(
Tuple[torch.Tensor, Out], self._model(self.batches[0])
)
losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i))

if len(self.batches) >= 2:
self._prefetch(self.batches[1], self.contexts[1])
self._prefetch(self._batch_ip1)

if len(self.batches) >= 3:
self.wait_sparse_data_dist(self.contexts[2])
self._wait_sparse_data_dist()

if self._model.training:
# backward
Expand All @@ -751,24 +751,30 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
with record_function("## optimizer ##"):
self._optimizer.step()

self.dequeue_batch()
self._batch_i = self._batch_ip1
self._batch_ip1 = self._batch_ip2
self._batch_ip2 = self._batch_ip3

return output

def _prefetch(self, batch: Optional[In], context: TrainPipelineContext) -> None:
def _prefetch(self, batch: Optional[In]) -> None:
"""
Waits for input dist to finish, then prefetches data.
"""
if batch is None:
return
with record_function(f"## sharded_module_prefetch {context.index} ##"):
self._context.module_input_post_prefetch.clear()
self._context.module_contexts_post_prefetch.clear()

with record_function("## sharded_module_prefetch ##"):
with torch.cuda.stream(self._prefetch_stream):
batch.record_stream(torch.cuda.current_stream())
for sharded_module in self._pipelined_modules:
forward = sharded_module.forward
assert isinstance(forward, PrefetchPipelinedForward)

assert forward._name in context.input_dist_tensors_requests
request = context.input_dist_tensors_requests[forward._name]
assert forward._name in self._context.input_dist_tensors_requests
request = self._context.input_dist_tensors_requests[forward._name]
assert isinstance(request, Awaitable)
with record_function("## wait_sparse_data_dist ##"):
# Finish waiting on the dist_stream,
Expand All @@ -788,16 +794,16 @@ def _prefetch(self, batch: Optional[In], context: TrainPipelineContext) -> None:
data.record_stream(cur_stream)
data.record_stream(self._default_stream)

module_context = context.module_contexts[forward._name]
module_context.record_stream(cur_stream)
module_context.record_stream(self._default_stream)
ctx = self._context.module_contexts[forward._name]
ctx.record_stream(cur_stream)
ctx.record_stream(self._default_stream)

sharded_module.prefetch(
dist_input=data, forward_stream=self._default_stream
)
context.module_input_post_prefetch[forward._name] = data
context.module_contexts_post_prefetch[forward._name] = (
context.module_contexts[forward._name]
self._context.module_input_post_prefetch[forward._name] = data
self._context.module_contexts_post_prefetch[forward._name] = (
self._context.module_contexts[forward._name]
)


Expand Down
14 changes: 7 additions & 7 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,10 @@ def get_context(self) -> TrainPipelineContext:
class PipelinedForward(BaseForward):
# pyre-ignore [2, 24]
def __call__(self, *input, **kwargs) -> Awaitable:
assert self._name in self._context.input_dist_tensors_requests
request = self._context.input_dist_tensors_requests[self._name]
assert (
self._name in self._context.input_dist_tensors_requests
), "Invalid PipelinedForward usage, please do not directly call model.forward()"
request = self._context.input_dist_tensors_requests.pop(self._name)
assert isinstance(request, Awaitable)
with record_function("## wait_sparse_data_dist ##"):
# Finish waiting on the dist_stream,
Expand All @@ -198,6 +200,8 @@ def __call__(self, *input, **kwargs) -> Awaitable:

# Make sure that both result of input_dist and context
# are properly transferred to the current stream.
ctx = self._context.module_contexts.pop(self._name)

if self._stream is not None:
torch.cuda.current_stream().wait_stream(self._stream)
cur_stream = torch.cuda.current_stream()
Expand All @@ -206,13 +210,9 @@ def __call__(self, *input, **kwargs) -> Awaitable:
data, (torch.Tensor, Multistreamable)
), f"{type(data)} must implement Multistreamable interface"
data.record_stream(cur_stream)

ctx = self._context.module_contexts[self._name]
ctx.record_stream(cur_stream)

return self._module.compute_and_output_dist(
self._context.module_contexts[self._name], data
)
return self._module.compute_and_output_dist(ctx, data)


class EmbeddingPipelinedForward(BaseForward):
Expand Down