Skip to content

Commit 2fa7ea7

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
split train_pipeline.utils - pipeline_context (#2978)
Summary: Pull Request resolved: #2978 # context * train_pipeline.utils file is overloaded * split the functions, classes, etc. into three files with each ~< 1000 lines * this diff: pipeline_context.py Reviewed By: malaybag Differential Revision: D73906059 fbshipit-source-id: 7b3e59279a5b27b1953d0e24cc206c8a395bbd8e
1 parent 1370c8c commit 2fa7ea7

File tree

7 files changed

+121
-93
lines changed

7 files changed

+121
-93
lines changed

torchrec/distributed/train_pipeline/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
# pyre-strict
99

1010

11+
from torchrec.distributed.train_pipeline.pipeline_context import ( # noqa
12+
In,
13+
Out,
14+
TrainPipelineContext,
15+
)
1116
from torchrec.distributed.train_pipeline.train_pipelines import ( # noqa
1217
EvalPipelineSparseDist, # noqa
1318
PrefetchTrainPipelineSparseDist, # noqa
@@ -30,10 +35,7 @@
3035
ArgInfoStepFactory, # noqa
3136
CallArgs, # noqa
3237
DataLoadingThread, # noqa
33-
In, # noqa
34-
Out, # noqa
3538
SparseDataDistUtil, # noqa
3639
StageOut, # noqa
3740
Tracer, # noqa
38-
TrainPipelineContext, # noqa
3941
)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
import logging
10+
from dataclasses import dataclass, field
11+
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
12+
13+
import torch
14+
15+
from torchrec.distributed.embedding_sharding import FusedKJTListSplitsAwaitable
16+
from torchrec.distributed.types import Awaitable, LazyAwaitable
17+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
18+
from torchrec.streamable import Multistreamable, Pipelineable
19+
20+
logger: logging.Logger = logging.getLogger(__name__)
21+
22+
23+
In = TypeVar("In", bound=Pipelineable)
24+
Out = TypeVar("Out")
25+
26+
27+
@dataclass
28+
class TrainPipelineContext:
29+
"""
30+
Context information for a `TrainPipelineSparseDist` instance.
31+
32+
Attributes:
33+
input_dist_splits_requests (Dict[str, Awaitable[Any]]): Stores input dist
34+
requests in the splits awaitable stage, which occurs after starting the
35+
input dist.
36+
input_dist_tensors_requests (Dict[str, Awaitable[Any]]): Stores input dist
37+
requests in the tensors awaitable stage, which occurs after calling `wait()`
38+
on the splits awaitable.
39+
module_contexts (Dict[str, Multistreamable]): Stores module contexts from the
40+
input dist for the current batch.
41+
module_contexts_next_batch (Dict[str, Multistreamable]): Stores module contexts
42+
from the input dist for the next batch. (only for version 0)
43+
fused_splits_awaitables (List[Tuple[List[str], FusedKJTListSplitsAwaitable]]):
44+
List of fused splits input dist awaitable and the corresponding module names
45+
of each awaitable.
46+
event: Optional[torch.cuda.Event]: Event to record the completion of this stage
47+
index: Optional[int]: Index of the current batch.
48+
version: int = 0; support for backward compatiblity
49+
"""
50+
51+
# pyre-ignore [4]
52+
input_dist_splits_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
53+
# pyre-ignore [4]
54+
input_dist_tensors_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
55+
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
56+
module_contexts_next_batch: Dict[str, Multistreamable] = field(
57+
default_factory=dict
58+
) # deprecated: to support legacy code
59+
fused_splits_awaitables: List[Tuple[List[str], FusedKJTListSplitsAwaitable]] = (
60+
field(default_factory=list)
61+
)
62+
events: List[torch.Event] = field(default_factory=list)
63+
postproc_fwd_results: Dict[str, Any] = field(default_factory=dict)
64+
index: Optional[int] = None
65+
version: int = (
66+
0 # 1 is current version, 0 is deprecated but supported for backward compatibility
67+
)
68+
69+
70+
@dataclass
71+
class PrefetchTrainPipelineContext(TrainPipelineContext):
72+
module_input_post_prefetch: Dict[str, Multistreamable] = field(default_factory=dict)
73+
module_contexts_post_prefetch: Dict[str, Multistreamable] = field(
74+
default_factory=dict
75+
)
76+
module_input_post_prefetch_next_batch: Dict[str, Multistreamable] = field(
77+
default_factory=dict
78+
)
79+
module_contexts_post_prefetch_next_batch: Dict[str, Multistreamable] = field(
80+
default_factory=dict
81+
)
82+
83+
84+
@dataclass
85+
class EmbeddingTrainPipelineContext(TrainPipelineContext):
86+
embedding_a2a_requests: Dict[
87+
str,
88+
Union[
89+
LazyAwaitable[Multistreamable],
90+
# ManagedCollisionEC/EBC returns tuple of awaitables
91+
Tuple[
92+
LazyAwaitable[KeyedTensor], LazyAwaitable[Optional[KeyedJaggedTensor]]
93+
],
94+
],
95+
] = field(default_factory=dict)
96+
embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)
97+
embedding_features: List[List[Union[str, List[str]]]] = field(default_factory=list)
98+
detached_embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)

torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
#!/usr/bin/env python3
1111

1212
import copy
13-
import os
14-
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
13+
from typing import Any, cast, Dict, List, Optional, Tuple, Type, Union
1514

1615
import click
1716

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from torchrec.distributed.tests.test_fp_embeddingbag_utils import (
4848
create_module_and_freeze,
4949
)
50+
from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext
5051
from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import (
5152
TrainPipelineSparseDistTestBase,
5253
)
@@ -73,7 +74,6 @@
7374
PostprocArgInfoStep,
7475
SparseDataDistUtil,
7576
StageOut,
76-
TrainPipelineContext,
7777
)
7878
from torchrec.distributed.types import (
7979
ModuleSharder,

torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
2020
from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule
21+
from torchrec.distributed.train_pipeline.pipeline_context import TrainPipelineContext
2122

2223
from torchrec.distributed.train_pipeline.tests.test_train_pipelines_base import (
2324
TrainPipelineSparseDistTestBase,
@@ -30,7 +31,6 @@
3031
NodeArgsHelper,
3132
PipelinedForward,
3233
PipelinedPostproc,
33-
TrainPipelineContext,
3434
)
3535
from torchrec.distributed.types import ShardingType
3636
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@
3333
from torch.autograd.profiler import record_function
3434
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
3535
from torchrec.distributed.model_parallel import ShardedModule
36+
from torchrec.distributed.train_pipeline.pipeline_context import (
37+
EmbeddingTrainPipelineContext,
38+
In,
39+
Out,
40+
PrefetchTrainPipelineContext,
41+
TrainPipelineContext,
42+
)
3643
from torchrec.distributed.train_pipeline.utils import (
3744
_override_input_dist_forwards,
3845
_pipeline_detach_model,
@@ -45,19 +52,14 @@
4552
_wait_for_events,
4653
DataLoadingThread,
4754
EmbeddingPipelinedForward,
48-
EmbeddingTrainPipelineContext,
49-
In,
5055
InSyncEmbeddingPipelinedForward,
51-
Out,
5256
PipelinedForward,
5357
PipelinedPostproc,
5458
PipelineStage,
5559
PrefetchPipelinedForward,
56-
PrefetchTrainPipelineContext,
5760
RunnableType,
5861
StageOut,
5962
StageOutputWithEvent,
60-
TrainPipelineContext,
6163
use_context_for_postprocs,
6264
)
6365
from torchrec.distributed.types import Awaitable

torchrec/distributed/train_pipeline/utils.py

Lines changed: 8 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
from collections import defaultdict, deque, OrderedDict
1616
from contextlib import AbstractContextManager
17-
from dataclasses import dataclass, field
17+
from dataclasses import dataclass
1818

1919
from itertools import chain
2020
from threading import Event, Thread
@@ -40,7 +40,6 @@
4040
import torch
4141
from torch import distributed as dist
4242
from torch.utils.hooks import RemovableHandle
43-
from torchrec.distributed.types import LazyAwaitable
4443

4544
if not torch._running_with_deploy():
4645
from torch.distributed._composable.fsdp.fully_shard import FSDPModule as FSDP2
@@ -66,6 +65,13 @@ class FSDP2:
6665
)
6766
from torchrec.distributed.embedding_types import KJTList
6867
from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule
68+
from torchrec.distributed.train_pipeline.pipeline_context import (
69+
EmbeddingTrainPipelineContext,
70+
In,
71+
Out, # noqa
72+
PrefetchTrainPipelineContext,
73+
TrainPipelineContext,
74+
)
6975

