|
2 | 2 |
|
3 | 3 | import itertools |
4 | 4 | import operator |
| 5 | +from collections import defaultdict |
5 | 6 | from functools import partial, reduce |
6 | 7 |
|
7 | 8 | import numpy as np |
@@ -425,14 +426,12 @@ def local_sumsqr2dot(fgraph, node): |
425 | 426 |
|
426 | 427 | @register_specialize |
427 | 428 | @node_rewriter([mul, true_div]) |
428 | | -def local_mulexp2expadd(fgraph, node): |
| 429 | +def local_mul_exp_to_exp_add(fgraph, node): |
429 | 430 | """ |
430 | 431 | This rewrite detects e^x * e^y and converts it to e^(x+y). |
431 | 432 | Similarly, e^x / e^y becomes e^(x-y). |
432 | 433 | """ |
433 | | - if isinstance(node.op, Elemwise) and isinstance( |
434 | | - node.op.scalar_op, (aes.Mul, aes.TrueDiv) |
435 | | - ): |
| 434 | + if isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv)): |
436 | 435 | exps = [ |
437 | 436 | n.owner.inputs[0] |
438 | 437 | for n in node.inputs |
@@ -468,16 +467,12 @@ def local_mulexp2expadd(fgraph, node): |
468 | 467 |
|
469 | 468 | @register_specialize |
470 | 469 | @node_rewriter([mul, true_div]) |
471 | | -def local_mulpow2powadd(fgraph, node): |
| 470 | +def local_mul_pow_to_pow_add(fgraph, node): |
472 | 471 | """ |
473 | 472 | This rewrite detects a^x * a^y and converts it to a^(x+y). |
474 | 473 | Similarly, a^x / a^y becomes a^(x-y). |
475 | 474 | """ |
476 | | - if isinstance(node.op, Elemwise) and isinstance( |
477 | | - node.op.scalar_op, (aes.Mul, aes.TrueDiv) |
478 | | - ): |
479 | | - from collections import defaultdict |
480 | | - |
| 475 | + if isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv)): |
481 | 476 | # search for pow-s and group them by their bases |
482 | 477 | pow_nodes = defaultdict(list) |
483 | 478 | rest = [] |
|
0 commit comments