Skip to content

Commit 1b7050f

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Add test case for exporting EBC with VBE KJT
Summary: astitled Differential Revision: D73454558
1 parent 9f0bd7e commit 1b7050f

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,45 @@ def forward(
176176

177177
return model
178178

179+
def generate_model_for_vbe_kjt(self) -> nn.Module:
180+
class Model(nn.Module):
181+
def __init__(self, ebc):
182+
super().__init__()
183+
self.ebc1 = ebc
184+
185+
def forward(
186+
self,
187+
features: KeyedJaggedTensor,
188+
) -> List[torch.Tensor]:
189+
kt1 = self.ebc1(features)
190+
res: List[torch.Tensor] = []
191+
192+
for kt in [kt1]:
193+
res.extend(KeyedTensor.regroup([kt], [[key] for key in kt.keys()]))
194+
195+
return res
196+
197+
config1 = EmbeddingBagConfig(
198+
name="t1",
199+
embedding_dim=3,
200+
num_embeddings=10,
201+
feature_names=["f1"],
202+
)
203+
config2 = EmbeddingBagConfig(
204+
name="t2",
205+
embedding_dim=4,
206+
num_embeddings=10,
207+
feature_names=["f2"],
208+
)
209+
ebc = EmbeddingBagCollection(
210+
tables=[config1, config2],
211+
is_weighted=False,
212+
)
213+
214+
model = Model(ebc)
215+
216+
return model
217+
179218
def test_serialize_deserialize_ebc(self) -> None:
180219
model = self.generate_model()
181220
id_list_features = KeyedJaggedTensor.from_offsets_sync(
@@ -253,6 +292,88 @@ def test_serialize_deserialize_ebc(self) -> None:
253292
self.assertEqual(deserialized.shape, orginal.shape)
254293
self.assertTrue(torch.allclose(deserialized, orginal))
255294

295+
@unittest.skip("Adding test for demonstrating VBE KJT issue for now.")
296+
def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
297+
model = self.generate_model_for_vbe_kjt()
298+
id_list_features = KeyedJaggedTensor(
299+
keys=["f1", "f2"],
300+
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
301+
lengths=torch.tensor([3, 3, 2]),
302+
stride_per_key_per_rank=[[2], [1]],
303+
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
304+
)
305+
306+
eager_out = model(id_list_features)
307+
308+
print("eager_out: ", eager_out)
309+
310+
# Serialize EBC
311+
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
312+
ep = torch.export.export(
313+
model,
314+
(id_list_features,),
315+
{},
316+
strict=False,
317+
# Allows KJT to not be unflattened and run a forward on unflattened EP
318+
preserve_module_call_signature=(tuple(sparse_fqns)),
319+
)
320+
321+
# Run forward on ExportedProgram
322+
ep_output = ep.module()(id_list_features)
323+
324+
for i, tensor in enumerate(ep_output):
325+
self.assertEqual(eager_out[i].shape, tensor.shape)
326+
327+
# Deserialize EBC
328+
unflatten_ep = torch.export.unflatten(ep)
329+
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
330+
331+
# check EBC config
332+
for i in range(5):
333+
ebc_name = f"ebc{i + 1}"
334+
self.assertIsInstance(
335+
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
336+
)
337+
338+
for deserialized, orginal in zip(
339+
getattr(deserialized_model, ebc_name).embedding_bag_configs(),
340+
getattr(model, ebc_name).embedding_bag_configs(),
341+
):
342+
self.assertEqual(deserialized.name, orginal.name)
343+
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
344+
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
345+
self.assertEqual(deserialized.feature_names, orginal.feature_names)
346+
347+
# check FPEBC config
348+
for i in range(2):
349+
fpebc_name = f"fpebc{i + 1}"
350+
assert isinstance(
351+
getattr(deserialized_model, fpebc_name),
352+
FeatureProcessedEmbeddingBagCollection,
353+
)
354+
355+
for deserialized, orginal in zip(
356+
getattr(
357+
deserialized_model, fpebc_name
358+
)._embedding_bag_collection.embedding_bag_configs(),
359+
getattr(
360+
model, fpebc_name
361+
)._embedding_bag_collection.embedding_bag_configs(),
362+
):
363+
self.assertEqual(deserialized.name, orginal.name)
364+
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
365+
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
366+
self.assertEqual(deserialized.feature_names, orginal.feature_names)
367+
368+
# Run forward on deserialized model and compare the output
369+
deserialized_model.load_state_dict(model.state_dict())
370+
deserialized_out = deserialized_model(id_list_features)
371+
372+
self.assertEqual(len(deserialized_out), len(eager_out))
373+
for deserialized, orginal in zip(deserialized_out, eager_out):
374+
self.assertEqual(deserialized.shape, orginal.shape)
375+
self.assertTrue(torch.allclose(deserialized, orginal))
376+
256377
def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
257378
model = self.generate_model()
258379
feature1 = KeyedJaggedTensor.from_offsets_sync(

0 commit comments

Comments
 (0)