7076
from torchrec.distributed.types import Awaitable, LazyNoWait
7177

@@ -74,90 +80,11 @@ class FSDP2:
7480

7581
logger: logging.Logger = logging.getLogger(__name__)
7682

77-
import torch
78-
79-
In = TypeVar("In", bound=Pipelineable)
8083
StageOut = TypeVar("StageOut", bound=Pipelineable)
81-
Out = TypeVar("Out")
82-
8384
RunnableType = Callable[..., StageOut]
8485
StageOutputWithEvent = Tuple[Optional[StageOut], Optional[torch.Event]]
8586

8687

87-
@dataclass
88-
class TrainPipelineContext:
89-
"""
90-
Context information for a `TrainPipelineSparseDist` instance.
91-
92-
Attributes:
93-
input_dist_splits_requests (Dict[str, Awaitable[Any]]): Stores input dist
94-
requests in the splits awaitable stage, which occurs after starting the
95-
input dist.
96-
input_dist_tensors_requests (Dict[str, Awaitable[Any]]): Stores input dist
97-
requests in the tensors awaitable stage, which occurs after calling `wait()`
98-
on the splits awaitable.
99-
module_contexts (Dict[str, Multistreamable]): Stores module contexts from the
100-
input dist for the current batch.
101-
module_contexts_next_batch (Dict[str, Multistreamable]): Stores module contexts
102-
from the input dist for the next batch. (only for version 0)
103-
fused_splits_awaitables (List[Tuple[List[str], FusedKJTListSplitsAwaitable]]):
104-
List of fused splits input dist awaitable and the corresponding module names
105-
of each awaitable.
106-
event: Optional[torch.cuda.Event]: Event to record the completion of this stage
107-
index: Optional[int]: Index of the current batch.
108-
version: int = 0; support for backward compatiblity
109-
"""
110-
111-
# pyre-ignore [4]
112-
input_dist_splits_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
113-
# pyre-ignore [4]
114-
input_dist_tensors_requests: Dict[str, Awaitable[Any]] = field(default_factory=dict)
115-
module_contexts: Dict[str, Multistreamable] = field(default_factory=dict)
116-
module_contexts_next_batch: Dict[str, Multistreamable] = field(
117-
default_factory=dict
118-
) # deprecated: to support legacy code
119-
fused_splits_awaitables: List[Tuple[List[str], FusedKJTListSplitsAwaitable]] = (
120-
field(default_factory=list)
121-
)
122-
events: List[torch.Event] = field(default_factory=list)
123-
postproc_fwd_results: Dict[str, Any] = field(default_factory=dict)
124-
index: Optional[int] = None
125-
version: int = (
126-
0 # 1 is current version, 0 is deprecated but supported for backward compatibility
127-
)
128-
129-
130-
@dataclass
131-
class PrefetchTrainPipelineContext(TrainPipelineContext):
132-
module_input_post_prefetch: Dict[str, Multistreamable] = field(default_factory=dict)
133-
module_contexts_post_prefetch: Dict[str, Multistreamable] = field(
134-
default_factory=dict
135-
)
136-
module_input_post_prefetch_next_batch: Dict[str, Multistreamable] = field(
137-
default_factory=dict
138-
)
139-
module_contexts_post_prefetch_next_batch: Dict[str, Multistreamable] = field(
140-
default_factory=dict
141-
)
142-
143-
144-
@dataclass
145-
class EmbeddingTrainPipelineContext(TrainPipelineContext):
146-
embedding_a2a_requests: Dict[
147-
str,
148-
Union[
149-
LazyAwaitable[Multistreamable],
150-
# ManagedCollisionEC/EBC returns tuple of awaitables
151-
Tuple[
152-
LazyAwaitable[KeyedTensor], LazyAwaitable[Optional[KeyedJaggedTensor]]
153-
],
154-
],
155-
] = field(default_factory=dict)
156-
embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)
157-
embedding_features: List[List[Union[str, List[str]]]] = field(default_factory=list)
158-
detached_embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)
159-
160-
16188
@dataclass
16289
class PipelineStage:
16390
"""

0 commit comments

Comments
 (0)