Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 48 additions & 53 deletions torchrec/ir/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import json
import logging
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, List, Optional, Tuple, Type

import torch

Expand Down Expand Up @@ -69,8 +69,18 @@ def get_deserialized_device(
return device


class JsonSerializerBase(SerializerInterface):
class JsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using json.
"""

module_to_serializer_cls: Dict[str, Type["JsonSerializer"]] = {}
_module_cls: Optional[Type[nn.Module]] = None
_children: Optional[List[str]] = None

@classmethod
def children(cls, module: nn.Module) -> List[str]:
return [] if not cls._children else cls._children

@classmethod
def serialize_to_dict(cls, module: nn.Module) -> Dict[str, Any]:
Expand All @@ -81,47 +91,67 @@ def deserialize_from_dict(
cls,
metadata_dict: Dict[str, Any],
device: Optional[torch.device] = None,
unflatten: Optional[nn.Module] = None,
) -> nn.Module:
raise NotImplementedError()

@classmethod
def serialize(
cls,
module: nn.Module,
) -> torch.Tensor:
if cls._module_cls is None:
) -> Tuple[torch.Tensor, List[str]]:
typename = type(module).__name__
serializer = cls.module_to_serializer_cls.get(typename)
if serializer is None:
raise ValueError(
"Must assign a nn.Module to class static variable _module_cls"
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
)
if not isinstance(module, cls._module_cls):
assert issubclass(serializer, JsonSerializer)
assert serializer._module_cls is not None
if not isinstance(module, serializer._module_cls):
raise ValueError(
f"Expected module to be of type {cls._module_cls.__name__}, got {type(module)}"
f"Expected module to be of type {serializer._module_cls.__name__}, "
f"got {type(module)}"
)
metadata_dict = cls.serialize_to_dict(module)
return torch.frombuffer(json.dumps(metadata_dict).encode(), dtype=torch.uint8)
metadata_dict = serializer.serialize_to_dict(module)
raw_dict = {"typename": typename, "metadata_dict": metadata_dict}
serialized_tensor = torch.frombuffer(
json.dumps(raw_dict).encode(), dtype=torch.uint8
)
return serialized_tensor, serializer.children(module)

@classmethod
def deserialize(
cls,
input: torch.Tensor,
typename: str,
device: Optional[torch.device] = None,
unflatten: Optional[nn.Module] = None,
) -> nn.Module:
raw_bytes = input.numpy().tobytes()
metadata_dict = json.loads(raw_bytes.decode())
module = cls.deserialize_from_dict(metadata_dict, device)
if cls._module_cls is None:
raw_dict = json.loads(raw_bytes.decode())
typename = raw_dict["typename"]
if typename not in cls.module_to_serializer_cls:
raise ValueError(
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
)
serializer = cls.module_to_serializer_cls[typename]
assert issubclass(serializer, JsonSerializer)
module = serializer.deserialize_from_dict(
raw_dict["metadata_dict"], device, unflatten
)

if serializer._module_cls is None:
raise ValueError(
"Must assign a nn.Module to class static variable _module_cls"
)
if not isinstance(module, cls._module_cls):
if not isinstance(module, serializer._module_cls):
raise ValueError(
f"Expected module to be of type {cls._module_cls.__name__}, got {type(module)}"
f"Expected module to be of type {serializer._module_cls.__name__}, got {type(module)}"
)
return module


class EBCJsonSerializer(JsonSerializerBase):
class EBCJsonSerializer(JsonSerializer):
_module_cls = EmbeddingBagCollection

@classmethod
Expand All @@ -148,6 +178,7 @@ def deserialize_from_dict(
cls,
metadata_dict: Dict[str, Any],
device: Optional[torch.device] = None,
unflatten: Optional[nn.Module] = None,
) -> nn.Module:
tables = [
EmbeddingBagConfigMetadata(**table_config)
Expand All @@ -164,40 +195,4 @@ def deserialize_from_dict(
)


class JsonSerializer(SerializerInterface):
"""
Serializer for torch.export IR using json.
"""

module_to_serializer_cls: Dict[str, Type[SerializerInterface]] = {
"EmbeddingBagCollection": EBCJsonSerializer,
}

@classmethod
def serialize(
cls,
module: nn.Module,
) -> torch.Tensor:
typename = type(module).__name__
if typename not in cls.module_to_serializer_cls:
raise ValueError(
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
)

return cls.module_to_serializer_cls[typename].serialize(module)

@classmethod
def deserialize(
cls,
input: torch.Tensor,
typename: str,
device: Optional[torch.device] = None,
) -> nn.Module:
if typename not in cls.module_to_serializer_cls:
raise ValueError(
f"Expected typename to be one of {list(cls.module_to_serializer_cls.keys())}, got {typename}"
)

return cls.module_to_serializer_cls[typename].deserialize(
input, typename, device
)
JsonSerializer.module_to_serializer_cls["EmbeddingBagCollection"] = EBCJsonSerializer
48 changes: 47 additions & 1 deletion torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import copy
import unittest
from typing import Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -54,6 +54,41 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:
return res


class CompoundModuleSerializer(JsonSerializer):
_module_cls = CompoundModule

@classmethod
def children(cls, module: nn.Module) -> List[str]:
children = ["ebc", "list"]
if module.comp is not None:
children += ["comp"]
return children

@classmethod
def serialize_to_dict(
cls,
module: nn.Module,
) -> Dict[str, Any]:
return {}

@classmethod
def deserialize_from_dict(
cls,
metadata_dict: Dict[str, Any],
device: Optional[torch.device] = None,
unflatten: Optional[nn.Module] = None,
) -> nn.Module:
assert unflatten is not None
ebc = unflatten.ebc
comp = getattr(unflatten, "comp", None)
i = 0
mlist = []
while hasattr(unflatten.list, str(i)):
mlist.append(getattr(unflatten.list, str(i)))
i += 1
return CompoundModule(ebc, comp, mlist)


class TestJsonSerializer(unittest.TestCase):
def generate_model(self) -> nn.Module:
class Model(nn.Module):
Expand Down Expand Up @@ -328,6 +363,9 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:

eager_out = model(id_list_features)

JsonSerializer.module_to_serializer_cls["CompoundModule"] = (
CompoundModuleSerializer
)
# Serialize
model, sparse_fqns = serialize_embedding_modules(model, JsonSerializer)
ep = torch.export.export(
Expand All @@ -346,6 +384,14 @@ def forward(self, features: KeyedJaggedTensor) -> List[torch.Tensor]:

# Deserialize
deserialized_model = deserialize_embedding_modules(ep, JsonSerializer)

# Check if Compound Module is deserialized correctly
self.assertIsInstance(deserialized_model.comp, CompoundModule)
self.assertIsInstance(deserialized_model.comp.comp, CompoundModule)
self.assertIsInstance(deserialized_model.comp.comp.comp, CompoundModule)
self.assertIsInstance(deserialized_model.comp.list[1], CompoundModule)
self.assertIsInstance(deserialized_model.comp.list[1].comp, CompoundModule)

deserialized_model.load_state_dict(model.state_dict())
# Run forward on deserialized model
deserialized_out = deserialized_model(id_list_features)
Expand Down
17 changes: 7 additions & 10 deletions torchrec/ir/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#!/usr/bin/env python3

import abc
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, List, Optional, Tuple

import torch

Expand All @@ -24,28 +24,25 @@ class SerializerInterface(abc.ABC):

@classmethod
@property
# pyre-ignore [3]: Returning `None` but type `Any` is specified.
def module_to_serializer_cls(cls) -> Dict[str, Type[Any]]:
def module_to_serializer_cls(cls) -> Dict[str, Any]:
raise NotImplementedError

@classmethod
@abc.abstractmethod
# pyre-ignore [3]: Returning `None` but type `Any` is specified.
def serialize(
cls,
module: nn.Module,
) -> Any:
) -> Tuple[torch.Tensor, List[str]]:
# Take the eager embedding module and generate bytes in buffer
pass
raise NotImplementedError

@classmethod
@abc.abstractmethod
def deserialize(
cls,
# pyre-ignore [2]: Parameter `input` must have a type other than `Any`.
input: Any,
typename: str,
input: torch.Tensor,
device: Optional[torch.device] = None,
unflatten: Optional[nn.Module] = None,
) -> nn.Module:
# Take the bytes in the buffer and regenerate the eager embedding module
pass
raise NotImplementedError
82 changes: 42 additions & 40 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@


def serialize_embedding_modules(
model: nn.Module,
module: nn.Module,
serializer_cls: Type[SerializerInterface] = DEFAULT_SERIALIZER_CLS,
fqn: str = "",
) -> Tuple[nn.Module, List[str]]:
"""
Takes all the modules that are of type `serializer_cls` and serializes them
Expand All @@ -37,13 +38,46 @@ def serialize_embedding_modules(
Returns the modified module and the list of fqns that had the buffer added.
"""
preserve_fqns = []
for fqn, module in model.named_modules():
if type(module).__name__ in serializer_cls.module_to_serializer_cls:
serialized_module = serializer_cls.serialize(module)
module.register_buffer("ir_metadata", serialized_module, persistent=False)
preserve_fqns.append(fqn)

return model, preserve_fqns
# handle current module
if type(module).__name__ in serializer_cls.module_to_serializer_cls:
serialized_tensor, children = serializer_cls.serialize(module)
module.register_buffer("ir_metadata", serialized_tensor, persistent=False)
preserve_fqns.append(fqn)
else:
children = [child for child, _ in module.named_children()]

# handle child modules
for child in children:
submodule = module.get_submodule(child)
child_fqn = f"{fqn}.{child}" if len(fqn) > 0 else child
preserve_fqns.extend(
serialize_embedding_modules(submodule, serializer_cls, child_fqn)[1]
)
return module, preserve_fqns


def _deserialize_embedding_modules(
module: nn.Module,
serializer_cls: Type[SerializerInterface],
device: Optional[torch.device] = None,
) -> nn.Module:
"""
returns:
1. the children of the parent_fqn Dict[relative_fqn -> module]
2. the next node Optional[fqn], Optional[module], which is not a child of the parent_fqn
"""

for child_fqn, child in module.named_children():
child = _deserialize_embedding_modules(
module=child, serializer_cls=serializer_cls, device=device
)
setattr(module, child_fqn, child)

if "ir_metadata" in dict(module.named_buffers()):
serialized_tensor = module.get_buffer("ir_metadata")
module = serializer_cls.deserialize(serialized_tensor, device, module)
return module


def deserialize_embedding_modules(
Expand All @@ -59,39 +93,7 @@ def deserialize_embedding_modules(
Returns the unflattened ExportedProgram with the deserialized modules.
"""
model = torch.export.unflatten(ep)
module_type_dict = {}
for node in ep.graph.nodes:
if "nn_module_stack" in node.meta:
for fqn, type_name in node.meta["nn_module_stack"].values():
# Only get the module type name, not the full type name
module_type_dict[fqn] = type_name.split(".")[-1]

fqn_to_new_module = {}
for fqn, module in model.named_modules():
if "ir_metadata" in dict(module.named_buffers()):
serialized_module = dict(module.named_buffers())["ir_metadata"]

if fqn not in module_type_dict:
raise RuntimeError(
f"Cannot find the type of module {fqn} in the exported program"
)

deserialized_module = serializer_cls.deserialize(
serialized_module,
module_type_dict[fqn],
device,
)
fqn_to_new_module[fqn] = deserialized_module

for fqn, new_module in fqn_to_new_module.items():
# handle nested attribute like "x.y.z"
attrs = fqn.split(".")
parent = model
for a in attrs[:-1]:
parent = getattr(parent, a)
setattr(parent, attrs[-1], new_module)

return model
return _deserialize_embedding_modules(model, serializer_cls, device)


def _get_dim(x: Union[DIM, str, None], s: str, max: Optional[int] = None) -> DIM:
Expand Down