@@ -176,6 +176,45 @@ def forward(
176176
177177 return model
178178
179+ def generate_model_for_vbe_kjt (self ) -> nn .Module :
180+ class Model (nn .Module ):
181+ def __init__ (self , ebc ):
182+ super ().__init__ ()
183+ self .ebc1 = ebc
184+
185+ def forward (
186+ self ,
187+ features : KeyedJaggedTensor ,
188+ ) -> List [torch .Tensor ]:
189+ kt1 = self .ebc1 (features )
190+ res : List [torch .Tensor ] = []
191+
192+ for kt in [kt1 ]:
193+ res .extend (KeyedTensor .regroup ([kt ], [[key ] for key in kt .keys ()]))
194+
195+ return res
196+
197+ config1 = EmbeddingBagConfig (
198+ name = "t1" ,
199+ embedding_dim = 3 ,
200+ num_embeddings = 10 ,
201+ feature_names = ["f1" ],
202+ )
203+ config2 = EmbeddingBagConfig (
204+ name = "t2" ,
205+ embedding_dim = 4 ,
206+ num_embeddings = 10 ,
207+ feature_names = ["f2" ],
208+ )
209+ ebc = EmbeddingBagCollection (
210+ tables = [config1 , config2 ],
211+ is_weighted = False ,
212+ )
213+
214+ model = Model (ebc )
215+
216+ return model
217+
179218 def test_serialize_deserialize_ebc (self ) -> None :
180219 model = self .generate_model ()
181220 id_list_features = KeyedJaggedTensor .from_offsets_sync (
@@ -253,6 +292,88 @@ def test_serialize_deserialize_ebc(self) -> None:
253292 self .assertEqual (deserialized .shape , orginal .shape )
254293 self .assertTrue (torch .allclose (deserialized , orginal ))
255294
295+ @unittest .skip ("Adding test for demonstrating VBE KJT issue for now." )
296+ def test_serialize_deserialize_ebc_with_vbe_kjt (self ) -> None :
297+ model = self .generate_model_for_vbe_kjt ()
298+ id_list_features = KeyedJaggedTensor (
299+ keys = ["f1" , "f2" ],
300+ values = torch .tensor ([5 , 6 , 7 , 1 , 2 , 3 , 0 , 1 ]),
301+ lengths = torch .tensor ([3 , 3 , 2 ]),
302+ stride_per_key_per_rank = [[2 ], [1 ]],
303+ inverse_indices = (["f1" , "f2" ], torch .tensor ([[0 , 1 , 0 ], [0 , 0 , 0 ]])),
304+ )
305+
306+ eager_out = model (id_list_features )
307+
308+ print ("eager_out: " , eager_out )
309+
310+ # Serialize EBC
311+ model , sparse_fqns = encapsulate_ir_modules (model , JsonSerializer )
312+ ep = torch .export .export (
313+ model ,
314+ (id_list_features ,),
315+ {},
316+ strict = False ,
317+ # Allows KJT to not be unflattened and run a forward on unflattened EP
318+ preserve_module_call_signature = (tuple (sparse_fqns )),
319+ )
320+
321+ # Run forward on ExportedProgram
322+ ep_output = ep .module ()(id_list_features )
323+
324+ for i , tensor in enumerate (ep_output ):
325+ self .assertEqual (eager_out [i ].shape , tensor .shape )
326+
327+ # Deserialize EBC
328+ unflatten_ep = torch .export .unflatten (ep )
329+ deserialized_model = decapsulate_ir_modules (unflatten_ep , JsonSerializer )
330+
331+ # check EBC config
332+ for i in range (5 ):
333+ ebc_name = f"ebc{ i + 1 } "
334+ self .assertIsInstance (
335+ getattr (deserialized_model , ebc_name ), EmbeddingBagCollection
336+ )
337+
338+ for deserialized , orginal in zip (
339+ getattr (deserialized_model , ebc_name ).embedding_bag_configs (),
340+ getattr (model , ebc_name ).embedding_bag_configs (),
341+ ):
342+ self .assertEqual (deserialized .name , orginal .name )
343+ self .assertEqual (deserialized .embedding_dim , orginal .embedding_dim )
344+ self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
345+ self .assertEqual (deserialized .feature_names , orginal .feature_names )
346+
347+ # check FPEBC config
348+ for i in range (2 ):
349+ fpebc_name = f"fpebc{ i + 1 } "
350+ assert isinstance (
351+ getattr (deserialized_model , fpebc_name ),
352+ FeatureProcessedEmbeddingBagCollection ,
353+ )
354+
355+ for deserialized , orginal in zip (
356+ getattr (
357+ deserialized_model , fpebc_name
358+ )._embedding_bag_collection .embedding_bag_configs (),
359+ getattr (
360+ model , fpebc_name
361+ )._embedding_bag_collection .embedding_bag_configs (),
362+ ):
363+ self .assertEqual (deserialized .name , orginal .name )
364+ self .assertEqual (deserialized .embedding_dim , orginal .embedding_dim )
365+ self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
366+ self .assertEqual (deserialized .feature_names , orginal .feature_names )
367+
368+ # Run forward on deserialized model and compare the output
369+ deserialized_model .load_state_dict (model .state_dict ())
370+ deserialized_out = deserialized_model (id_list_features )
371+
372+ self .assertEqual (len (deserialized_out ), len (eager_out ))
373+ for deserialized , orginal in zip (deserialized_out , eager_out ):
374+ self .assertEqual (deserialized .shape , orginal .shape )
375+ self .assertTrue (torch .allclose (deserialized , orginal ))
376+
256377 def test_dynamic_shape_ebc_disabled_in_oss_compatibility (self ) -> None :
257378 model = self .generate_model ()
258379 feature1 = KeyedJaggedTensor .from_offsets_sync (
0 commit comments