@@ -1113,6 +1113,78 @@ def test_test_values(self, test_value):
1113
1113
f .maker .fgraph .outputs [0 ].tag .test_value , np .c_ [[2.0 ]]
1114
1114
)
1115
1115
1116
+ @pytest .mark .parametrize ("linker" , ["cvm" , "py" ])
1117
+ @pytest .mark .parametrize ("axis" , [None , 0 , 1 , (0 , 1 ), (0 , 1 , 2 )])
1118
+ def test_CAReduce_single_input (self , linker , axis ):
1119
+ """Make sure that `CAReduce` and `Elemwise` fusions work with a single input."""
1120
+
1121
+ mode = Mode (linker = linker )
1122
+ mode ._optimizer = mode ._optimizer .including (
1123
+ "local_careduce_fusion" ,
1124
+ "canonicalize" ,
1125
+ "inplace" ,
1126
+ )
1127
+
1128
+ x = tensor ("floatX" , shape = (None , None , None ), name = "x" )
1129
+ out = exp (x ).sum (axis = axis )
1130
+
1131
+ out_fn = function ([x ], out , mode = mode )
1132
+ (out_node ,) = out_fn .maker .fgraph .toposort ()
1133
+
1134
+ assert isinstance (getattr (out_node .op , "scalar_op" ), aes .basic .Composite )
1135
+
1136
+ rng = np .random .default_rng (2320 )
1137
+ x_val = rng .random ((4 , 3 , 2 ), dtype = config .floatX )
1138
+
1139
+ exp_res = np .exp (x_val ).sum (axis = axis )
1140
+
1141
+ out_val = out_fn (x_val )
1142
+ assert out_val .shape == exp_res .shape
1143
+ assert np .allclose (out_val , exp_res )
1144
+
1145
+ # `Elemwise`s with more than one client shouldn't be rewritten
1146
+ x = tensor ("floatX" , shape = (None , None , None ), name = "x" )
1147
+ exp_x = exp (x )
1148
+ out = exp_x .sum (axis = axis ) + exp (x )
1149
+
1150
+ out_fn = function ([x ], out , mode = mode )
1151
+ out_nodes = out_fn .maker .fgraph .toposort ()
1152
+ assert not any (
1153
+ isinstance (out_node .op .scalar_op , aes .basic .Composite )
1154
+ for out_node in out_nodes
1155
+ if hasattr (out_node .op , "scalar_op" )
1156
+ )
1157
+
1158
+ @pytest .mark .xfail (reason = "Not implemented" )
1159
+ @pytest .mark .parametrize ("linker" , ["cvm" , "py" ])
1160
+ @pytest .mark .parametrize ("axis" , [None , 0 , 1 , (0 , 1 ), (0 , 1 , 2 )])
1161
+ def test_CAReduce_multiple_inputs (self , linker , axis ):
1162
+ """Make sure that `CAReduce` and `Elemwise` fusions work with multiple inputs."""
1163
+
1164
+ mode = Mode (linker = linker )
1165
+ mode ._optimizer = mode ._optimizer .including (
1166
+ "local_careduce_fusion" ,
1167
+ "canonicalize" ,
1168
+ "inplace" ,
1169
+ )
1170
+
1171
+ x = tensor ("floatX" , shape = (None , None , None ), name = "x" )
1172
+ y = tensor ("floatX" , shape = (None , None , None ), name = "y" )
1173
+ out = (x + y ).sum (axis = axis )
1174
+
1175
+ out_fn = function ([x , y ], out , mode = mode )
1176
+ (out_node ,) = out_fn .maker .fgraph .toposort ()
1177
+
1178
+ assert isinstance (getattr (out_node .op , "scalar_op" ), aes .basic .Composite )
1179
+
1180
+ rng = np .random .default_rng (2320 )
1181
+ x_val = rng .random ((4 , 3 , 2 ), dtype = config .floatX )
1182
+ y_val = rng .random ((4 , 3 , 2 ), dtype = config .floatX )
1183
+ exp_res = (x_val + y_val ).sum (axis = axis )
1184
+ out_val = out_fn (x_val , y_val )
1185
+ assert out_val .shape == exp_res .shape
1186
+ assert np .allclose (out_val , exp_res )
1187
+
1116
1188
1117
1189
class TimesN (aes .basic .UnaryScalarOp ):
1118
1190
"""
0 commit comments