diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 56196dff3..f332ac994 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -17,7 +17,11 @@ from torch import nn from torchrec.ir.serializer import JsonSerializer -from torchrec.ir.utils import deserialize_embedding_modules, serialize_embedding_modules +from torchrec.ir.utils import ( + deserialize_embedding_modules, + mark_dynamic_kjt, + serialize_embedding_modules, +) from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.modules.embedding_modules import EmbeddingBagCollection @@ -174,6 +178,52 @@ def test_serialize_deserialize_ebc(self) -> None: assert eager_out[i].shape == tensor.shape assert torch.allclose(eager_out[i], tensor) + def test_dynamic_shape_ebc(self) -> None: + model = self.generate_model() + feature1 = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 2, 3]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]), + ) + + feature2 = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 2, 3, 4]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]), + ) + eager_out = model(feature2) + + # Serialize EBC + collection = mark_dynamic_kjt(feature1) + model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (feature1,), + {}, + dynamic_shapes=collection.dynamic_shapes(model, (feature1,)), + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + + # Run forward on ExportedProgram + ep_output = ep.module()(feature2) + + # other asserts + for i, tensor in enumerate(ep_output): + self.assertEqual(eager_out[i].shape, tensor.shape) + + # Deserialize EBC + deserialized_model = deserialize_embedding_modules(ep, JsonSerializer) + + deserialized_model.load_state_dict(model.state_dict()) + # Run forward on deserialized model + deserialized_out = deserialized_model(feature2) + + for i, tensor in enumerate(deserialized_out): + self.assertEqual(eager_out[i].shape, tensor.shape) + assert torch.allclose(eager_out[i], tensor) + def test_deserialized_device(self) -> None: model = self.generate_model() id_list_features = KeyedJaggedTensor.from_offsets_sync( diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index 0cf90d984..1676f166c 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -9,17 +9,21 @@ #!/usr/bin/env python3 -from typing import List, Optional, Tuple, Type +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Type, Union import torch from torch import nn -from torch.export.exported_program import ExportedProgram +from torch.export import Dim, ExportedProgram, ShapesCollection +from torch.export.dynamic_shapes import _Dim as DIM +from torchrec import KeyedJaggedTensor from torchrec.ir.types import SerializerInterface # TODO: Replace the default interface with the python dataclass interface DEFAULT_SERIALIZER_CLS = SerializerInterface +DYNAMIC_DIMS: Dict[str, int] = defaultdict(int) def serialize_embedding_modules( @@ -88,3 +92,62 @@ def deserialize_embedding_modules( setattr(parent, attrs[-1], new_module) return model + + +def _get_dim(x: Union[DIM, str, None], s: str, max: Optional[int] = None) -> DIM: + if isinstance(x, DIM): + return x + elif isinstance(x, str): + if x in DYNAMIC_DIMS: + DYNAMIC_DIMS[x] += 1 + x += str(DYNAMIC_DIMS[x]) + dim = Dim(x, max=max) + else: + DYNAMIC_DIMS[s] += 1 + dim = Dim(s + str(DYNAMIC_DIMS[s]), max=max) + return dim + + +def mark_dynamic_kjt( + kjt: KeyedJaggedTensor, + shapes_collection: Optional[ShapesCollection] = None, + variable_length: bool = False, + vlen: Optional[Union[DIM, str]] = None, + batch_size: Optional[Union[DIM, str]] = None, +) -> ShapesCollection: + """ + Makes the given KJT dynamic. If it's not variable length, it will only have + one dynamic dimension, which is the length of the values (and weights). + If it is variable length, then the lengths and offsets will be dynamic. + + If a shapes collection is provided, it will be updated with the new shapes, + otherwise a new shapes collection will be created. A passed-in shapes_collection is + useful if you have multiple KJTs or other dynamic shapes that you want to trace. + + If a dynamic dim/name is provided, it will directly use that dim/name. Otherwise, + it will use the default name "vlen" for values, and "llen", "lofs" if variable length. + A passed-in dynamic dim is useful if the dynamic dim is already used in other places. + + Args: + kjt (KeyedJaggedTensor): The KJT to make dynamic. + shapes_collection (Optional[ShapesCollection]): The collection to update. + variable_length (bool): Whether the KJT is variable length. + vlen (Optional[Union[DIM, str]]): The dynamic length for the values. + batch_size (Optional[Union[DIM, str]]): The dynamic length for the batch_size. + """ + global DYNAMIC_DIMS + if shapes_collection is None: + shapes_collection = ShapesCollection() + vlen = _get_dim(vlen, "vlen") + shapes_collection[kjt._values] = (vlen,) + if kjt._weights is not None: + shapes_collection[kjt._weights] = (vlen,) + if variable_length: + batch_size = _get_dim(batch_size, "batch_size", max=4294967295) + llen = len(kjt.keys()) * batch_size + olen = llen + 1 + if kjt._lengths is not None: + shapes_collection[kjt._lengths] = (llen,) + if kjt._offsets is not None: + shapes_collection[kjt._offsets] = (olen,) + return shapes_collection