Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from torch import nn
from torchrec.ir.serializer import JsonSerializer

from torchrec.ir.utils import deserialize_embedding_modules, serialize_embedding_modules
from torchrec.ir.utils import (
deserialize_embedding_modules,
mark_dynamic_kjt,
serialize_embedding_modules,
)

from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
Expand Down Expand Up @@ -174,6 +178,52 @@ def test_serialize_deserialize_ebc(self) -> None:
assert eager_out[i].shape == tensor.shape
assert torch.allclose(eager_out[i], tensor)

def test_dynamic_shape_ebc(self) -> None:
model = self.generate_model()
feature1 = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3"],
values=torch.tensor([0, 1, 2, 3, 2, 3]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]),
)

feature2 = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3"],
values=torch.tensor([0, 1, 2, 3, 2, 3, 4]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
)
eager_out = model(feature2)

# Serialize EBC
collection = mark_dynamic_kjt(feature1)
model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer)
ep = torch.export.export(
model,
(feature1,),
{},
dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=(tuple(sparse_fqns)),
)

# Run forward on ExportedProgram
ep_output = ep.module()(feature2)

# other asserts
for i, tensor in enumerate(ep_output):
self.assertEqual(eager_out[i].shape, tensor.shape)

# Deserialize EBC
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)

deserialized_model.load_state_dict(model.state_dict())
# Run forward on deserialized model
deserialized_out = deserialized_model(feature2)

for i, tensor in enumerate(deserialized_out):
self.assertEqual(eager_out[i].shape, tensor.shape)
assert torch.allclose(eager_out[i], tensor)

def test_deserialized_device(self) -> None:
model = self.generate_model()
id_list_features = KeyedJaggedTensor.from_offsets_sync(
Expand Down
67 changes: 65 additions & 2 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,21 @@

#!/usr/bin/env python3

from typing import List, Optional, Tuple, Type
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Type, Union

import torch

from torch import nn
from torch.export.exported_program import ExportedProgram
from torch.export import Dim, ExportedProgram, ShapesCollection
from torch.export.dynamic_shapes import _Dim as DIM
from torchrec import KeyedJaggedTensor
from torchrec.ir.types import SerializerInterface


# TODO: Replace the default interface with the python dataclass interface
DEFAULT_SERIALIZER_CLS = SerializerInterface
DYNAMIC_DIMS: Dict[str, int] = defaultdict(int)


def serialize_embedding_modules(
Expand Down Expand Up @@ -88,3 +92,62 @@ def deserialize_embedding_modules(
setattr(parent, attrs[-1], new_module)

return model


def _get_dim(x: Union[DIM, str, None], s: str, max: Optional[int] = None) -> DIM:
if isinstance(x, DIM):
return x
elif isinstance(x, str):
if x in DYNAMIC_DIMS:
DYNAMIC_DIMS[x] += 1
x += str(DYNAMIC_DIMS[x])
dim = Dim(x, max=max)
else:
DYNAMIC_DIMS[s] += 1
dim = Dim(s + str(DYNAMIC_DIMS[s]), max=max)
return dim


def mark_dynamic_kjt(
kjt: KeyedJaggedTensor,
shapes_collection: Optional[ShapesCollection] = None,
variable_length: bool = False,
vlen: Optional[Union[DIM, str]] = None,
batch_size: Optional[Union[DIM, str]] = None,
) -> ShapesCollection:
"""
Makes the given KJT dynamic. If it's not variable length, it will only have
one dynamic dimension, which is the length of the values (and weights).
If it is variable length, then the lengths and offsets will be dynamic.

If a shapes collection is provided, it will be updated with the new shapes,
otherwise a new shapes collection will be created. A passed-in shapes_collection is
useful if you have multiple KJTs or other dynamic shapes that you want to trace.

If a dynamic dim/name is provided, it will directly use that dim/name. Otherwise,
it will use the default name "vlen" for values, and "llen", "lofs" if variable length.
A passed-in dynamic dim is useful if the dynamic dim is already used in other places.

Args:
kjt (KeyedJaggedTensor): The KJT to make dynamic.
shapes_collection (Optional[ShapesCollection]): The collection to update.
variable_length (bool): Whether the KJT is variable length.
vlen (Optional[Union[DIM, str]]): The dynamic length for the values.
batch_size (Optional[Union[DIM, str]]): The dynamic length for the batch_size.
"""
global DYNAMIC_DIMS
if shapes_collection is None:
shapes_collection = ShapesCollection()
vlen = _get_dim(vlen, "vlen")
shapes_collection[kjt._values] = (vlen,)
if kjt._weights is not None:
shapes_collection[kjt._weights] = (vlen,)
if variable_length:
batch_size = _get_dim(batch_size, "batch_size", max=4294967295)
llen = len(kjt.keys()) * batch_size
olen = llen + 1
if kjt._lengths is not None:
shapes_collection[kjt._lengths] = (llen,)
if kjt._offsets is not None:
shapes_collection[kjt._offsets] = (olen,)
return shapes_collection