@@ -431,38 +431,37 @@ def local_mul_exp_to_exp_add(fgraph, node):
431431 This rewrite detects e^x * e^y and converts it to e^(x+y).
432432 Similarly, e^x / e^y becomes e^(x-y).
433433 """
434- if isinstance (node .op .scalar_op , (aes .Mul , aes .TrueDiv )):
435- exps = [
436- n .owner .inputs [0 ]
434+ exps = [
435+ n .owner .inputs [0 ]
436+ for n in node .inputs
437+ if n .owner
438+ and hasattr (n .owner .op , "scalar_op" )
439+ and isinstance (n .owner .op .scalar_op , aes .Exp )
440+ ]
441+ # Can only do any rewrite if there are at least two exp-s
442+ if len (exps ) >= 2 :
443+ # Mul -> add; TrueDiv -> sub
444+ orig_op , new_op = mul , add
445+ if isinstance (node .op .scalar_op , aes .TrueDiv ):
446+ orig_op , new_op = true_div , sub
447+ new_out = exp (new_op (* exps ))
448+ if new_out .dtype != node .outputs [0 ].dtype :
449+ new_out = cast (new_out , dtype = node .outputs [0 ].dtype )
450+ # The original Mul may have more than two factors, some of which may not be exp nodes.
451+ # If so, we keep multiplying them with the new exp(sum) node.
452+ # E.g.: e^x * y * e^z * w --> e^(x+z) * y * w
453+ rest = [
454+ n
437455 for n in node .inputs
438- if n .owner
439- and hasattr (n .owner .op , "scalar_op" )
440- and isinstance (n .owner .op .scalar_op , aes .Exp )
456+ if not n .owner
457+ or not hasattr (n .owner .op , "scalar_op" )
458+ or not isinstance (n .owner .op .scalar_op , aes .Exp )
441459 ]
442- # Can only do any rewrite if there are at least two exp-s
443- if len (exps ) >= 2 :
444- # Mul -> add; TrueDiv -> sub
445- orig_op , new_op = mul , add
446- if isinstance (node .op .scalar_op , aes .TrueDiv ):
447- orig_op , new_op = true_div , sub
448- new_out = exp (new_op (* exps ))
460+ if len (rest ) > 0 :
461+ new_out = orig_op (new_out , * rest )
449462 if new_out .dtype != node .outputs [0 ].dtype :
450463 new_out = cast (new_out , dtype = node .outputs [0 ].dtype )
451- # The original Mul may have more than two factors, some of which may not be exp nodes.
452- # If so, we keep multiplying them with the new exp(sum) node.
453- # E.g.: e^x * y * e^z * w --> e^(x+z) * y * w
454- rest = [
455- n
456- for n in node .inputs
457- if not n .owner
458- or not hasattr (n .owner .op , "scalar_op" )
459- or not isinstance (n .owner .op .scalar_op , aes .Exp )
460- ]
461- if len (rest ) > 0 :
462- new_out = orig_op (new_out , * rest )
463- if new_out .dtype != node .outputs [0 ].dtype :
464- new_out = cast (new_out , dtype = node .outputs [0 ].dtype )
465- return [new_out ]
464+ return [new_out ]
466465
467466
468467@register_specialize
@@ -472,52 +471,51 @@ def local_mul_pow_to_pow_add(fgraph, node):
472471 This rewrite detects a^x * a^y and converts it to a^(x+y).
473472 Similarly, a^x / a^y becomes a^(x-y).
474473 """
475- if isinstance (node .op .scalar_op , (aes .Mul , aes .TrueDiv )):
476- # search for pow-s and group them by their bases
477- pow_nodes = defaultdict (list )
478- rest = []
479- for n in node .inputs :
480- if (
481- n .owner
482- and hasattr (n .owner .op , "scalar_op" )
483- and isinstance (n .owner .op .scalar_op , aes .Pow )
484- ):
485- base_node = n .owner .inputs [0 ]
486- # exponent is at n.owner.inputs[1], but we need to store the full node
487- # in case this particular power node remains alone and can't be rewritten
488- pow_nodes [base_node ].append (n )
489- else :
490- rest .append (n )
491-
492- # Can only do any rewrite if there are at least two pow-s with the same base
493- can_rewrite = [k for k , v in pow_nodes .items () if len (v ) >= 2 ]
494- if len (can_rewrite ) >= 1 :
495- # Mul -> add; TrueDiv -> sub
496- orig_op , new_op = mul , add
497- if isinstance (node .op .scalar_op , aes .TrueDiv ):
498- orig_op , new_op = true_div , sub
499- pow_factors = []
500- # Rewrite pow-s having the same base for each different base
501- # E.g.: a^x * a^y --> a^(x+y)
502- for base in can_rewrite :
503- exponents = [n .owner .inputs [1 ] for n in pow_nodes [base ]]
504- new_node = base ** new_op (* exponents )
505- if new_node .dtype != node .outputs [0 ].dtype :
506- new_node = cast (new_node , dtype = node .outputs [0 ].dtype )
507- pow_factors .append (new_node )
508- # Don't forget about those sole pow-s that couldn't be rewriten
509- sole_pows = [v [0 ] for k , v in pow_nodes .items () if k not in can_rewrite ]
510- # Combine the rewritten pow-s and other, non-pow factors of the original Mul
511- # E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v
512- if len (pow_factors ) > 1 or len (sole_pows ) > 0 or len (rest ) > 0 :
513- new_out = orig_op (* pow_factors , * sole_pows , * rest )
514- if new_out .dtype != node .outputs [0 ].dtype :
515- new_out = cast (new_out , dtype = node .outputs [0 ].dtype )
516- else :
517- # if all factors of the original mul were pows-s with the same base,
518- # we can get rid of the mul completely.
519- new_out = pow_factors [0 ]
520- return [new_out ]
474+ # search for pow-s and group them by their bases
475+ pow_nodes = defaultdict (list )
476+ rest = []
477+ for n in node .inputs :
478+ if (
479+ n .owner
480+ and hasattr (n .owner .op , "scalar_op" )
481+ and isinstance (n .owner .op .scalar_op , aes .Pow )
482+ ):
483+ base_node = n .owner .inputs [0 ]
484+ # exponent is at n.owner.inputs[1], but we need to store the full node
485+ # in case this particular power node remains alone and can't be rewritten
486+ pow_nodes [base_node ].append (n )
487+ else :
488+ rest .append (n )
489+
490+ # Can only do any rewrite if there are at least two pow-s with the same base
491+ can_rewrite = [k for k , v in pow_nodes .items () if len (v ) >= 2 ]
492+ if len (can_rewrite ) >= 1 :
493+ # Mul -> add; TrueDiv -> sub
494+ orig_op , new_op = mul , add
495+ if isinstance (node .op .scalar_op , aes .TrueDiv ):
496+ orig_op , new_op = true_div , sub
497+ pow_factors = []
498+ # Rewrite pow-s having the same base for each different base
499+ # E.g.: a^x * a^y --> a^(x+y)
500+ for base in can_rewrite :
501+ exponents = [n .owner .inputs [1 ] for n in pow_nodes [base ]]
502+ new_node = base ** new_op (* exponents )
503+ if new_node .dtype != node .outputs [0 ].dtype :
504+ new_node = cast (new_node , dtype = node .outputs [0 ].dtype )
505+ pow_factors .append (new_node )
506+ # Don't forget about those sole pow-s that couldn't be rewriten
507+ sole_pows = [v [0 ] for k , v in pow_nodes .items () if k not in can_rewrite ]
508+ # Combine the rewritten pow-s and other, non-pow factors of the original Mul
509+ # E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v
510+ if len (pow_factors ) > 1 or len (sole_pows ) > 0 or len (rest ) > 0 :
511+ new_out = orig_op (* pow_factors , * sole_pows , * rest )
512+ if new_out .dtype != node .outputs [0 ].dtype :
513+ new_out = cast (new_out , dtype = node .outputs [0 ].dtype )
514+ else :
515+ # if all factors of the original mul were pows-s with the same base,
516+ # we can get rid of the mul completely.
517+ new_out = pow_factors [0 ]
518+ return [new_out ]
521519
522520
523521@register_stabilize
0 commit comments