@@ -33,12 +33,12 @@ def main(
3333 ) -> R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ):
3434 cls = Conv2dReLUx2
3535 with R .dataflow ():
36- lv : R .Tensor (
37- ( 1 , 64 , 56 , 56 ), dtype = "float32"
38- ) = cls . fused_relax_nn_conv2d_relax_nn_relu ( data , weight1 )
39- gv : R .Tensor (
40- ( 1 , 64 , 54 , 54 ), dtype = "float32"
41- ) = cls . fused_relax_nn_conv2d_relax_nn_relu1 ( lv , weight2 )
36+ lv : R .Tensor (( 1 , 64 , 56 , 56 ), dtype = "float32" ) = (
37+ cls . fused_relax_nn_conv2d_relax_nn_relu ( data , weight1 )
38+ )
39+ gv : R .Tensor (( 1 , 64 , 54 , 54 ), dtype = "float32" ) = (
40+ cls . fused_relax_nn_conv2d_relax_nn_relu1 ( lv , weight2 )
41+ )
4242 R .output (gv )
4343 return gv
4444
@@ -85,10 +85,10 @@ def main(
8585 ) -> R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ):
8686 cls = Conv2dReLUx2_merged
8787 with R .dataflow ():
88- gv : R .Tensor (
89- ( 1 , 64 , 54 , 54 ), dtype = "float32"
90- ) = cls . fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1_dnnl (
91- data , weight1 , weight2
88+ gv : R .Tensor (( 1 , 64 , 54 , 54 ), dtype = "float32" ) = (
89+ cls . fused_relax_nn_conv2d_relax_nn_relu_relax_nn_conv2d_relax_nn_relu1_dnnl (
90+ data , weight1 , weight2
91+ )
9292 )
9393 R .output (gv )
9494 return gv
@@ -159,7 +159,7 @@ def main(
159159
160160 @R .function (private = True )
161161 def fused_relax_nn_gelu (
162- lv : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" )
162+ lv : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ),
163163 ) -> R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ):
164164 R .func_attr ({"Primitive" : 1 , "Composite" : "compiler_A.gelu" })
165165 with R .dataflow ():
@@ -169,7 +169,7 @@ def fused_relax_nn_gelu(
169169
170170 @R .function (private = True )
171171 def fused_relax_nn_relu (
172- lv1 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" )
172+ lv1 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ),
173173 ) -> R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ):
174174 R .func_attr ({"Primitive" : 1 , "Composite" : "compiler_A.relu" })
175175 with R .dataflow ():
@@ -243,7 +243,7 @@ def lv(
243243
244244 @R .function
245245 def lv1 (
246- lv11 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" )
246+ lv11 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ),
247247 ) -> R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ):
248248 # function attr dict
249249 R .func_attr ({"Composite" : "compiler_A.relu" })
@@ -257,7 +257,7 @@ def lv1(
257257
258258 @R .function
259259 def lv21 (
260- lv4 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" )
260+ lv4 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ),
261261 ) -> R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ):
262262 # function attr dict
263263 R .func_attr ({"Composite" : "compiler_A.gelu" })
@@ -292,10 +292,10 @@ def main(
292292 ) -> R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ):
293293 cls = Diamond_merged
294294 with R .dataflow ():
295- gv5 : R .Tensor (
296- ( 1 , 64 , 54 , 54 ), dtype = "float32"
297- ) = cls . fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A (
298- data2 , weight2
295+ gv5 : R .Tensor (( 1 , 64 , 54 , 54 ), dtype = "float32" ) = (
296+ cls . fused_relax_nn_conv2d_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A (
297+ data2 , weight2
298+ )
299299 )
300300 R .output (gv5 )
301301 return gv5
@@ -321,7 +321,7 @@ def main(
321321
322322 @R .function (private = True )
323323 def fused_relax_nn_gelu (
324- lv : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" )
324+ lv : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ),
325325 ) -> R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ):
326326 R .func_attr ({"Primitive" : 1 , "Composite" : "compiler_B.gelu" })
327327 with R .dataflow ():
@@ -331,7 +331,7 @@ def fused_relax_nn_gelu(
331331
332332 @R .function (private = True )
333333 def fused_relax_nn_relu (
334- lv1 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" )
334+ lv1 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ),
335335 ) -> R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ):
336336 R .func_attr ({"Primitive" : 1 , "Composite" : "compiler_A.relu" })
337337 with R .dataflow ():
@@ -418,7 +418,7 @@ def lv(
418418
419419 @R .function
420420 def lv1 (
421- lv11 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" )
421+ lv11 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ),
422422 ) -> R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ):
423423 R .func_attr ({"Composite" : "compiler_A.relu" })
424424 with R .dataflow ():
@@ -432,13 +432,13 @@ def lv1(
432432
433433 @R .function
434434 def fused_relax_nn_gelu1_compiler_B (
435- lv2 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" )
435+ lv2 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ),
436436 ) -> R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ):
437437 R .func_attr ({"Codegen" : "compiler_B" })
438438
439439 @R .function
440440 def lv21 (
441- lv3 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" )
441+ lv3 : R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ),
442442 ) -> R .Tensor ((1 , 64 , 54 , 54 ), dtype = "float32" ):
443443 R .func_attr ({"Composite" : "compiler_B.gelu" })
444444 with R .dataflow ():
@@ -489,7 +489,7 @@ def main(
489489
490490 @R .function (private = True )
491491 def fused_relax_nn_relu (
492- x11 : R .Tensor ((10 ,), dtype = "float32" )
492+ x11 : R .Tensor ((10 ,), dtype = "float32" ),
493493 ) -> R .Tensor ((10 ,), dtype = "float32" ):
494494 R .func_attr ({"Primitive" : 1 , "Composite" : "compiler_A.relu" })
495495 with R .dataflow ():
@@ -499,7 +499,7 @@ def fused_relax_nn_relu(
499499
500500 @R .function (private = True )
501501 def fused_relax_nn_gelu (
502- x21 : R .Tensor ((10 ,), dtype = "float32" )
502+ x21 : R .Tensor ((10 ,), dtype = "float32" ),
503503 ) -> R .Tensor ((10 ,), dtype = "float32" ):
504504 R .func_attr ({"Primitive" : 1 , "Composite" : "compiler_A.gelu" })
505505 with R .dataflow ():
@@ -575,10 +575,10 @@ def main(
575575 ) -> R .Tensor ((10 ,), dtype = "float32" ):
576576 cls = MultipleProducers_merged
577577 with R .dataflow ():
578- gv4 : R .Tensor (
579- ( 10 ,), dtype = "float32"
580- ) = cls . fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A (
581- x12 , x22
578+ gv4 : R .Tensor (( 10 ,), dtype = "float32" ) = (
579+ cls . fused_relax_nn_relu_relax_nn_gelu_relax_nn_relu_relax_nn_gelu_relax_add_compiler_A (
580+ x12 , x22
581+ )
582582 )
583583 R .output (gv4 )
584584 return gv4
@@ -599,7 +599,7 @@ def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32
599599
600600 @R .function (private = True )
601601 def fused_relax_nn_relu (
602- x11 : R .Tensor ((10 ,), dtype = "float32" )
602+ x11 : R .Tensor ((10 ,), dtype = "float32" ),
603603 ) -> R .Tensor ((10 ,), dtype = "float32" ):
604604 R .func_attr ({"Primitive" : 1 , "Composite" : "compiler_A.relu" })
605605 with R .dataflow ():
@@ -609,7 +609,7 @@ def fused_relax_nn_relu(
609609
610610 @R .function (private = True )
611611 def fused_relax_nn_gelu (
612- x21 : R .Tensor ((10 ,), dtype = "float32" )
612+ x21 : R .Tensor ((10 ,), dtype = "float32" ),
613613 ) -> R .Tensor ((10 ,), dtype = "float32" ):
614614 R .func_attr ({"Primitive" : 1 , "Composite" : "compiler_A.gelu" })
615615 with R .dataflow ():
@@ -644,7 +644,7 @@ def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32
644644
645645 @R .function
646646 def fused_relax_nn_relu1_compiler_A (
647- x11 : R .Tensor ((10 ,), dtype = "float32" )
647+ x11 : R .Tensor ((10 ,), dtype = "float32" ),
648648 ) -> R .Tensor ((10 ,), dtype = "float32" ):
649649 # function attr dict
650650 R .func_attr ({"Codegen" : "compiler_A" })
@@ -722,7 +722,7 @@ def main(
722722
723723 @R .function (private = True )
724724 def fused_relax_nn_relu (
725- add2 : R .Tensor ((10 ,), dtype = "float32" )
725+ add2 : R .Tensor ((10 ,), dtype = "float32" ),
726726 ) -> R .Tensor ((10 ,), dtype = "float32" ):
727727 R .func_attr ({"Primitive" : 1 , "Composite" : "compiler_A.relu" })
728728 with R .dataflow ():
@@ -742,7 +742,7 @@ def fused_relax_add(
742742
743743 @R .function (private = True )
744744 def fused_relax_nn_gelu (
745- x31 : R .Tensor ((10 ,), dtype = "float32" )
745+ x31 : R .Tensor ((10 ,), dtype = "float32" ),
746746 ) -> R .Tensor ((10 ,), dtype = "float32" ):
747747 R .func_attr ({"Primitive" : 1 , "Composite" : "compiler_B.gelu" })
748748 with R .dataflow ():
@@ -817,7 +817,7 @@ def lv31(add2: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float
817817
818818 @R .function
819819 def fused_relax_nn_gelu1_compiler_B (
820- x3 : R .Tensor ((10 ,), dtype = "float32" )
820+ x3 : R .Tensor ((10 ,), dtype = "float32" ),
821821 ) -> R .Tensor ((10 ,), dtype = "float32" ):
822822 R .func_attr ({"Codegen" : "compiler_B" })
823823
@@ -841,9 +841,9 @@ def main(
841841 cls = MergeCompilerRegionsExampleRef
842842 with R .dataflow ():
843843 lv5 : R .Tensor ((10 ,), dtype = "float32" ) = cls .fused_relax_nn_gelu1_compiler_B (x32 )
844- lv13 : R .Tuple (
845- R . Tensor (( 10 ,), dtype = "float32" ), R . Tensor (( 10 ,), dtype = "float32" )
846- ) = cls . fused_relax_add_relax_add_relax_nn_relu_compiler_A ( x12 , x22 , lv5 )
844+ lv13 : R .Tuple (R . Tensor (( 10 ,), dtype = "float32" ), R . Tensor (( 10 ,), dtype = "float32" )) = (
845+ cls . fused_relax_add_relax_add_relax_nn_relu_compiler_A ( x12 , x22 , lv5 )
846+ )
847847 lv23 : R .Tensor ((10 ,), dtype = "float32" ) = lv13 [0 ]
848848 lv32 : R .Tensor ((10 ,), dtype = "float32" ) = lv13 [1 ]
849849 lv41 : R .Tensor ((10 ,), dtype = "float32" ) = cls .fused_relax_nn_gelu1_compiler_B (lv23 )
@@ -1097,9 +1097,9 @@ def main(
10971097 lv1 : R .Tensor ((784 , 512 ), dtype = "float32" ) = R .permute_dims (
10981098 linear_relu_stack_0_weight , axes = None
10991099 )
1100- gv : R .Tensor (
1101- ( 1 , 512 ), dtype = "float32"
1102- ) = cls . fused_relax_reshape_relax_matmul_tensorrt ( inp_0 , lv1 )
1100+ gv : R .Tensor (( 1 , 512 ), dtype = "float32" ) = (
1101+ cls . fused_relax_reshape_relax_matmul_tensorrt ( inp_0 , lv1 )
1102+ )
11031103 R .output (gv )
11041104 return gv
11051105
@@ -1130,7 +1130,7 @@ def main(A: R.Tensor([10], dtype="float32")) -> R.Tensor([10], dtype="float32"):
11301130
11311131 @R .function (private = True )
11321132 def fused_relax_nn_relu (
1133- Input : R .Tensor ([10 ], dtype = "float32" )
1133+ Input : R .Tensor ([10 ], dtype = "float32" ),
11341134 ) -> R .Tensor ([10 ], dtype = "float32" ):
11351135 R .func_attr ({"Composite" : "compiler_A.relu" , "Primitive" : 1 })
11361136 with R .dataflow ():
@@ -1151,7 +1151,7 @@ def relu(
11511151
11521152 @R .function (private = True )
11531153 def fused_relax_nn_gelu (
1154- Input : R .Tensor ([10 ], dtype = "float32" )
1154+ Input : R .Tensor ([10 ], dtype = "float32" ),
11551155 ) -> R .Tensor ([10 ], dtype = "float32" ):
11561156 R .func_attr ({"Composite" : "compiler_A.gelu" , "Primitive" : 1 })
11571157 with R .dataflow ():
@@ -1173,13 +1173,13 @@ def main(A: R.Tensor([10], dtype="float32")) -> R.Tensor([10], dtype="float32"):
11731173
11741174 @R .function
11751175 def fused_relax_nn_relu1_compiler_A (
1176- Input : R .Tensor ([10 ], dtype = "float32" )
1176+ Input : R .Tensor ([10 ], dtype = "float32" ),
11771177 ) -> R .Tensor ([10 ], dtype = "float32" ):
11781178 R .func_attr ({"Codegen" : "compiler_A" })
11791179
11801180 @R .function
11811181 def composite_lambda (
1182- Input : R .Tensor ([10 ], dtype = "float32" )
1182+ Input : R .Tensor ([10 ], dtype = "float32" ),
11831183 ) -> R .Tensor ([10 ], dtype = "float32" ):
11841184 R .func_attr ({"Composite" : "compiler_A.relu" })
11851185 with R .dataflow ():
@@ -1203,13 +1203,13 @@ def relu(
12031203
12041204 @R .function
12051205 def fused_relax_nn_gelu1_compiler_A (
1206- Input : R .Tensor ([10 ], dtype = "float32" )
1206+ Input : R .Tensor ([10 ], dtype = "float32" ),
12071207 ) -> R .Tensor ([10 ], dtype = "float32" ):
12081208 R .func_attr ({"Codegen" : "compiler_A" })
12091209
12101210 @R .function
12111211 def composite_lambda (
1212- Input : R .Tensor ([10 ], dtype = "float32" )
1212+ Input : R .Tensor ([10 ], dtype = "float32" ),
12131213 ) -> R .Tensor ([10 ], dtype = "float32" ):
12141214 R .func_attr ({"Composite" : "compiler_A.gelu" })
12151215 with R .dataflow ():
0 commit comments