@@ -251,5 +251,87 @@ def test_not_lifting_constants_to_initializers_when_it_is_output(self):
251
251
self .assertEqual (len (result .model .graph .initializers ), 0 )
252
252
253
253
254
+ class TestLiftSubgraphInitializersToMainGraphPass (unittest .TestCase ):
255
+ @parameterized .parameterized .expand (
256
+ [
257
+ ("then_initializer" , "else_initializer" ),
258
+ ("initializer" , "initializer" ),
259
+ ]
260
+ )
261
+ def test_pass_with_lifting_constants_to_initializers_within_subgraph (
262
+ self , then_initializer_name , else_initializer_name
263
+ ):
264
+ input_value = ir .Value (
265
+ name = "input" , type = ir .TensorType (ir .DataType .FLOAT ), shape = ir .Shape ((2 , 3 ))
266
+ )
267
+
268
+ then_initializer_tensor = ir .tensor (np .random .rand (2 , 3 ).astype (np .float32 ))
269
+ then_initializer_value = ir .Value (
270
+ name = then_initializer_name ,
271
+ shape = then_initializer_tensor .shape ,
272
+ type = ir .TensorType (ir .DataType .FLOAT ),
273
+ const_value = then_initializer_tensor ,
274
+ )
275
+
276
+ # then branch adds the constant to the input
277
+ # else branch multiplies the input by the constant
278
+ add_node = ir .node ("Add" , inputs = [input_value , then_initializer_value ])
279
+ then_graph = ir .Graph (
280
+ inputs = [input_value , then_initializer_value ],
281
+ outputs = [add_node .outputs [0 ]],
282
+ nodes = [add_node ],
283
+ opset_imports = {"" : 20 },
284
+ initializers = [then_initializer_value ],
285
+ )
286
+ else_initializer_tensor = ir .tensor (np .random .rand (2 , 3 ).astype (np .float32 ))
287
+ else_initializer_value = ir .Value (
288
+ name = else_initializer_name ,
289
+ shape = else_initializer_tensor .shape ,
290
+ type = ir .TensorType (ir .DataType .FLOAT ),
291
+ const_value = else_initializer_tensor ,
292
+ )
293
+ mul_node = ir .node ("Mul" , inputs = [input_value , else_initializer_value ])
294
+ else_graph = ir .Graph (
295
+ inputs = [input_value ],
296
+ outputs = [mul_node .outputs [0 ]],
297
+ nodes = [mul_node ],
298
+ opset_imports = {"" : 20 },
299
+ initializers = [else_initializer_value ],
300
+ )
301
+ # create a conditional node that uses the then and else graphs
302
+ cond_node = ir .node (
303
+ "If" ,
304
+ inputs = [input_value ],
305
+ attributes = {"then_branch" : then_graph , "else_branch" : else_graph },
306
+ num_outputs = 1 ,
307
+ )
308
+ # construnct the model
309
+ main_graph = ir .Graph (
310
+ inputs = [input_value ],
311
+ outputs = cond_node .outputs ,
312
+ nodes = [cond_node ],
313
+ opset_imports = {"" : 20 },
314
+ )
315
+ main_graph .sort ()
316
+ model = ir .Model (
317
+ graph = main_graph ,
318
+ ir_version = 10 ,
319
+ )
320
+ result = constant_manipulation .LiftSubgraphInitializersToMainGraphPass ()(model )
321
+ self .assertTrue (result .modified )
322
+
323
+ self .assertEqual (len (else_graph .initializers ), 0 )
324
+ self .assertEqual (len (then_graph .initializers ), 0 )
325
+ self .assertEqual (len (main_graph .initializers ), 2 )
326
+ for value , tensor in zip (
327
+ main_graph .initializers .values (),
328
+ [then_initializer_tensor , else_initializer_tensor ],
329
+ ):
330
+ self .assertIs (
331
+ value .const_value ,
332
+ tensor ,
333
+ )
334
+
335
+
254
336
if __name__ == "__main__" :
255
337
unittest .main ()
0 commit comments