diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index 67ad98994a..a5958d7f4f 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -640,16 +640,6 @@ def add_tensor_configvars(): in_c_key=False, ) - config.add( - "tensor__local_elemwise_fusion", - ( - "Enable or not in fast_run mode(fast_run optimization) the elemwise " - "fusion optimization" - ), - BoolParam(True), - in_c_key=False, - ) - # http://developer.amd.com/CPU/LIBRARIES/LIBM/Pages/default.aspx config.add( "lib__amblibm", diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index ded188ba0c..e42543f9e6 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -1085,38 +1085,10 @@ def print_profile(stream, prof, level=0): print(blanc, " time_toposort", prof[7], file=stream) -if config.tensor__local_elemwise_fusion: - # Must be after gpu(48.5) and before AddDestroyHandler(49.5) - fuse_seqopt = SequenceDB() - fuse_seqopt.register( - "local_add_mul_fusion", - EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000), - "fast_run", - "fusion", - position=0, - ) - fuse_seqopt.register( - "composite_elemwise_fusion", - FusionOptimizer(), - "fast_run", - "fusion", - position=1, - ) - compile.optdb.register( - "elemwise_fusion", - fuse_seqopt, - "fast_run", - "fusion", - "local_elemwise_fusion", - "FusionOptimizer", - position=49, - ) - - @register_canonicalize @register_specialize @node_rewriter([Elemwise]) -def local_useless_composite(fgraph, node): +def local_useless_composite_outputs(fgraph, node): """Remove inputs and outputs of Composite Ops that are not used anywhere.""" if not isinstance(node.op, Elemwise) or not isinstance( node.op.scalar_op, aes.Composite @@ -1231,11 +1203,45 @@ def local_careduce_fusion(fgraph, node): return [new_car_op(*elm_inputs)] +# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites) +fuse_seqopt = SequenceDB() compile.optdb.register( + "elemwise_fusion", + fuse_seqopt, + "fast_run", + "fusion", + "local_elemwise_fusion", + "FusionOptimizer", + position=49, +) + +fuse_seqopt.register( + "local_add_mul_fusion", + EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000), + "fast_run", + "fusion", + position=0, +) +fuse_seqopt.register( + "composite_elemwise_fusion", + FusionOptimizer(), + "fast_run", + "fusion", + position=1, +) +fuse_seqopt.register( + "local_useless_composite_outputs", + in2out(local_useless_composite_outputs), + "fast_run", + "fusion", + position=2, +) +fuse_seqopt.register( "local_careduce_fusion", in2out(local_careduce_fusion), + "fast_run", "fusion", - position=49, + position=10, ) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 99c375aaa7..ead6575eb3 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1425,39 +1425,40 @@ def test_nested_composite(self): fval = f([1, 2, 3]) assert np.all(fval == [6, 12, 18]) - def test_local_useless_composite(self): - x = aes.float32() - y = aes.float32() - z = aes.float32() - c = aes.Composite([x, y, z], [x + 1, y - 1]) - X = matrix("X") - Y = matrix("Y") - Z = matrix("Z") - o1, o2 = Elemwise(scalar_op=c)(X, Y, Z) - mode = get_default_mode().including("local_useless_composite") - - f = function([X, Y, Z], [o1, o2], mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert len(topo[0].inputs) == 2 - assert len(topo[0].outputs) == 2 - res1, res2 = f([[1.0]], [[1.0]], [[np.nan]]) - utt.assert_allclose(res1, [[2.0]]) - utt.assert_allclose(res2, [[0.0]]) - - f = function([X, Y, Z], o1, mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert len(topo[0].inputs) == 1 - assert len(topo[0].outputs) == 1 - utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]]) - f = function([X, Y, Z], o2, mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert len(topo[0].inputs) == 1 - assert len(topo[0].outputs) == 1 - utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]]) +def test_local_useless_composite_outputs(): + x = aes.float32() + y = aes.float32() + z = aes.float32() + c = aes.Composite([x, y, z], [x + 1, y - 1]) + X = matrix("X") + Y = matrix("Y") + Z = matrix("Z") + o1, o2 = Elemwise(scalar_op=c)(X, Y, Z) + mode = get_default_mode().including("local_useless_composite") + + f = function([X, Y, Z], [o1, o2], mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert len(topo[0].inputs) == 2 + assert len(topo[0].outputs) == 2 + res1, res2 = f([[1.0]], [[1.0]], [[np.nan]]) + utt.assert_allclose(res1, [[2.0]]) + utt.assert_allclose(res2, [[0.0]]) + + f = function([X, Y, Z], o1, mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert len(topo[0].inputs) == 1 + assert len(topo[0].outputs) == 1 + utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]]) + + f = function([X, Y, Z], o2, mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert len(topo[0].inputs) == 1 + assert len(topo[0].outputs) == 1 + utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]]) def test_local_useless_dimshuffle_makevector():