@@ -4014,6 +4014,161 @@ def test_local_sumsqr2dot():
40144014 )
40154015
40164016
4017+ def test_local_mul_exp_to_exp_add ():
4018+ # Default and FAST_RUN modes put a Composite op into the final graph,
4019+ # whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs,
4020+ # we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites
4021+ mode = get_default_mode ().excluding ("fusion" ).including ("local_mul_exp_to_exp_add" )
4022+
4023+ x = scalar ("x" )
4024+ y = scalar ("y" )
4025+ z = scalar ("z" )
4026+ w = scalar ("w" )
4027+ expx = exp (x )
4028+ expy = exp (y )
4029+ expz = exp (z )
4030+ expw = exp (w )
4031+
4032+ # e^x * e^y * e^z * e^w = e^(x+y+z+w)
4033+ op = expx * expy * expz * expw
4034+ f = function ([x , y , z , w ], op , mode )
4035+ pytensor .dprint (f )
4036+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), np .exp (3 + 4 + 5 + 6 ))
4037+ graph = f .maker .fgraph .toposort ()
4038+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4039+ assert any (isinstance (n .op .scalar_op , aes .Add ) for n in graph )
4040+ assert not any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4041+
4042+ # e^x * e^y * e^z / e^w = e^(x+y+z-w)
4043+ op = expx * expy * expz / expw
4044+ f = function ([x , y , z , w ], op , mode )
4045+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), np .exp (3 + 4 + 5 - 6 ))
4046+ graph = f .maker .fgraph .toposort ()
4047+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4048+ assert any (isinstance (n .op .scalar_op , aes .Add ) for n in graph )
4049+ assert any (isinstance (n .op .scalar_op , aes .Sub ) for n in graph )
4050+ assert not any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4051+ assert not any (isinstance (n .op .scalar_op , aes .TrueDiv ) for n in graph )
4052+
4053+ # e^x * e^y / e^z * e^w = e^(x+y-z+w)
4054+ op = expx * expy / expz * expw
4055+ f = function ([x , y , z , w ], op , mode )
4056+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), np .exp (3 + 4 - 5 + 6 ))
4057+ graph = f .maker .fgraph .toposort ()
4058+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4059+ assert any (isinstance (n .op .scalar_op , aes .Add ) for n in graph )
4060+ assert any (isinstance (n .op .scalar_op , aes .Sub ) for n in graph )
4061+ assert not any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4062+ assert not any (isinstance (n .op .scalar_op , aes .TrueDiv ) for n in graph )
4063+
4064+ # e^x / e^y / e^z = (e^x / e^y) / e^z = e^(x-y-z)
4065+ op = expx / expy / expz
4066+ f = function ([x , y , z ], op , mode )
4067+ utt .assert_allclose (f (3 , 4 , 5 ), np .exp (3 - 4 - 5 ))
4068+ graph = f .maker .fgraph .toposort ()
4069+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4070+ assert any (isinstance (n .op .scalar_op , aes .Sub ) for n in graph )
4071+ assert not any (isinstance (n .op .scalar_op , aes .TrueDiv ) for n in graph )
4072+
4073+ # e^x * y * e^z * w = e^(x+z) * y * w
4074+ op = expx * y * expz * w
4075+ f = function ([x , y , z , w ], op , mode )
4076+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), np .exp (3 + 5 ) * 4 * 6 )
4077+ graph = f .maker .fgraph .toposort ()
4078+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4079+ assert any (isinstance (n .op .scalar_op , aes .Add ) for n in graph )
4080+ assert any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4081+
4082+ # expect same for matrices as well
4083+ mx = matrix ("mx" )
4084+ my = matrix ("my" )
4085+ f = function ([mx , my ], exp (mx ) * exp (my ), mode , allow_input_downcast = True )
4086+ M1 = np .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
4087+ M2 = np .array ([[5.0 , 6.0 ], [7.0 , 8.0 ]])
4088+ utt .assert_allclose (f (M1 , M2 ), np .exp (M1 + M2 ))
4089+ graph = f .maker .fgraph .toposort ()
4090+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4091+ assert any (isinstance (n .op .scalar_op , aes .Add ) for n in graph )
4092+ assert not any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4093+
4094+ # checking whether further rewrites can proceed after this one as one would expect
4095+ # e^x * e^(-x) = e^(x-x) = e^0 = 1
4096+ f = function ([x ], expx * exp (neg (x )), mode )
4097+ utt .assert_allclose (f (42 ), 1 )
4098+ graph = f .maker .fgraph .toposort ()
4099+ assert isinstance (graph [0 ].inputs [0 ], TensorConstant )
4100+
4101+ # e^x / e^x = e^(x-x) = e^0 = 1
4102+ f = function ([x ], expx / expx , mode )
4103+ utt .assert_allclose (f (42 ), 1 )
4104+ graph = f .maker .fgraph .toposort ()
4105+ assert isinstance (graph [0 ].inputs [0 ], TensorConstant )
4106+
4107+
4108+ def test_local_mul_pow_to_pow_add ():
4109+ # Default and FAST_RUN modes put a Composite op into the final graph,
4110+ # whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs,
4111+ # we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites
4112+ mode = (
4113+ get_default_mode ()
4114+ .excluding ("fusion" )
4115+ .including ("local_mul_exp_to_exp_add" )
4116+ .including ("local_mul_pow_to_pow_add" )
4117+ )
4118+
4119+ x = scalar ("x" )
4120+ y = scalar ("y" )
4121+ z = scalar ("z" )
4122+ w = scalar ("w" )
4123+ v = scalar ("v" )
4124+ u = scalar ("u" )
4125+ t = scalar ("t" )
4126+ s = scalar ("s" )
4127+ a = scalar ("a" )
4128+ b = scalar ("b" )
4129+ c = scalar ("c" )
4130+
4131+ # 2^x * 2^y * 2^z * 2^w = 2^(x+y+z+w)
4132+ op = 2 ** x * 2 ** y * 2 ** z * 2 ** w
4133+ f = function ([x , y , z , w ], op , mode )
4134+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), 2 ** (3 + 4 + 5 + 6 ))
4135+ graph = f .maker .fgraph .toposort ()
4136+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4137+ assert any (isinstance (n .op .scalar_op , aes .Add ) for n in graph )
4138+ assert not any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4139+
4140+ # 2^x * a^y * 2^z * b^w * c^v * a^u * s * b^t = 2^(x+z) * a^(y+u) * b^(w+t) * c^v * s
4141+ op = 2 ** x * a ** y * 2 ** z * b ** w * c ** v * a ** u * s * b ** t
4142+ f = function ([x , y , z , w , v , u , t , s , a , b , c ], op , mode )
4143+ utt .assert_allclose (
4144+ f (4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 2.5 , 3 , 3.5 ),
4145+ 2 ** (4 + 6 ) * 2.5 ** (5 + 9 ) * 3 ** (7 + 10 ) * 3.5 ** 8 * 11 ,
4146+ )
4147+ graph = f .maker .fgraph .toposort ()
4148+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4149+ assert len ([True for n in graph if isinstance (n .op .scalar_op , aes .Add )]) == 3
4150+ assert len ([True for n in graph if isinstance (n .op .scalar_op , aes .Pow )]) == 4
4151+ assert any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4152+
4153+ # (2^x / 2^y) * (a^z / a^w) = 2^(x-y) * a^(z-w)
4154+ op = 2 ** x / 2 ** y * (a ** z / a ** w )
4155+ f = function ([x , y , z , w , a ], op , mode )
4156+ utt .assert_allclose (f (3 , 5 , 6 , 4 , 7 ), 2 ** (3 - 5 ) * 7 ** (6 - 4 ))
4157+ graph = f .maker .fgraph .toposort ()
4158+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4159+ assert len ([True for n in graph if isinstance (n .op .scalar_op , aes .Sub )]) == 2
4160+ assert any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4161+
4162+ # a^x * a^y * exp(z) * exp(w) = a^(x+y) * exp(z+w)
4163+ op = a ** x * a ** y * exp (z ) * exp (w )
4164+ f = function ([x , y , z , w , a ], op , mode )
4165+ utt .assert_allclose (f (3 , 4 , 5 , 6 , 2 ), 2 ** (3 + 4 ) * np .exp (5 + 6 ))
4166+ graph = f .maker .fgraph .toposort ()
4167+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4168+ assert len ([True for n in graph if isinstance (n .op .scalar_op , aes .Add )]) == 2
4169+ assert any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4170+
4171+
40174172def test_local_expm1 ():
40184173 x = matrix ("x" )
40194174 u = scalar ("u" )
0 commit comments