Skip to content

Commit cd470f8

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
util function for marking input KJT dynamic (#2058)
Summary: Pull Request resolved: #2058 # context * In the IR export workflow, the module takes KJTs as input and produces an `ExportedProgram` of the module * KJT actually has a variable length for the values and weights * This dynamic nature of KJT needs to be explicitly passed to torch.export # changes * add a util function to mark the input KJT's dynamic shape * add in the test of how to correctly specify the dynamics shapes for the input KJT # results * input KJTs with different value lengths ``` (Pdb) feature1.values() tensor([0, 1, 2, 3, 2, 3]) (Pdb) feature2.values() tensor([0, 1, 2, 3, 2, 3, 4]) ``` * exported_program can take those input KJTs ``` (Pdb) ep.module()(feature1) [tensor([[-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16], [-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16]]), tensor([[-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16, -1.4368e-15, -1.4368e-15, -1.4368e-15, -1.4368e-15], [-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16, -1.4368e-15, -1.4368e-15, -1.4368e-15, -1.4368e-15]])] (Pdb) ep.module()(feature2) [tensor([[-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16], [-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16]]), tensor([[-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16, -1.4368e-15, -1.4368e-15, -1.4368e-15, -1.4368e-15], [-2.8735e-16, -2.8735e-16, -2.8735e-16, -2.8735e-16, -1.4368e-15, -1.4368e-15, -1.4368e-15, -1.4368e-15]])] ``` * deserialized module can take those input KJTs ``` (Pdb) deserialized_model(feature1) [tensor([[ 0.2630, 0.1473, -0.3691, 0.2261], [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<SplitWithSizesBackward0>), tensor([[ 0.2198, -0.1648, -0.0121, 0.1998, -0.0384, -0.2458, -0.6844, 0.8741], [ 0.1313, 0.2968, -0.2979, -0.2150, -0.2593, 0.6758, 1.0010, 0.9052]], grad_fn=<SplitWithSizesBackward0>)] (Pdb) deserialized_model(feature2) [tensor([[ 0.2630, 0.1473, -0.3691, 0.2261], [ 0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<SplitWithSizesBackward0>), tensor([[ 0.2198, -0.1648, -0.0121, 0.1998, -0.0384, -0.2458, -0.6844, 0.8741], [ 0.1313, 0.2968, -0.2979, -0.2150, -0.9359, 0.1123, 0.5834, -0.1357]], grad_fn=<SplitWithSizesBackward0>)] ``` Reviewed By: PaulZhang12 Differential Revision: D57824907 fbshipit-source-id: 615f602314e6517dba37e83eea5066de5950dc42
1 parent 9ce1982 commit cd470f8

File tree

2 files changed

+116
-3
lines changed

2 files changed

+116
-3
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from torch import nn
1818
from torchrec.ir.serializer import JsonSerializer
1919

20-
from torchrec.ir.utils import deserialize_embedding_modules, serialize_embedding_modules
20+
from torchrec.ir.utils import (
21+
deserialize_embedding_modules,
22+
mark_dynamic_kjt,
23+
serialize_embedding_modules,
24+
)
2125

2226
from torchrec.modules.embedding_configs import EmbeddingBagConfig
2327
from torchrec.modules.embedding_modules import EmbeddingBagCollection
@@ -174,6 +178,52 @@ def test_serialize_deserialize_ebc(self) -> None:
174178
assert eager_out[i].shape == tensor.shape
175179
assert torch.allclose(eager_out[i], tensor)
176180

181+
def test_dynamic_shape_ebc(self) -> None:
182+
model = self.generate_model()
183+
feature1 = KeyedJaggedTensor.from_offsets_sync(
184+
keys=["f1", "f2", "f3"],
185+
values=torch.tensor([0, 1, 2, 3, 2, 3]),
186+
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]),
187+
)
188+
189+
feature2 = KeyedJaggedTensor.from_offsets_sync(
190+
keys=["f1", "f2", "f3"],
191+
values=torch.tensor([0, 1, 2, 3, 2, 3, 4]),
192+
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
193+
)
194+
eager_out = model(feature2)
195+
196+
# Serialize EBC
197+
collection = mark_dynamic_kjt(feature1)
198+
model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer)
199+
ep = torch.export.export(
200+
model,
201+
(feature1,),
202+
{},
203+
dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
204+
strict=False,
205+
# Allows KJT to not be unflattened and run a forward on unflattened EP
206+
preserve_module_call_signature=(tuple(sparse_fqns)),
207+
)
208+
209+
# Run forward on ExportedProgram
210+
ep_output = ep.module()(feature2)
211+
212+
# other asserts
213+
for i, tensor in enumerate(ep_output):
214+
self.assertEqual(eager_out[i].shape, tensor.shape)
215+
216+
# Deserialize EBC
217+
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)
218+
219+
deserialized_model.load_state_dict(model.state_dict())
220+
# Run forward on deserialized model
221+
deserialized_out = deserialized_model(feature2)
222+
223+
for i, tensor in enumerate(deserialized_out):
224+
self.assertEqual(eager_out[i].shape, tensor.shape)
225+
assert torch.allclose(eager_out[i], tensor)
226+
177227
def test_deserialized_device(self) -> None:
178228
model = self.generate_model()
179229
id_list_features = KeyedJaggedTensor.from_offsets_sync(

torchrec/ir/utils.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,21 @@
99

1010
#!/usr/bin/env python3
1111

12-
from typing import List, Optional, Tuple, Type
12+
from collections import defaultdict
13+
from typing import Dict, List, Optional, Tuple, Type, Union
1314

1415
import torch
1516

1617
from torch import nn
17-
from torch.export.exported_program import ExportedProgram
18+
from torch.export import Dim, ExportedProgram, ShapesCollection
19+
from torch.export.dynamic_shapes import _Dim as DIM
20+
from torchrec import KeyedJaggedTensor
1821
from torchrec.ir.types import SerializerInterface
1922

2023

2124
# TODO: Replace the default interface with the python dataclass interface
2225
DEFAULT_SERIALIZER_CLS = SerializerInterface
26+
DYNAMIC_DIMS: Dict[str, int] = defaultdict(int)
2327

2428

2529
def serialize_embedding_modules(
@@ -88,3 +92,62 @@ def deserialize_embedding_modules(
8892
setattr(parent, attrs[-1], new_module)
8993

9094
return model
95+
96+
97+
def _get_dim(x: Union[DIM, str, None], s: str, max: Optional[int] = None) -> DIM:
98+
if isinstance(x, DIM):
99+
return x
100+
elif isinstance(x, str):
101+
if x in DYNAMIC_DIMS:
102+
DYNAMIC_DIMS[x] += 1
103+
x += str(DYNAMIC_DIMS[x])
104+
dim = Dim(x, max=max)
105+
else:
106+
DYNAMIC_DIMS[s] += 1
107+
dim = Dim(s + str(DYNAMIC_DIMS[s]), max=max)
108+
return dim
109+
110+
111+
def mark_dynamic_kjt(
112+
kjt: KeyedJaggedTensor,
113+
shapes_collection: Optional[ShapesCollection] = None,
114+
variable_length: bool = False,
115+
vlen: Optional[Union[DIM, str]] = None,
116+
batch_size: Optional[Union[DIM, str]] = None,
117+
) -> ShapesCollection:
118+
"""
119+
Makes the given KJT dynamic. If it's not variable length, it will only have
120+
one dynamic dimension, which is the length of the values (and weights).
121+
If it is variable length, then the lengths and offsets will be dynamic.
122+
123+
If a shapes collection is provided, it will be updated with the new shapes,
124+
otherwise a new shapes collection will be created. A passed-in shapes_collection is
125+
useful if you have multiple KJTs or other dynamic shapes that you want to trace.
126+
127+
If a dynamic dim/name is provided, it will directly use that dim/name. Otherwise,
128+
it will use the default name "vlen" for values, and "llen", "lofs" if variable length.
129+
A passed-in dynamic dim is useful if the dynamic dim is already used in other places.
130+
131+
Args:
132+
kjt (KeyedJaggedTensor): The KJT to make dynamic.
133+
shapes_collection (Optional[ShapesCollection]): The collection to update.
134+
variable_length (bool): Whether the KJT is variable length.
135+
vlen (Optional[Union[DIM, str]]): The dynamic length for the values.
136+
batch_size (Optional[Union[DIM, str]]): The dynamic length for the batch_size.
137+
"""
138+
global DYNAMIC_DIMS
139+
if shapes_collection is None:
140+
shapes_collection = ShapesCollection()
141+
vlen = _get_dim(vlen, "vlen")
142+
shapes_collection[kjt._values] = (vlen,)
143+
if kjt._weights is not None:
144+
shapes_collection[kjt._weights] = (vlen,)
145+
if variable_length:
146+
batch_size = _get_dim(batch_size, "batch_size", max=4294967295)
147+
llen = len(kjt.keys()) * batch_size
148+
olen = llen + 1
149+
if kjt._lengths is not None:
150+
shapes_collection[kjt._lengths] = (llen,)
151+
if kjt._offsets is not None:
152+
shapes_collection[kjt._offsets] = (olen,)
153+
return shapes_collection

0 commit comments

Comments
 (0)