Skip to content

Commit 9ee9e64

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
split test sharders into a separate file (#3453)
Summary: Pull Request resolved: #3453 # context * refactor the overloaded test_utils.test_model.py file * split the sharder classes into a separate file Reviewed By: spmex Differential Revision: D84458052 fbshipit-source-id: db7d1e930f2bc44abb7d632836dd99b43d3c66fc
1 parent 8f59580 commit 9ee9e64

File tree

9 files changed

+306
-291
lines changed

9 files changed

+306
-291
lines changed

torchrec/distributed/benchmark/benchmark_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
get_tables,
3030
)
3131
from torchrec.distributed.embedding_types import EmbeddingComputeKernel, ShardingType
32-
from torchrec.distributed.test_utils.test_model import TestEBCSharder
32+
from torchrec.distributed.test_utils.emb_sharder import TestEBCSharder
3333
from torchrec.distributed.types import DataType
3434
from torchrec.modules.embedding_modules import EmbeddingBagCollection
3535
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

torchrec/distributed/planner/tests/test_shard_estimators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@
4343
Topology,
4444
)
4545
from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder
46+
from torchrec.distributed.test_utils.emb_sharder import TestEBCSharder
4647
from torchrec.distributed.test_utils.infer_utils import quantize
47-
from torchrec.distributed.test_utils.test_model import TestEBCSharder, TestSparseNN
48+
from torchrec.distributed.test_utils.test_model import TestSparseNN
4849
from torchrec.distributed.tests.test_sequence_model import TestSequenceSparseNN
4950
from torchrec.distributed.types import (
5051
CacheParams,
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
from typing import Any, Dict, List, Optional, Type
11+
12+
import torch
13+
from torchrec.distributed.embedding import EmbeddingCollectionSharder
14+
from torchrec.distributed.embedding_tower_sharding import (
15+
EmbeddingTowerCollectionSharder,
16+
EmbeddingTowerSharder,
17+
)
18+
19+
from torchrec.distributed.embeddingbag import (
20+
EmbeddingBagCollectionSharder,
21+
EmbeddingBagSharder,
22+
)
23+
from torchrec.distributed.fused_embedding import FusedEmbeddingCollectionSharder
24+
from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder
25+
from torchrec.distributed.mc_embedding_modules import (
26+
BaseManagedCollisionEmbeddingCollectionSharder,
27+
)
28+
from torchrec.distributed.mc_embeddingbag import (
29+
ShardedManagedCollisionEmbeddingBagCollection,
30+
)
31+
from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder
32+
from torchrec.distributed.types import (
33+
ParameterSharding,
34+
QuantizedCommCodecs,
35+
ShardingEnv,
36+
)
37+
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection
38+
39+
40+
class TestECSharder(EmbeddingCollectionSharder):
41+
def __init__(
42+
self,
43+
sharding_type: str,
44+
kernel_type: str,
45+
fused_params: Optional[Dict[str, Any]] = None,
46+
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
47+
) -> None:
48+
if fused_params is None:
49+
fused_params = {}
50+
51+
self._sharding_type = sharding_type
52+
self._kernel_type = kernel_type
53+
super().__init__(fused_params, qcomm_codecs_registry)
54+
55+
"""
56+
Restricts sharding to single type only.
57+
"""
58+
59+
def sharding_types(self, compute_device_type: str) -> List[str]:
60+
return [self._sharding_type]
61+
62+
"""
63+
Restricts to single impl.
64+
"""
65+
66+
def compute_kernels(
67+
self, sharding_type: str, compute_device_type: str
68+
) -> List[str]:
69+
return [self._kernel_type]
70+
71+
72+
class TestEBCSharder(EmbeddingBagCollectionSharder):
73+
def __init__(
74+
self,
75+
sharding_type: str,
76+
kernel_type: str,
77+
fused_params: Optional[Dict[str, Any]] = None,
78+
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
79+
) -> None:
80+
if fused_params is None:
81+
fused_params = {}
82+
83+
self._sharding_type = sharding_type
84+
self._kernel_type = kernel_type
85+
super().__init__(fused_params, qcomm_codecs_registry)
86+
87+
"""
88+
Restricts sharding to single type only.
89+
"""
90+
91+
def sharding_types(self, compute_device_type: str) -> List[str]:
92+
return [self._sharding_type]
93+
94+
"""
95+
Restricts to single impl.
96+
"""
97+
98+
def compute_kernels(
99+
self, sharding_type: str, compute_device_type: str
100+
) -> List[str]:
101+
return [self._kernel_type]
102+
103+
104+
class TestMCSharder(ManagedCollisionCollectionSharder):
105+
def __init__(
106+
self,
107+
sharding_type: str,
108+
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
109+
) -> None:
110+
self._sharding_type = sharding_type
111+
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
112+
113+
def sharding_types(self, compute_device_type: str) -> List[str]:
114+
return [self._sharding_type]
115+
116+
117+
class TestEBCSharderMCH(
118+
BaseManagedCollisionEmbeddingCollectionSharder[
119+
ManagedCollisionEmbeddingBagCollection
120+
]
121+
):
122+
def __init__(
123+
self,
124+
sharding_type: str,
125+
kernel_type: str,
126+
fused_params: Optional[Dict[str, Any]] = None,
127+
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
128+
) -> None:
129+
super().__init__(
130+
TestEBCSharder(
131+
sharding_type, kernel_type, fused_params, qcomm_codecs_registry
132+
),
133+
TestMCSharder(sharding_type, qcomm_codecs_registry),
134+
qcomm_codecs_registry=qcomm_codecs_registry,
135+
)
136+
137+
@property
138+
def module_type(self) -> Type[ManagedCollisionEmbeddingBagCollection]:
139+
return ManagedCollisionEmbeddingBagCollection
140+
141+
def shard(
142+
self,
143+
module: ManagedCollisionEmbeddingBagCollection,
144+
params: Dict[str, ParameterSharding],
145+
env: ShardingEnv,
146+
device: Optional[torch.device] = None,
147+
module_fqn: Optional[str] = None,
148+
) -> ShardedManagedCollisionEmbeddingBagCollection:
149+
if device is None:
150+
device = torch.device("cuda")
151+
return ShardedManagedCollisionEmbeddingBagCollection(
152+
module,
153+
params,
154+
# pyre-ignore [6]
155+
ebc_sharder=self._e_sharder,
156+
mc_sharder=self._mc_sharder,
157+
env=env,
158+
device=device,
159+
)
160+
161+
162+
class TestFusedEBCSharder(FusedEmbeddingBagCollectionSharder):
163+
def __init__(
164+
self,
165+
sharding_type: str,
166+
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
167+
) -> None:
168+
super().__init__(fused_params={}, qcomm_codecs_registry=qcomm_codecs_registry)
169+
self._sharding_type = sharding_type
170+
171+
"""
172+
Restricts sharding to single type only.
173+
"""
174+
175+
def sharding_types(self, compute_device_type: str) -> List[str]:
176+
return [self._sharding_type]
177+
178+
179+
class TestFusedECSharder(FusedEmbeddingCollectionSharder):
180+
def __init__(
181+
self,
182+
sharding_type: str,
183+
) -> None:
184+
super().__init__()
185+
self._sharding_type = sharding_type
186+
187+
"""
188+
Restricts sharding to single type only.
189+
"""
190+
191+
def sharding_types(self, compute_device_type: str) -> List[str]:
192+
return [self._sharding_type]
193+
194+
195+
class TestEBSharder(EmbeddingBagSharder):
196+
def __init__(
197+
self,
198+
sharding_type: str,
199+
kernel_type: str,
200+
fused_params: Dict[str, Any],
201+
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
202+
) -> None:
203+
super().__init__(fused_params, qcomm_codecs_registry)
204+
self._sharding_type = sharding_type
205+
self._kernel_type = kernel_type
206+
207+
"""
208+
Restricts sharding to single type only.
209+
"""
210+
211+
def sharding_types(self, compute_device_type: str) -> List[str]:
212+
return [self._sharding_type]
213+
214+
"""
215+
Restricts to single impl.
216+
"""
217+
218+
def compute_kernels(
219+
self, sharding_type: str, compute_device_type: str
220+
) -> List[str]:
221+
return [self._kernel_type]
222+
223+
@property
224+
def fused_params(self) -> Optional[Dict[str, Any]]:
225+
return self._fused_params
226+
227+
228+
class TestETSharder(EmbeddingTowerSharder):
229+
def __init__(
230+
self,
231+
sharding_type: str,
232+
kernel_type: str,
233+
fused_params: Dict[str, Any],
234+
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
235+
) -> None:
236+
super().__init__(fused_params, qcomm_codecs_registry=qcomm_codecs_registry)
237+
self._sharding_type = sharding_type
238+
self._kernel_type = kernel_type
239+
240+
"""
241+
Restricts sharding to single type only.
242+
"""
243+
244+
def sharding_types(self, compute_device_type: str) -> List[str]:
245+
return [self._sharding_type]
246+
247+
"""
248+
Restricts to single impl.
249+
"""
250+
251+
def compute_kernels(
252+
self, sharding_type: str, compute_device_type: str
253+
) -> List[str]:
254+
return [self._kernel_type]
255+
256+
@property
257+
def fused_params(self) -> Optional[Dict[str, Any]]:
258+
return self._fused_params
259+
260+
261+
class TestETCSharder(EmbeddingTowerCollectionSharder):
262+
def __init__(
263+
self,
264+
sharding_type: str,
265+
kernel_type: str,
266+
fused_params: Dict[str, Any],
267+
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
268+
) -> None:
269+
super().__init__(fused_params, qcomm_codecs_registry=qcomm_codecs_registry)
270+
self._sharding_type = sharding_type
271+
self._kernel_type = kernel_type
272+
273+
"""
274+
Restricts sharding to single type only.
275+
"""
276+
277+
def sharding_types(self, compute_device_type: str) -> List[str]:
278+
return [self._sharding_type]
279+
280+
"""
281+
Restricts to single impl.
282+
"""
283+
284+
def compute_kernels(
285+
self, sharding_type: str, compute_device_type: str
286+
) -> List[str]:
287+
return [self._kernel_type]
288+
289+
@property
290+
def fused_params(self) -> Optional[Dict[str, Any]]:
291+
return self._fused_params

0 commit comments

Comments
 (0)