diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 020f0d8d04..06ee041a38 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -54,6 +54,8 @@ jobs: steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 + with: + python-version: '3.9' - uses: pre-commit/action@v2.0.0 test: @@ -72,9 +74,9 @@ jobs: install-numba: [1] part: - "tests --ignore=tests/tensor --ignore=tests/sparse --ignore=tests/tensor/nnet" - - "tests/tensor tests/sparse --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_basic_opt.py --ignore=tests/tensor/test_math_opt.py --ignore=tests/tensor/nnet" + - "tests/tensor tests/sparse --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/rewriting/test_basic.py --ignore=tests/tensor/rewriting/test_math.py --ignore=tests/tensor/nnet" - "tests/tensor/test_basic.py tests/tensor/test_math.py tests/tensor/test_math_scipy.py tests/tensor/test_inplace.py" - - "tests/tensor/test_elemwise.py tests/tensor/test_basic_opt.py tests/tensor/test_math_opt.py" + - "tests/tensor/test_elemwise.py tests/tensor/rewriting/test_basic.py tests/tensor/rewriting/test_math.py" - "tests/tensor/nnet --ignore-glob='*/test_abstract_conv.py'" - "tests/tensor/nnet/test_abstract_conv.py" include: @@ -143,7 +145,7 @@ jobs: run: | mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov sympy if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.55" numba-scipy; fi - mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib + mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax "jaxlib!=0.3.15" pip install -e ./ mamba list && pip freeze python -c 'import aesara; print(aesara.config.__str__(print_doc=False))' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8073e0f67b..7a16f40178 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ exclude: | )$ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.1.0 + rev: v4.3.0 hooks: - id: debug-statements exclude: | @@ -15,21 +15,21 @@ repos: aesara/breakpoint\.py| aesara/graph/op\.py| aesara/compile/nanguardmode\.py| - aesara/graph/opt\.py| + aesara/graph/rewriting/basic\.py| aesara/tensor/var\.py| )$ - id: check-merge-conflict - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 22.8.0 hooks: - id: black language_version: python3 - repo: https://gitlab.com/pycqa/flake8 - rev: 3.8.4 + rev: 3.9.2 hooks: - id: flake8 - repo: https://github.com/pycqa/isort - rev: 5.6.4 + rev: 5.10.1 hooks: - id: isort - repo: https://github.com/humitos/mirrors-autoflake.git @@ -47,7 +47,7 @@ repos: )$ args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable'] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.931 + rev: v0.971 hooks: - id: mypy additional_dependencies: diff --git a/aesara/__init__.py b/aesara/__init__.py index 39eef9a041..0d81ab1087 100644 --- a/aesara/__init__.py +++ b/aesara/__init__.py @@ -62,12 +62,6 @@ def disable_log_handler(logger=aesara_logger, handler=logging_default_handler): raise RuntimeError("You have the aesara directory in your Python path.") from aesara.configdefaults import config -from aesara.utils import deprecated - - -change_flags = deprecated("Use aesara.config.change_flags instead!")( - config.change_flags -) # This is the api version for ops that generate C code. External ops @@ -178,3 +172,27 @@ def get_scalar_constant_value(v): # imports were executed, we can warn about remaining flags provided by the user # through AESARA_FLAGS. config.warn_unused_flags() + +DEPRECATED_NAMES = [ + ( + "change_flags", + "`aesara.change_flags` is deprecated: use `aesara.config.change_flags` instead.", + config.change_flags, + ), +] + + +def __getattr__(name): + """Intercept module-level attribute access of deprecated symbols. + + Adapted from https://stackoverflow.com/a/55139609/3006474. + + """ + from warnings import warn + + for old_name, msg, old_object in DEPRECATED_NAMES: + if name == old_name: + warn(msg, DeprecationWarning, stacklevel=2) + return old_object + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/aesara/compile/builders.py b/aesara/compile/builders.py index 738d724ac6..b751172c0a 100644 --- a/aesara/compile/builders.py +++ b/aesara/compile/builders.py @@ -24,9 +24,9 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.null_type import NullType from aesara.graph.op import HasInnerGraph, Op -from aesara.graph.opt import in2out, local_optimizer +from aesara.graph.rewriting.basic import in2out, node_rewriter from aesara.graph.utils import MissingInputError -from aesara.tensor.basic_opt import ShapeFeature +from aesara.tensor.rewriting.shape import ShapeFeature def infer_shape(outs, inputs, input_shapes): @@ -928,7 +928,7 @@ def perform(self, node, inputs, outputs): output[0] = variable -@local_optimizer([OpFromGraph]) +@node_rewriter([OpFromGraph]) def inline_ofg_expansion(fgraph, node): """ This optimization expands internal graph of OpFromGraph. diff --git a/aesara/compile/function/pfunc.py b/aesara/compile/function/pfunc.py index 94be2db2c6..dfe5ff6e16 100644 --- a/aesara/compile/function/pfunc.py +++ b/aesara/compile/function/pfunc.py @@ -189,11 +189,8 @@ def clone_inputs(i): (store_into, update_d[store_into]), ) - # filter_variable ensure smooth conversion of cpu Types try: - update_val = store_into.type.filter_variable( - update_val, allow_convert=False - ) + update_val = store_into.type.filter_variable(update_val, allow_convert=True) except TypeError: err_msg = ( "An update must have the same type as the" diff --git a/aesara/compile/function/types.py b/aesara/compile/function/types.py index d1f9eae2fc..2444145e29 100644 --- a/aesara/compile/function/types.py +++ b/aesara/compile/function/types.py @@ -1,7 +1,4 @@ -""" -Driver of graph construction, optimization, and linking. - -""" +"""Objects that orchestrate graph construction, rewriting, and linking.""" import copy import copyreg @@ -753,9 +750,8 @@ def checkSV(sv_ori, sv_rpl): # cause problems. on_unused_input="ignore", function_builder=maker.function_builder, - # As this is an optimized graph, it - # can contain inplace. DebugMode check - # that. + # As this is an rewritten graph, it can contain inplace. DebugMode + # check that. accept_inplace=True, no_fgraph_prep=True, ).create(input_storage, storage_map=new_storage_map) @@ -1182,7 +1178,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): This loop was inserted to remove aliasing between outputs when they all evaluate to the same value. Originally it was OK for outputs to be aliased, but some of the outputs can be shared variables, and is not good for shared - variables to be aliased. It might be possible to optimize this by making + variables to be aliased. It might be possible to rewrite this by making sure there is no aliasing only between shared variables. If some outputs are constant, we add deep copy to respect the memory @@ -1279,7 +1275,7 @@ class FunctionMaker: """ `FunctionMaker` is the class to `create` `Function` instances. - This class has the fgraph, the optimizer, and the linker. When + This class has the fgraph, the rewriter, and the linker. When copying a `Function`, there is no need to duplicate the `FunctionMaker` instance. Deepcopy still copies both, which can variable in re-compilation. @@ -1292,7 +1288,7 @@ class FunctionMaker: functions produced by FunctionMaker will return their output value directly. mode : Mode instance - Telling FunctionMaker how to optimize and link. None means to use the + Telling FunctionMaker how to rewrite and link. None means to use the `config.mode`. accept_inplace : bool True iff it is acceptable to have inplace operations in the graph from @@ -1395,44 +1391,44 @@ def check_unused_inputs(inputs, outputs, on_unused_input): @staticmethod def prepare_fgraph( - inputs, outputs, additional_outputs, fgraph, optimizer, linker, profile + inputs, outputs, additional_outputs, fgraph, rewriter, linker, profile ): try: - start_optimizer = time.time() + start_rewriter = time.time() - optimizer_profile = None - opt_time = None + rewriter_profile = None + rewrite_time = None with config.change_flags( compute_test_value=config.compute_test_value_opt, traceback__limit=config.traceback__compile_limit, ): - optimizer_profile = optimizer(fgraph) + rewriter_profile = rewriter(fgraph) - end_optimizer = time.time() - opt_time = end_optimizer - start_optimizer - _logger.debug(f"Optimizing took {opt_time:f} seconds") + end_rewriter = time.time() + rewrite_time = end_rewriter - start_rewriter + _logger.debug(f"Rewriting took {rewrite_time:f} seconds") # Add deep copy to respect the memory interface insert_deepcopy(fgraph, inputs, outputs + additional_outputs) finally: - # If the optimizer got interrupted - if opt_time is None: - end_optimizer = time.time() - opt_time = end_optimizer - start_optimizer + # If the rewriter got interrupted + if rewrite_time is None: + end_rewriter = time.time() + rewrite_time = end_rewriter - start_rewriter - aesara.compile.profiling.total_graph_opt_time += opt_time + aesara.compile.profiling.total_graph_rewrite_time += rewrite_time if profile: - if optimizer_profile is None and hasattr(optimizer, "pre_profile"): - optimizer_profile = optimizer.pre_profile + if rewriter_profile is None and hasattr(rewriter, "pre_profile"): + rewriter_profile = rewriter.pre_profile - profile.optimizer_time += opt_time + profile.rewriting_time += rewrite_time if config.profile_optimizer: - profile.optimizer_profile = (optimizer, optimizer_profile) + profile.rewriter_profile = (rewriter, rewriter_profile) elif config.profile_optimizer and profile is not False: # If False, it means the profiling for that function was # explicitly disabled @@ -1466,8 +1462,8 @@ def __init__( ): # Save the provided mode, not the instantiated mode. # The instantiated mode don't pickle and if we unpickle an Aesara - # function and it get re-compiled, we want the current optimizer to be - # used, not the optimizer when it was saved. + # function and it get re-compiled, we want the current rewriter to be + # used, not the rewriter when it was saved. self.mode = mode mode = aesara.compile.mode.get_mode(mode) @@ -1478,7 +1474,7 @@ def __init__( if profile: # This is very important: # 1) We preload the cache here to not have its timing - # included in optimization that compile function. + # included with the rewrites. # 2) Do not refresh the cache here by default. It cause # too much execution time during testing as we compile # much more functions then the number of compile c @@ -1515,11 +1511,11 @@ def __init__( self.fgraph = fgraph - optimizer, linker = mode.optimizer, copy.copy(mode.linker) + rewriter, linker = mode.optimizer, copy.copy(mode.linker) if not no_fgraph_prep: self.prepare_fgraph( - inputs, outputs, found_updates, fgraph, optimizer, linker, profile + inputs, outputs, found_updates, fgraph, rewriter, linker, profile ) assert len(fgraph.outputs) == len(outputs + found_updates) @@ -1715,7 +1711,7 @@ def orig_function( time spent in this function. accept_inplace : bool True iff the graph can contain inplace operations prior to the - optimization phase (default is False). + rewrite phase (default is False). profile : None or ProfileStats instance on_unused_input : {'raise', 'warn', 'ignore', None} What to do if a variable in the 'inputs' list is not used in the graph. diff --git a/aesara/compile/mode.py b/aesara/compile/mode.py index ff5048eb9d..5232163ba2 100644 --- a/aesara/compile/mode.py +++ b/aesara/compile/mode.py @@ -10,17 +10,17 @@ from aesara.compile.function.types import Supervisor from aesara.configdefaults import config from aesara.graph.destroyhandler import DestroyHandler -from aesara.graph.opt import ( - CheckStackTraceOptimization, - GlobalOptimizer, +from aesara.graph.rewriting.basic import ( + CheckStackTraceRewriter, + GraphRewriter, MergeOptimizer, - NavigatorOptimizer, + NodeProcessingGraphRewriter, ) -from aesara.graph.optdb import ( +from aesara.graph.rewriting.db import ( EquilibriumDB, LocalGroupDB, - OptimizationDatabase, - OptimizationQuery, + RewriteDatabase, + RewriteDatabaseQuery, SequenceDB, TopoDB, ) @@ -64,15 +64,15 @@ def register_linker(name, linker): exclude = [] if not config.cxx: exclude = ["cxx_only"] -OPT_NONE = OptimizationQuery(include=[], exclude=exclude) +OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude) # Even if multiple merge optimizer call will be there, this shouldn't # impact performance. -OPT_MERGE = OptimizationQuery(include=["merge"], exclude=exclude) -OPT_FAST_RUN = OptimizationQuery(include=["fast_run"], exclude=exclude) +OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude) +OPT_FAST_RUN = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude) OPT_FAST_RUN_STABLE = OPT_FAST_RUN.requiring("stable") -OPT_FAST_COMPILE = OptimizationQuery(include=["fast_compile"], exclude=exclude) -OPT_STABILIZE = OptimizationQuery(include=["fast_run"], exclude=exclude) +OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"], exclude=exclude) +OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude) OPT_STABILIZE.position_cutoff = 1.5000001 OPT_NONE.name = "OPT_NONE" OPT_MERGE.name = "OPT_MERGE" @@ -106,13 +106,13 @@ def register_linker(name, linker): def register_optimizer(name, opt): - """Add a `GlobalOptimizer` which can be referred to by `name` in `Mode`.""" + """Add a `GraphRewriter` which can be referred to by `name` in `Mode`.""" if name in predefined_optimizers: raise ValueError(f"Optimizer name already taken: {name}") predefined_optimizers[name] = opt -class AddDestroyHandler(GlobalOptimizer): +class AddDestroyHandler(GraphRewriter): """ This optimizer performs two important functions: @@ -145,7 +145,7 @@ def add_requirements(self, fgraph): fgraph.attach_feature(DestroyHandler()) -class AddFeatureOptimizer(GlobalOptimizer): +class AddFeatureOptimizer(GraphRewriter): """ This optimizer adds a provided feature to the function graph. """ @@ -161,7 +161,7 @@ def apply(self, fgraph): pass -class PrintCurrentFunctionGraph(GlobalOptimizer): +class PrintCurrentFunctionGraph(GraphRewriter): """ This optimizer is for debugging. @@ -190,10 +190,10 @@ def apply(self, fgraph): # The opt should not do anything that need shape inference. # New nodes that don't have infer_shape need that the original node # also don't have infer_shape -local_useless = LocalGroupDB(apply_all_opts=True, profile=True) +local_useless = LocalGroupDB(apply_all_rewrites=True, profile=True) optdb.register( "useless", - TopoDB(local_useless, failure_callback=NavigatorOptimizer.warn_inplace), + TopoDB(local_useless, failure_callback=NodeProcessingGraphRewriter.warn_inplace), "fast_run", "fast_compile", position=0.6, @@ -212,10 +212,10 @@ def apply(self, fgraph): "canonicalize_db", position=1, ) -# Register in the canonizer Equilibrium as a clean up opt the merge opt. +# Register in the canonizer Equilibrium as a clean-up rewrite the merge rewrite. # Without this, as the equilibrium have ignore_newtrees=False, we -# won't merge all nodes if it is set as a global optimizer with -# final_opt=True. +# won't merge all nodes if it is set as a global rewriter with +# final_rewriter=True. # We need a new instance of MergeOptimizer to don't have its name # changed by other usage of it. @@ -271,25 +271,24 @@ def apply(self, fgraph): if config.check_stack_trace == "off": _tags = () -optdb.register("CheckStackTrace", CheckStackTraceOptimization(), *_tags, position=-1) +optdb.register("CheckStackTrace", CheckStackTraceRewriter(), *_tags, position=-1) del _tags class Mode: - """ - The Mode represents a way to optimize and then link a computation graph. + """A class that specifies the rewrites/optimizations used during function compilation. Parameters ---------- - optimizer: a structure of type Optimizer + optimizer An Optimizer may simplify the math, put similar computations together, improve numerical stability and various other improvements. - linker: a structure of type Linker + linker A Linker decides which implementations to use (C or Python, for example) and how to string them together to perform the computation. - db: - The ``OptimizationDatabase`` used by this ``Mode``. Note: This value - is *not* part of a ``Mode`` instance's pickled state. + db + The `RewriteDatabase` used by this `Mode`. Note: This value + is *not* part of a `Mode` instance's pickled state. See Also -------- @@ -302,8 +301,8 @@ class Mode: def __init__( self, linker: Optional[Union[str, Linker]] = None, - optimizer: Union[str, OptimizationQuery] = "default", - db: OptimizationDatabase = None, + optimizer: Union[str, RewriteDatabaseQuery] = "default", + db: RewriteDatabase = None, ): if linker is None: linker = config.linker @@ -320,7 +319,7 @@ def __init__( # self.provided_optimizer - typically the `optimizer` arg. # But if the `optimizer` arg is keyword corresponding to a predefined - # OptimizationQuery, then this stores the query + # RewriteDatabaseQuery, then this stores the query # self._optimizer - typically same as provided_optimizer?? # self.__get_optimizer - returns self._optimizer (possibly querying @@ -342,7 +341,7 @@ def __setstate__(self, state): self.linker = linker if isinstance(optimizer, str) or optimizer is None: optimizer = predefined_optimizers[optimizer] - if isinstance(optimizer, OptimizationQuery): + if isinstance(optimizer, RewriteDatabaseQuery): self.provided_optimizer = optimizer self._optimizer = optimizer self.call_time = 0 @@ -357,7 +356,7 @@ def __str__(self): ) def __get_optimizer(self): - if isinstance(self._optimizer, OptimizationQuery): + if isinstance(self._optimizer, RewriteDatabaseQuery): return self.optdb.query(self._optimizer) else: return self._optimizer @@ -375,7 +374,7 @@ def including(self, *tags): link, opt = self.get_linker_optimizer( self.provided_linker, self.provided_optimizer ) - # N.B. opt might be a OptimizationQuery instance, not sure what else it might be... + # N.B. opt might be a RewriteDatabaseQuery instance, not sure what else it might be... # string? Optimizer? OptDB? who knows??? return self.clone(optimizer=opt.including(*tags), linker=link) @@ -448,11 +447,11 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): JAX = Mode( JAXLinker(), - OptimizationQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]), + RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]), ) NUMBA = Mode( NumbaLinker(), - OptimizationQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]), + RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]), ) diff --git a/aesara/compile/profiling.py b/aesara/compile/profiling.py index ec2c74eec2..15c57fdf74 100644 --- a/aesara/compile/profiling.py +++ b/aesara/compile/profiling.py @@ -45,7 +45,7 @@ def extended_open(filename, mode="r"): aesara_imported_time: float = time.time() total_fct_exec_time: float = 0.0 -total_graph_opt_time: float = 0.0 +total_graph_rewrite_time: float = 0.0 total_time_linker: float = 0.0 _atexit_print_list: List["ProfileStats"] = [] @@ -97,7 +97,7 @@ def _atexit_print_fn(): "fct_call_time", "fct_callcount", "vm_call_time", - "optimizer_time", + "rewriter_time", "linker_time", "validate_time", "import_time", @@ -120,18 +120,18 @@ def _atexit_print_fn(): assert key not in cum_attr, (key, cum_attr) cum_attr[key] = val - if cum.optimizer_profile and ps.optimizer_profile: + if cum.rewriter_profile and ps.rewriter_profile: try: - merge = cum.optimizer_profile[0].merge_profile( - cum.optimizer_profile[1], ps.optimizer_profile[1] + merge = cum.rewriter_profile[0].merge_profile( + cum.rewriter_profile[1], ps.rewriter_profile[1] ) - assert len(merge) == len(cum.optimizer_profile[1]) - cum.optimizer_profile = (cum.optimizer_profile[0], merge) + assert len(merge) == len(cum.rewriter_profile[1]) + cum.rewriter_profile = (cum.rewriter_profile[0], merge) except Exception as e: print(e) - cum.optimizer_profile = None + cum.rewriter_profile = None else: - cum.optimizer_profile = None + cum.rewriter_profile = None cum.summary( file=destination_file, @@ -149,7 +149,7 @@ def print_global_stats(): -- Time elapsed since Aesara was imported -- Time spent inside Aesara functions -- Time spent in compiling Aesara functions - -- on graph optimization + -- on graph rewriters -- on linker """ @@ -168,7 +168,7 @@ def print_global_stats(): f"Time elasped since Aesara import = {time.time() - aesara_imported_time:6.3f}s, " f"Time spent in Aesara functions = {total_fct_exec_time:6.3f}s, " "Time spent compiling Aesara functions: " - f" optimization = {total_graph_opt_time:6.3f}s, linker = {total_time_linker:6.3f}s ", + f"rewriting = {total_graph_rewrite_time:6.3f}s, linking = {total_time_linker:6.3f}s ", ), file=destination_file, ) @@ -186,7 +186,7 @@ def register_profiler_printer(fct): class ProfileStats: """ Object to store runtime and memory profiling information for all of - Aesara's operations: compilation, optimization, execution. + Aesara's operations: compilation, rewriting, execution. Parameters ---------- @@ -220,7 +220,7 @@ def reset(self): compile_time: float = 0.0 # Total time spent in body of orig_function, - # dominated by graph optimization and compilation of C + # dominated by graph rewriting and compilation of C # fct_call_time: float = 0.0 @@ -259,12 +259,12 @@ def reset(self): # Variable -> offset # - optimizer_time: float = 0.0 - # time spent optimizing graph (FunctionMaker.__init__) + rewriting_time: float = 0.0 + # time spent rewriting graph (FunctionMaker.__init__) validate_time: float = 0.0 # time spent in fgraph.validate - # This is a subset of optimizer_time that is dominated by toposort() + # This is a subset of rewriting_time that is dominated by toposort() # when the destorymap feature is included. linker_time: float = 0.0 @@ -284,8 +284,8 @@ def reset(self): # case we print the profile when the function wasn't executed, or if there # is a lazy operation in the graph. - optimizer_profile = None - # None or tuple (the optimizer, the profile it returned) + rewriter_profile = None + # None or tuple (the rewriter, the profile it returned) # param is called flag_time_thunks because most other attributes with time # in the name are times *of* something, rather than configuration flags. @@ -801,9 +801,9 @@ def summary_function(self, file): f" Time in thunks: {local_time}s ({100 * local_time / self.fct_call_time:.3f}%)", file=file, ) - print(f" Total compile time: {self.compile_time:e}s", file=file) + print(f" Total compilation time: {self.compile_time:e}s", file=file) print(f" Number of Apply nodes: {int(self.nb_nodes)}", file=file) - print(f" Aesara Optimizer time: {self.optimizer_time:e}s", file=file) + print(f" Aesara rewrite time: {self.rewriting_time:e}s", file=file) print(f" Aesara validate time: {self.validate_time:e}s", file=file) print( ( @@ -823,9 +823,8 @@ def summary_function(self, file): print(f" Node {node} time {t:e}s", file=file) print("", file=file) - # The validation time is a subset of optimizer_time - if self.optimizer_time > 0: - assert self.validate_time < self.optimizer_time + if self.rewriting_time > 0: + assert self.validate_time < self.rewriting_time def summary_globals(self, file): print( @@ -1468,10 +1467,10 @@ def summary(self, file=sys.stderr, n_ops_to_print=20, n_apply_to_print=20): aesara.printing.debugprint(fcts, print_type=True) if self.variable_shape or self.variable_strides: self.summary_memory(file, n_apply_to_print) - if self.optimizer_profile: - print("Optimizer Profile", file=file) - print("-----------------", file=file) - self.optimizer_profile[0].print_profile(file, self.optimizer_profile[1]) + if self.rewriter_profile: + print("Rewriter Profile", file=file) + print("----------------", file=file) + self.rewriter_profile[0].print_profile(file, self.rewriter_profile[1]) self.print_extra(file) self.print_tips(file) @@ -1619,7 +1618,7 @@ def exp_float32_op(op): ): print( ( - " - You have a dot operation that was not optimized to" + " - You have a dot operation that was not rewritten to" " dot22 (which is faster). Make sure the inputs are " "float32 or float64, and are the same for both inputs. " f"Currently they are: {[i.type for i in node.inputs]}" diff --git a/aesara/configdefaults.py b/aesara/configdefaults.py index 5bce7b067a..b1e914b2a9 100644 --- a/aesara/configdefaults.py +++ b/aesara/configdefaults.py @@ -1107,7 +1107,7 @@ def add_optimizer_configvars(): config.add( "optdb__max_use_ratio", - "A ratio that prevent infinite loop in EquilibriumOptimizer.", + "A ratio that prevent infinite loop in EquilibriumGraphRewriter.", FloatParam(8), in_c_key=False, ) diff --git a/aesara/configparser.py b/aesara/configparser.py index 12023f4de1..99269cfc35 100644 --- a/aesara/configparser.py +++ b/aesara/configparser.py @@ -14,7 +14,7 @@ from io import StringIO from typing import Callable, Dict, Optional, Sequence, Union -from aesara.utils import deprecated, hash_from_code +from aesara.utils import hash_from_code _logger = logging.getLogger("aesara.configparser") @@ -582,8 +582,7 @@ def __getattr__(self, attr): if attr == "_actual": return _ConfigProxy._actual warnings.warn( - "Accessing config through `aesara.configparser.config` is deprecated. " - "Use `aesara.config` instead.", + "`aesara.configparser.config` is deprecated; use `aesara.config` instead.", DeprecationWarning, stacklevel=2, ) @@ -593,8 +592,7 @@ def __setattr__(self, attr, value): if attr == "_actual": return setattr(_ConfigProxy._actual, attr, value) warnings.warn( - "Accessing config through `aesara.configparser.config` is deprecated. " - "Use `aesara.config` instead.", + "`aesara.configparser.config` is deprecated; use `aesara.config` instead.", DeprecationWarning, stacklevel=2, ) @@ -609,12 +607,37 @@ def __setattr__(self, attr, value): # These imports/accesses should be replaced with `aesara.config`, so this wraps # it with warnings: config = _ConfigProxy(_config) -# We can't alias the methods of the `config` variable above without already -# triggering the warning. Instead, we wrap the methods of the actual instance -# with warnings: -change_flags = deprecated("Use aesara.config.change_flags instead!")( - _config.change_flags -) -_config_print = deprecated("Use aesara.config.config_print instead!")( - _config.config_print -) + +DEPRECATED_NAMES = [ + ( + "change_flags", + "`change_flags` is deprecated; use `aesara.config.change_flags` instead.", + _config.change_flags, + ), + ( + "_change_flags", + "`_change_flags` is deprecated; use `aesara.config.change_flags` instead.", + _config.change_flags, + ), + ( + "_config_print", + "`_config_print` is deprecated; use `aesara.config.config_print` instead.", + _config.config_print, + ), +] + + +def __getattr__(name): + """Intercept module-level attribute access of deprecated symbols. + + Adapted from https://stackoverflow.com/a/55139609/3006474. + + """ + from warnings import warn + + for old_name, msg, old_object in DEPRECATED_NAMES: + if name == old_name: + warn(msg, DeprecationWarning, stacklevel=2) + return old_object + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/aesara/gradient.py b/aesara/gradient.py index da3a13220f..51b2bb77ad 100644 --- a/aesara/gradient.py +++ b/aesara/gradient.py @@ -2129,10 +2129,9 @@ def grad(self, args, g_outs): def consider_constant(x): - """ - DEPRECATED: use zero_grad() or disconnected_grad() instead. + """Consider an expression constant when computing gradients. - Consider an expression constant when computing gradients. + DEPRECATED: use `zero_grad` or `disconnected_grad` instead. The expression itself is unaffected, but when its gradient is computed, or the gradient of another expression that this @@ -2149,14 +2148,14 @@ def consider_constant(x): """ warnings.warn( ( - "consider_constant() is deprecated, use zero_grad() or " - "disconnected_grad() instead." + "`ConsiderConstant` is deprecated; use `zero_grad` or " + "`disconnected_grad` instead." ), category=DeprecationWarning, stacklevel=3, ) - return consider_constant_(x) + return ConsiderConstant()(x) class ZeroGrad(ViewOp): @@ -2365,3 +2364,28 @@ def grad_scale(x, multiplier): 0.416... """ return GradScale(multiplier)(x) + + +DEPRECATED_NAMES = [ + ( + "consider_constant_", + "`consider_constant_` is deprecated; use `zero_grad` or `disconnected_grad` instead.", + ConsiderConstant(), + ), +] + + +def __getattr__(name): + """Intercept module-level attribute access of deprecated symbols. + + Adapted from https://stackoverflow.com/a/55139609/3006474. + + """ + from warnings import warn + + for old_name, msg, old_object in DEPRECATED_NAMES: + if name == old_name: + warn(msg, DeprecationWarning, stacklevel=2) + return old_object + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/aesara/graph/__init__.py b/aesara/graph/__init__.py index 492e328b85..2c93c8112e 100644 --- a/aesara/graph/__init__.py +++ b/aesara/graph/__init__.py @@ -13,8 +13,8 @@ from aesara.graph.op import Op from aesara.graph.type import Type from aesara.graph.fg import FunctionGraph -from aesara.graph.opt import local_optimizer, optimizer -from aesara.graph.opt_utils import optimize_graph -from aesara.graph.optdb import OptimizationQuery +from aesara.graph.rewriting.basic import node_rewriter, graph_rewriter +from aesara.graph.rewriting.utils import rewrite_graph +from aesara.graph.rewriting.db import RewriteDatabaseQuery # isort: on diff --git a/aesara/graph/callcache.py b/aesara/graph/callcache.py deleted file mode 100644 index 0aefc842df..0000000000 --- a/aesara/graph/callcache.py +++ /dev/null @@ -1,54 +0,0 @@ -import logging -import pickle - - -_logger = logging.getLogger("aesara.graph.callcache") - - -class CallCache: - def __init__(self, filename=None): - self.filename = filename - try: - if filename is None: - raise OSError("bad filename") # just goes to except - with open(filename) as f: - self.cache = pickle.load(f) - except OSError: - self.cache = {} - - def persist(self, filename=None): - """ - Cache "filename" as a pickle file - """ - if filename is None: - filename = self.filename - with open(filename, "w") as f: - pickle.dump(self.cache, f) - - def call(self, fn, args=(), key=None): - """ - Retrieve item from the cache(if available) - based on a key - - Parameters: - ---------- - key - parameter to retrieve cache item - fn,args - key to retrieve if "key" is None - """ - if key is None: - key = (fn, tuple(args)) - if key not in self.cache: - _logger.debug("cache miss %i", len(self.cache)) - self.cache[key] = fn(*args) - else: - _logger.debug("cache hit %i", len(self.cache)) - return self.cache[key] - - def __del__(self): - try: - if self.filename: - self.persist() - except Exception as e: - _logger.error("persist failed %s %s", self.filename, e) diff --git a/aesara/graph/features.py b/aesara/graph/features.py index 56ada7faf8..4e69a654bb 100644 --- a/aesara/graph/features.py +++ b/aesara/graph/features.py @@ -603,13 +603,13 @@ def replace_all_validate( fgraph.revert(chk) if verbose: print( - f"optimizer: validate failed on node {r}.\n Reason: {reason}, {e}" + f"rewriting: validate failed on node {r}.\n Reason: {reason}, {e}" ) raise if verbose: print( - f"optimizer: rewrite {reason} replaces {r} of {r.owner} with {new_r} of {new_r.owner}" + f"rewriting: rewrite {reason} replaces {r} of {r.owner} with {new_r} of {new_r.owner}" ) # The return is needed by replace_all_validate_remove diff --git a/aesara/graph/fg.py b/aesara/graph/fg.py index 4a793614d6..26fb74bd7f 100644 --- a/aesara/graph/fg.py +++ b/aesara/graph/fg.py @@ -481,7 +481,7 @@ def replace( verbose = config.optimizer_verbose if verbose: print( - f"optimizer: rewrite {reason} replaces {var} of {var.owner} with {new_var} of {new_var.owner}" + f"rewriting: rewrite {reason} replaces {var} of {var.owner} with {new_var} of {new_var.owner}" ) new_var = var.type.filter_variable(new_var, allow_convert=True) @@ -909,7 +909,9 @@ def __getstate__(self): for feature in self._features: for attr in getattr(feature, "pickle_rm_attr", []): del d[attr] - # The class Updater take fct as parameter and they are lambda function, so unpicklable. + + # XXX: The `Feature` `DispatchingFeature` takes functions as parameter + # and they can be lambda functions, making them unpicklable. # execute_callbacks_times have reference to optimizer, and they can't # be pickled as the decorators with parameters aren't pickable. diff --git a/aesara/graph/kanren.py b/aesara/graph/kanren.py index ba8939b44d..d0b6dba190 100644 --- a/aesara/graph/kanren.py +++ b/aesara/graph/kanren.py @@ -1,100 +1,10 @@ -from typing import Callable, Iterator, List, Optional, Union +import warnings -from etuples.core import ExpressionTuple -from kanren import run -from unification import var -from unification.variable import Var -from aesara.graph.basic import Apply, Variable -from aesara.graph.opt import LocalOptimizer -from aesara.graph.unify import eval_if_etuple +warnings.warn( + "The module `aesara.graph.kanren` is deprecated; use `aesara.graph.rewriting.kanren` instead.", + DeprecationWarning, + stacklevel=2, +) - -class KanrenRelationSub(LocalOptimizer): - r"""A local optimizer that uses `kanren` to match and replace terms. - - See `kanren `__ for more information - miniKanren and the API for constructing `kanren` goals. - - Example - ------- - - ..code-block:: python - - from kanren import eq, conso, var - - import aesara.tensor as at - from aesara.graph.kanren import KanrenRelationSub - - - def relation(in_lv, out_lv): - # A `kanren` goal that changes `at.log` terms to `at.exp` - cdr_lv = var() - return eq(conso(at.log, cdr_lv, in_lv), - conso(at.exp, cdr_lv, out_lv)) - - - kanren_sub_opt = KanrenRelationSub(relation) - - """ - - reentrant = True - - def __init__( - self, - kanren_relation: Callable[[Variable, Var], Callable], - results_filter: Optional[ - Callable[[Iterator], Optional[List[Union[ExpressionTuple, Variable]]]] - ] = None, - node_filter: Callable[[Apply], bool] = lambda x: True, - ): - r"""Create a `KanrenRelationSub`. - - Parameters - ---------- - kanren_relation - A function that takes an input graph and an output logic variable and - returns a `kanren` goal. - results_filter - A function that takes the direct output of `kanren.run(None, ...)` - and returns a single result. The default implementation returns - the first result. - node_filter - A function taking a single node and returns ``True`` when the node - should be processed. - """ - if results_filter is None: - - def results_filter( - x: Iterator, - ) -> Optional[List[Union[ExpressionTuple, Variable]]]: - return next(x, None) - - self.kanren_relation = kanren_relation - self.results_filter = results_filter - self.node_filter = node_filter - super().__init__() - - def transform(self, fgraph, node): - if self.node_filter(node) is False: - return False - - try: - input_expr = node.default_output() - except ValueError: - input_expr = node.outputs - - q = var() - kanren_results = run(None, q, self.kanren_relation(input_expr, q)) - - chosen_res = self.results_filter(kanren_results) - - if chosen_res: - if isinstance(chosen_res, list): - new_outputs = [eval_if_etuple(v) for v in chosen_res] - else: - new_outputs = [eval_if_etuple(chosen_res)] - - return new_outputs - else: - return False +from aesara.graph.rewriting.kanren import * # noqa: F401 E402 F403 diff --git a/aesara/graph/opt.py b/aesara/graph/opt.py index 2d1a2ac444..851c721815 100644 --- a/aesara/graph/opt.py +++ b/aesara/graph/opt.py @@ -1,3032 +1,29 @@ -""" -Defines the base class for optimizations as well as a certain -amount of useful generic optimization tools. - -""" -import abc -import copy -import functools -import inspect -import logging -import pdb -import sys -import time -import traceback import warnings -from collections import UserList, defaultdict, deque -from collections.abc import Iterable -from functools import _compose_mro, partial, reduce # type: ignore -from itertools import chain -from typing import Dict, List, Optional, Sequence, Tuple, Union - -import aesara -from aesara.configdefaults import config -from aesara.graph import destroyhandler as dh -from aesara.graph.basic import ( - Apply, - AtomicVariable, - Constant, - Variable, - applys_between, - io_toposort, - vars_between, -) -from aesara.graph.features import AlreadyThere, Feature, NodeFinder -from aesara.graph.fg import FunctionGraph -from aesara.graph.op import Op -from aesara.graph.utils import AssocList, InconsistencyError -from aesara.misc.ordered_set import OrderedSet -from aesara.utils import flatten - - -_logger = logging.getLogger("aesara.graph.opt") - - -class LocalMetaOptimizerSkipAssertionError(AssertionError): - """This is an AssertionError, but instead of having the - LocalMetaOptimizer print the error, it just skip that - compilation. - - """ - - -class Rewriter(abc.ABC): - """Abstract base class for graph/term rewriters.""" - - name: Optional[str] = None - - @abc.abstractmethod - def add_requirements(self, fgraph: FunctionGraph): - r"""Add `Feature`\s and other requirements to a `FunctionGraph`.""" - - @abc.abstractmethod - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - """Print a single-line, indented representation of the rewriter.""" - - def __eq__(self, other): - return self is other - - def __hash__(self): - return id(self) - - -class GlobalOptimizer(Rewriter): - """A optimizer that can be applied to a `FunctionGraph` in order to transform it. - - It can represent an optimization or, in general, any kind of transformation - one could apply to a `FunctionGraph`. - - """ - - @abc.abstractmethod - def apply(self, fgraph): - """Apply the optimization to a `FunctionGraph`. - - It may use all the methods defined by the `FunctionGraph`. If the - `GlobalOptimizer` needs to use a certain tool, such as an - `InstanceFinder`, it can do so in its `add_requirements` method. - - """ - raise NotImplementedError() - - def optimize(self, fgraph, *args, **kwargs): - """ - - This is meant as a shortcut for the following:: - - opt.add_requirements(fgraph) - opt.apply(fgraph) - - """ - self.add_requirements(fgraph) - ret = self.apply(fgraph, *args, **kwargs) - return ret - - def __call__(self, fgraph): - """Optimize a `FunctionGraph`. - - This is the same as ``self.optimize(fgraph)``. - - """ - return self.optimize(fgraph) - - def add_requirements(self, fgraph): - ... - - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - name = getattr(self, "name", None) - print( - f"{' ' * level}{self.__class__.__name__} {name} id={id(self)}", - file=stream, - ) - - @staticmethod - def print_profile(stream, prof, level=0): - if prof is not None: - raise NotImplementedError( - "The function print_profile must be overridden if the" - " optimizer return profiling information." - ) - - -class LocalOptimizer(Rewriter): - """A node-based optimizer.""" - - def tracks(self): - """Return the list of `Op` classes to which this optimization applies. - - Returns ``None`` when the optimization applies to all nodes. - - """ - return None - - @abc.abstractmethod - def transform( - self, fgraph: FunctionGraph, node: Apply, *args, **kwargs - ) -> Union[bool, List[Variable], Dict[Variable, Variable]]: - r"""Transform a subgraph whose output is `node`. - - Subclasses should implement this function so that it returns one of the - following: - - - ``False`` to indicate that no optimization can be applied to this `node`; - - A list of `Variable`\s to use in place of the `node`'s current outputs. - - A ``dict`` mapping old `Variable`\s to `Variable`\s. - - - Parameters - ---------- - fgraph : - A `FunctionGraph` containing `node`. - node : - An `Apply` node to be transformed. - - """ - - raise NotImplementedError() - - def add_requirements(self, fgraph): - r"""Add required `Feature`\s to `fgraph`.""" - - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream) - - -class FromFunctionOptimizer(GlobalOptimizer): - """A `GlobalOptimizer` constructed from a given function.""" - - def __init__(self, fn, requirements=()): - self.fn = fn - self.requirements = requirements - - def apply(self, *args, **kwargs): - return self.fn(*args, **kwargs) - - def add_requirements(self, fgraph): - for req in self.requirements: - req(fgraph) - - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - print(f"{' ' * level}{self.apply} id={id(self)}", file=stream) - - def __call__(self, *args, **kwargs): - return self.fn(*args, **kwargs) - - def __str__(self): - return self.__name__ - - -def optimizer(f): - """Decorator for `FromFunctionOptimizer`.""" - rval = FromFunctionOptimizer(f) - rval.__name__ = f.__name__ - return rval - - -def inplace_optimizer(f): - """Decorator for `FromFunctionOptimizer` that also adds the `DestroyHandler` features.""" - dh_handler = dh.DestroyHandler - requirements = (lambda fgraph: fgraph.attach_feature(dh_handler()),) - rval = FromFunctionOptimizer(f, requirements) - rval.__name__ = f.__name__ - return rval - - -class SeqOptimizer(GlobalOptimizer, UserList): - """A `GlobalOptimizer` that applies a list of optimizers sequentially.""" - - @staticmethod - def warn(exc, self, optimizer): - """Default ``failure_callback`` for `SeqOptimizer`.""" - _logger.error(f"SeqOptimizer apply {optimizer}") - _logger.error("Traceback:") - _logger.error(traceback.format_exc()) - if config.on_opt_error == "raise": - raise exc - elif config.on_opt_error == "pdb": - pdb.post_mortem(sys.exc_info()[2]) - - def __init__(self, *opts, failure_callback=None): - """ - Parameters - ---------- - *opts : - The List of optimizers to be applied to a node - failure_callback : callable or None - Keyword only argument. A callback used when a failure - happen during optimization. - - """ - if len(opts) == 1 and isinstance(opts[0], (list, tuple)): - opts = opts[0] - - super().__init__(opts) - - self.failure_callback = failure_callback - - def apply(self, fgraph): - """Applies each `GlobalOptimizer` in ``self.data`` to `fgraph`.""" - l = [] - if fgraph.profile: - validate_before = fgraph.profile.validate_time - sub_validate_time = [validate_before] - callbacks_before = fgraph.execute_callbacks_times.copy() - else: - sub_validate_time = [] - callbacks_before = [] - callback_before = fgraph.execute_callbacks_time - nb_node_before = len(fgraph.apply_nodes) - sub_profs = [] - nb_nodes = [] - - self.pre_profile = ( - self, - l, - -1, - -1, - nb_node_before, - -1, - sub_profs, - sub_validate_time, - nb_nodes, - {}, - ) - try: - for optimizer in self.data: - try: - nb_nodes_before = len(fgraph.apply_nodes) - t0 = time.time() - sub_prof = optimizer.apply(fgraph) - l.append(float(time.time() - t0)) - sub_profs.append(sub_prof) - nb_nodes.append((nb_nodes_before, len(fgraph.apply_nodes))) - if fgraph.profile: - sub_validate_time.append(fgraph.profile.validate_time) - except AssertionError: - # do not catch Assertion failures - raise - except Exception as e: - if self.failure_callback: - self.failure_callback(e, self, optimizer) - continue - else: - raise - finally: - - if fgraph.profile: - validate_time = fgraph.profile.validate_time - validate_before - callbacks_time = {} - for k, v in fgraph.execute_callbacks_times.items(): - if k in callbacks_before: - t = v - callbacks_before[k] - if t > 0: - callbacks_time[k] = t - else: - callbacks_time[k] = v - else: - validate_time = None - callbacks_time = {} - callback_time = fgraph.execute_callbacks_time - callback_before - self.pre_profile = ( - self, - l, - validate_time, - callback_time, - nb_node_before, - len(fgraph.apply_nodes), - sub_profs, - sub_validate_time, - nb_nodes, - callbacks_time, - ) - return self.pre_profile - - def __repr__(self): - return f"SeqOpt({self.data})" - - def add_requirements(self, fgraph): - for opt in self.data: - opt.add_requirements(fgraph) - - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - name = getattr(self, "name", None) - print( - f"{' ' * level}{self.__class__.__name__} {name} id={id(self)}", file=stream - ) - # This way, -1 will do all depth - if depth != 0: - depth -= 1 - for opt in self.data: - opt.print_summary(stream, level=(level + 2), depth=depth) - - @staticmethod - def print_profile(stream, prof, level=0): - ( - opts, - prof, - validate_time, - callback_time, - nb_node_before, - nb_node_after, - sub_profs, - sub_validate_time, - nb_nodes, - callbacks_time, - ) = prof - - validate_time = validate_time or float("nan") - callback_time = callback_time or float("nan") - - blanc = " " * level - - print(blanc, "SeqOptimizer", end=" ", file=stream) - if hasattr(opts, "name"): - print(blanc, opts.name, end=" ", file=stream) - elif hasattr(opts, "__name__"): - print(blanc, opts.__name__, end=" ", file=stream) - print( - ( - f" time {sum(prof):.3f}s for {int(nb_node_before)}/{int(nb_node_after)} nodes" - " before/after optimization" - ), - file=stream, - ) - print(blanc, f" {callback_time:.3f}s for callback", file=stream) - print(blanc, f" {validate_time:.3f}s for fgraph.validate()", file=stream) - if callback_time > 1: - print(blanc, " callbacks_time", file=stream) - for i in sorted(callbacks_time.items(), key=lambda a: -a[1]): - if i[1] > 0: - # We want to have the __str__ called, so we can't - # just print i. - print(blanc, " ", i[0], ",", i[1], file=stream) - - if level == 0: - print( - blanc, - " time - (name, class, index, nodes before, nodes after) - validate time", - file=stream, - ) - ll = [] - for (opt, nb_n) in zip(opts, nb_nodes): - if hasattr(opt, "__name__"): - name = opt.__name__ - else: - name = opt.name - idx = opts.index(opt) - ll.append((name, opt.__class__.__name__, idx) + nb_n) - lll = sorted(zip(prof, ll), key=lambda a: a[0]) - - for (t, opt) in lll[::-1]: - i = opt[2] - if sub_validate_time: - val_time = sub_validate_time[i + 1] - sub_validate_time[i] - print( - blanc, - f" {t:.6f}s - {opt} - {val_time:.3f}s", - file=stream, - ) - else: - print(blanc, f" {t:.6f}s - {opt}", file=stream) - - if sub_profs[i]: - opts[i].print_profile(stream, sub_profs[i], level=level + 1) - print(file=stream) - - @staticmethod - def merge_profile(prof1, prof2): - """Merge two profiles.""" - new_t = [] # the time for the optimization - new_l = [] # the optimization - new_sub_profile = [] - # merge common(same object) opt - for l in set(prof1[0]).intersection(set(prof2[0])): - idx1 = prof1[0].index(l) - idx2 = prof2[0].index(l) - new_t.append(prof1[1][idx1] + prof2[1][idx2]) - new_l.append(l) - if hasattr(l, "merge_profile"): - assert len(prof1[6][idx1]) == len(prof2[6][idx2]) - new_sub_profile.append(l.merge_profile(prof1[6][idx1], prof2[6][idx2])) - else: - new_sub_profile.append(None) - - # merge not common opt - from io import StringIO - - for l in set(prof1[0]).symmetric_difference(set(prof2[0])): - # The set trick above only work for the same object optimization - # It don't work for equivalent optimization. - # So we try to merge equivalent optimization here. - new_l_names = [o.name for o in new_l] - if l.name in new_l_names: - idx = new_l_names.index(l.name) - io1 = StringIO() - io2 = StringIO() - l.print_summary(io1) - new_l[idx].print_summary(io2) - if io1.read() == io2.read(): - if l in prof1[0]: - p = prof1 - else: - p = prof2 - new_t[idx] += p[1][p[0].index(l)] - if hasattr(l, "merge_profile"): - assert len(p[6][p[0].index(l)]) == len(new_sub_profile[idx]) - new_sub_profile[idx] = l.merge_profile( - new_sub_profile[idx], p[6][p[0].index(l)] - ) - else: - new_sub_profile[idx] = None - continue - if l in prof1[0]: - p = prof1 - else: - p = prof2 - new_t.append(p[1][p[0].index(l)]) - idx = p[0].index(l) - new_l.append(l) - new_sub_profile.append(p[6][idx]) - - new_opt = SeqOptimizer(*new_l) - new_nb_nodes = [] - for p1, p2 in zip(prof1[8], prof2[8]): - new_nb_nodes.append((p1[0] + p2[0], p1[1] + p2[1])) - new_nb_nodes.extend(prof1[8][len(new_nb_nodes) :]) - new_nb_nodes.extend(prof2[8][len(new_nb_nodes) :]) - - new_callbacks_times = merge_dict(prof1[9], prof2[9]) - # We need to assert based on the name as we merge also based on - # the name. - assert {l.name for l in prof1[0]}.issubset({l.name for l in new_l}) - assert {l.name for l in prof2[0]}.issubset({l.name for l in new_l}) - assert len(new_t) == len(new_opt) == len(new_sub_profile) - return ( - new_opt, - new_t, - prof1[2] + prof2[2], - prof1[3] + prof2[3], - -1, - -1, - new_sub_profile, - [], - new_nb_nodes, - new_callbacks_times, - ) - - -class MergeFeature(Feature): - """Keeps track of variables in a `FunctionGraph` that cannot be merged together. - - That way, the `MergeOptimizer` can remember the result of the last - merge-pass on the `FunctionGraph`. - - """ - - def on_attach(self, fgraph): - if hasattr(fgraph, "merge_feature"): - raise AlreadyThere() - - fgraph.merge_feature = self - - self.seen_atomics = set() - self.atomic_sig = AssocList() - self.atomic_sig_inv = AssocList() - - # For all Apply nodes - # Set of distinct (not mergeable) nodes - self.nodes_seen = set() - # Ordered set of distinct (not mergeable) nodes without any input - self.noinput_nodes = OrderedSet() - - # Each element of scheduled is a list of list of (out, new_out) pairs. - # Each list of pairs represent the substitution needed to replace all - # the outputs of a node with the outputs of a replacement candidate. - # Each node can have several candidates. For instance, if "node" has - # 2 outputs, and there are 3 replacement candidates, we will have: - # shelf.scheduled = [ - # [[(node.out1, cand1.out1), (node.out2, cand1.out2)], - # [(node.out1, cand2.out1), (node.out2, cand2.out2)], - # [(node.out1, cand3.out1), (node.out2, cand3.out2)]]] - self.scheduled = [] - - # List of (node, candidate) pairs, where we tried to replace node by - # candidate, but it failed. This is used to avoid infinite loops - # during the replacement phase. - self.blacklist = [] - - for node in fgraph.toposort(): - self.on_import(fgraph, node, "on_attach") - - def clone(self): - return type(self)() - - def on_change_input(self, fgraph, node, i, r, new_r, reason): - if node in self.nodes_seen: - # If inputs to a node change, it's not guaranteed that the node is - # distinct from the other nodes in `self.nodes_seen`. - self.nodes_seen.discard(node) - self.process_node(fgraph, node) - - if isinstance(new_r, AtomicVariable): - self.process_atomic(fgraph, new_r) - - def on_import(self, fgraph, node, reason): - for c in node.inputs: - if isinstance(c, AtomicVariable): - self.process_atomic(fgraph, c) - - self.process_node(fgraph, node) - - def on_prune(self, fgraph, node, reason): - self.nodes_seen.discard(node) - if not node.inputs: - self.noinput_nodes.discard(node) - for c in node.inputs: - if isinstance(c, AtomicVariable) and len(fgraph.clients[c]) <= 1: - # This was the last node using this constant - sig = self.atomic_sig[c] - self.atomic_sig.discard(c) - self.atomic_sig_inv.discard(sig) - self.seen_atomics.discard(id(c)) - - def process_atomic(self, fgraph, c): - """Check if an atomic `c` can be merged, and queue that replacement.""" - if id(c) in self.seen_atomics: - return - sig = c.merge_signature() - other_c = self.atomic_sig_inv.get(sig, None) - if other_c is not None: - # multiple names will clobber each other.. - # we adopt convention to keep the last name - if c.name: - other_c.name = c.name - self.scheduled.append([[(c, other_c, "merge")]]) - else: - # this is a new constant - self.atomic_sig[c] = sig - self.atomic_sig_inv[sig] = c - self.seen_atomics.add(id(c)) - - def process_node(self, fgraph, node): - r"""Check if a `node` can be merged, and queue that replacement. - - When `node` is changed we check for other nodes (via the clients map) - that depend on the same inputs. If any of those other nodes have the - same inputs and `Op` as `node`, they are queued to be merged. - - """ - - if node in self.nodes_seen: - return - - if node.inputs: - # We use the smallest clients list. Some `Op`s like `Elemwise` - # have optimizations that put constants as the first inputs. Since - # constants generally have more clients than other types of nodes, - # using `node.inputs[0]` will make us look at more nodes on - # average, so by picking the smallest clients list, we might speed - # things up? - - clients = sorted( - (fgraph.clients[inp] for inp in node.inputs), key=lambda x: len(x) - )[0] - assert len(clients) > 0 - - merge_candidates = [c for c, i in clients if c in self.nodes_seen] - else: - # If two nodes have no input, but perform the same operation, - # they are not always constant-folded, so we want to merge them. - # In that case, the candidates are all the nodes without inputs. - merge_candidates = self.noinput_nodes - - replacement_candidates = [] - for candidate in merge_candidates: - - if candidate is node: - continue - if len(node.inputs) != len(candidate.inputs): - continue - - inputs_match = all( - node_in is cand_in - for node_in, cand_in in zip(node.inputs, candidate.inputs) - ) - - if inputs_match and node.op == candidate.op: - if (node, candidate) in self.blacklist: - # They were already tried, and there was an error - continue - - # Schedule transfer of clients from node to candidate - pairs = list( - zip( - node.outputs, - candidate.outputs, - ["merge"] * len(node.outputs), - ) - ) - - replacement_candidates.append(pairs) - - if replacement_candidates: - self.scheduled.append(replacement_candidates) - else: - self.nodes_seen.add(node) - if not node.inputs: - self.noinput_nodes.add(node) - - -class MergeOptimizer(GlobalOptimizer): - r"""Merges parts of the graph that are identical and redundant. - - The basic principle is that if two `Apply`\s have `Op`\s that compare equal, and - identical inputs, then they do not both need to be computed. The clients of - one are transferred to the other and one of them is removed from the graph. - This procedure is carried out in input-to-output order throughout the graph. - - The first step of merging is atomic variable-merging, so that all clients of a - :class:`Constant` like ``int(1)``, are transferred to just one particular - instance of ``int(1)``. :class:`NominalVariable`\s are not merged individually - like this; only the nodes that use them are. - - """ - - def add_requirements(self, fgraph): - if not hasattr(fgraph, "merge_feature"): - fgraph.attach_feature(MergeFeature()) - - def apply(self, fgraph): - sched = fgraph.merge_feature.scheduled - nb_fail = 0 - t0 = time.time() - if fgraph.profile: - validate_before = fgraph.profile.validate_time - callback_before = fgraph.execute_callbacks_time - callbacks_before = fgraph.execute_callbacks_times.copy() - - nb_merged = 0 - nb_atomic = 0 - while sched: - pairs_list = sched.pop() - success = True - for pairs_ in pairs_list: - # We must check again the equivalence, as the graph could've - # changed. If so, doing the replacement can introduce a node - # that depends on itself. Doing the full check of such cycles - # every time is very time consuming. I think this double check - # is faster than doing the full cycle check. The full cycle - # check is skipped by `Validator.validate` if the graph doesn't - # contain destroyers. - var, candidate_var, merge_mode = pairs_[0] - if merge_mode == "new_node" and var in fgraph.variables: - pass - elif ( - var not in fgraph.variables or candidate_var not in fgraph.variables - ): - continue - - # Keep len(item) == 2 for item in pairs - pairs = [pair[:2] for pair in pairs_] - - if var.owner and candidate_var.owner: - if merge_mode == "new_node": - inputs_match = True - else: - inputs_match = all( - node_in is cand_in - for node_in, cand_in in zip( - var.owner.inputs, candidate_var.owner.inputs - ) - ) - - # No need to compare the op again, as it don't change. - if not inputs_match: - continue - - if hasattr(fgraph, "destroy_handler"): - # If both nodes have clients that destroy them, we - # can't merge them. - clients = ( - fgraph.clients[pairs[0][0]] + fgraph.clients[pairs[0][1]] - ) - if any( - i in flatten(c.op.destroy_map.values()) - for c, i in clients - if c != "output" and c.op.destroy_map - ): - continue - - if len(pairs) == 1 and pairs[0][0].type != pairs[0][1].type: - res = pairs[0][0].type.convert_variable(pairs[0][1]) - - # Since the fgraph.replace only checks the convert_variable - # in one way, we change the order in the case that - # convert_variable will not be successful. - if not res: - pairs = [(pairs[0][1], pairs[0][0])] - - try: - # If they're all `AtomicVariable`s, there's no need to call validate. - if all(isinstance(old, AtomicVariable) for old, _ in pairs): - fgraph.replace_all(pairs, reason="MergeOptimizer") - else: - fgraph.replace_all_validate(pairs, reason="MergeOptimizer") - except InconsistencyError: - success = False - nb_fail += 1 - fgraph.merge_feature.blacklist.append( - (pairs[0][0].owner, pairs[0][1].owner) - ) - - if success: - nb_merged += len(pairs) - if isinstance(pairs[0][0], AtomicVariable): - nb_atomic += 1 - break - - if fgraph.profile: - validate_time = fgraph.profile.validate_time - validate_before - callback_time = fgraph.execute_callbacks_time - callback_before - callbacks_time = {} - for k, v in fgraph.execute_callbacks_times.items(): - if k in callbacks_before: - t = v - callbacks_before[k] - if t > 0: - callbacks_time[k] = t - else: - callbacks_time[k] = v - else: - validate_time = None - callback_time = None - callbacks_time = {} - - fgraph.merge_feature.blacklist = [] - - return ( - nb_fail, - time.time() - t0, - validate_time, - callback_time, - callbacks_time, - nb_merged, - nb_atomic, - ) - - def __str__(self): - return self.__class__.__name__ - - @staticmethod - def print_profile(stream, prof, level=0): - - ( - nb_fail, - replace_time, - validate_time, - callback_time, - callbacks_time, - nb_merged, - nb_atomic, - ) = prof - - validate_time = validate_time or float("nan") - callback_time = callback_time or float("nan") - - blanc = " " * level - print(blanc, "MergeOptimizer", file=stream) - print( - blanc, - f" nb fail={nb_fail:5d} merged={nb_merged:5d} atomic={nb_atomic:5d}", - file=stream, - ) - print( - blanc, - f" time replace={replace_time:2.2f} validate={validate_time:2.2f} callback={callback_time:2.2f}", - file=stream, - ) - if callback_time > 1: - print(blanc, " callbacks_time", file=stream) - for i in sorted(callbacks_time.items(), key=lambda a: a[1]): - if i[1] > 0: - # We want to have the __str__ called, so we can't - # just print i. - print(blanc, " ", i[0], ",", i[1], file=stream) - - @staticmethod - def merge_profile(prof1, prof2): - def merge_none_number(v1, v2): - if v1 is None: - return v2 - if v2 is None: - return v1 - return v1 + v2 - - nb_fail = prof1[0] + prof2[0] - replace_time = prof1[1] + prof2[1] - validate_time = merge_none_number(prof1[2], prof2[2]) - callback_time = merge_none_number(prof1[3], prof2[3]) - callbacks_time = merge_dict(prof1[4], prof2[4]) - nb_merged = prof1[5] + prof2[5] - nb_atomic = prof1[6] + prof2[6] - return ( - nb_fail, - replace_time, - validate_time, - callback_time, - callbacks_time, - nb_merged, - nb_atomic, - ) - - -def pre_constant_merge(fgraph, variables): - """Merge constants in the graphs given by `variables`. - - .. warning:: - - This changes the nodes in a graph in-place! - - Parameters - ---------- - fgraph - A `FunctionGraph` instance in which some of these `variables` may - reside. - - We want to avoid terms in `variables` that are contained in `fgraph`. - The reason for that: it will break consistency of `fgraph` and its - features (e.g. `ShapeFeature`). - - variables - A list of nodes for which we want to merge constant inputs. - - Notes - ----- - It is used to pre-merge nodes generated inside an optimization. It is - useful if there are many such replacements to make, so that `DebugMode` - will not check each of them. - - """ - seen_var = set() - # signature -> variable (for constants) - const_sig_inv = {} - if isinstance(variables, Variable): - variables = [variables] - - def recursive_merge(var): - - if var in seen_var: - return var - - if not hasattr(var, "owner"): - return var - - # We don't want to merge constants that are *within* the - # `FunctionGraph` - if var.owner in fgraph.apply_nodes: - return var - - seen_var.add(var) - - if isinstance(var, Constant): - sig = var.signature() - - if sig in const_sig_inv: - return const_sig_inv[sig] - - const_sig_inv[sig] = var - - return var - - if var.owner: - for idx, inp in enumerate(var.owner.inputs): - # XXX: This is changing the graph in place! - var.owner.inputs[idx] = recursive_merge(inp) - return var - - return [recursive_merge(v) for v in variables] - - -class LocalMetaOptimizer(LocalOptimizer): - r""" - Base class for meta-optimizers that try a set of `LocalOptimizer`\s - to replace a node and choose the one that executes the fastest. - - If the error ``LocalMetaOptimizerSkipAssertionError`` is raised during - compilation, we will skip that function compilation and not print - the error. - - """ - - def __init__(self): - self.verbose = config.metaopt__verbose - self.track_dict = defaultdict(lambda: []) - self.tag_dict = defaultdict(lambda: []) - self._tracks = [] - self.optimizers = [] - - def register(self, optimizer, tag_list): - self.optimizers.append(optimizer) - for c in optimizer.tracks(): - self.track_dict[c].append(optimizer) - self._tracks.append(c) - for tag in tag_list: - self.tag_dict[tag].append(optimizer) - - def tracks(self): - return self._tracks - - def transform(self, fgraph, node, *args, **kwargs): - # safety check: depending on registration, tracks may have been ignored - if self._tracks is not None: - if not isinstance(node.op, tuple(self._tracks)): - return - # first, we need to provide dummy values for all inputs - # to the node that are not shared variables anyway - givens = {} - missing = set() - for input in node.inputs: - if isinstance(input, aesara.compile.SharedVariable): - pass - elif hasattr(input.tag, "test_value"): - givens[input] = aesara.shared( - input.type.filter(input.tag.test_value), - input.name, - shape=input.broadcastable, - borrow=True, - ) - else: - missing.add(input) - if missing: - givens.update(self.provide_inputs(node, missing)) - missing.difference_update(givens.keys()) - # ensure we have data for all input variables that need it - if missing: - if self.verbose > 0: - print( - f"{self.__class__.__name__} cannot meta-optimize {node}, " - f"{len(missing)} of {int(node.nin)} input shapes unknown" - ) - return - # now we can apply the different optimizations in turn, - # compile the resulting subgraphs and time their execution - if self.verbose > 1: - print( - f"{self.__class__.__name__} meta-optimizing {node} ({len(self.get_opts(node))} choices):" - ) - timings = [] - for opt in self.get_opts(node): - outputs = opt.transform(fgraph, node, *args, **kwargs) - if outputs: - try: - fn = aesara.function( - [], outputs, givens=givens, on_unused_input="ignore" - ) - fn.trust_input = True - timing = min(self.time_call(fn) for _ in range(2)) - except LocalMetaOptimizerSkipAssertionError: - continue - except Exception as e: - if self.verbose > 0: - print(f"* {opt}: exception", e) - continue - else: - if self.verbose > 1: - print(f"* {opt}: {timing:.5g} sec") - timings.append((timing, outputs, opt)) - else: - if self.verbose > 0: - print(f"* {opt}: not applicable") - # finally, we choose the fastest one - if timings: - timings.sort() - if self.verbose > 1: - print(f"= {timings[0][2]}") - return timings[0][1] - return - - def provide_inputs(self, node, inputs): - """Return a dictionary mapping some `inputs` to `SharedVariable` instances of with dummy values. - - The `node` argument can be inspected to infer required input shapes. - - """ - raise NotImplementedError() - - def get_opts(self, node): - """Return the optimizations that apply to `node`. - - This uses ``self.track_dict[type(node.op)]`` by default. - """ - return self.track_dict[type(node.op)] - - def time_call(self, fn): - start = time.time() - fn() - return time.time() - start - - -class FromFunctionLocalOptimizer(LocalOptimizer): - """A `LocalOptimizer` constructed from a function.""" - - def __init__(self, fn, tracks=None, requirements=()): - self.fn = fn - self._tracks = tracks - self._tracked_types = ( - tuple(t for t in tracks if isinstance(t, type)) if tracks else () - ) - self.requirements = requirements - - def transform(self, fgraph, node): - if self._tracks: - if not ( - node.op in self._tracks or isinstance(node.op, self._tracked_types) - ): - return False - - return self.fn(fgraph, node) - - def add_requirements(self, fgraph): - for req in self.requirements: - req(fgraph) - - def tracks(self): - return self._tracks - - def __str__(self): - return getattr(self, "__name__", repr(self)) - - def __repr__(self): - return f"FromFunctionLocalOptimizer({repr(self.fn)}, {repr(self._tracks)}, {repr(self.requirements)})" - - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - print(f"{' ' * level}{self.transform} id={id(self)}", file=stream) - - -def local_optimizer( - tracks: Optional[Sequence[Union[Op, type]]], - inplace: bool = False, - requirements: Optional[Tuple[type, ...]] = (), -): - r"""A decorator used to construct `FromFunctionLocalOptimizer` instances. - - Parameters - ---------- - tracks - The `Op` types or instances to which this optimization applies. - Use ``None`` instead of an empty list to have the optimization apply to - all `Op`s`. - inplace - A boolean indicating whether or not the optimization works in-place. - If ``True``, a `DestroyHandler` `Feature` is added automatically added - to the `FunctionGraph`\s applied to this optimization. - requirements - `Feature` types required by this optimization. - - """ - - if requirements is None: - requirements = () - - def decorator(f): - if tracks is not None: - if len(tracks) == 0: - raise ValueError( - "Use `None` instead of an empty list to make an optimization apply to all nodes." - ) - for t in tracks: - if not ( - isinstance(t, Op) or (isinstance(t, type) and issubclass(t, Op)) - ): - raise TypeError( - "`tracks` must consist of `Op` classes or instances." - ) - req = requirements - if inplace: - dh_handler = dh.DestroyHandler - req = tuple(requirements) + ( - lambda fgraph: fgraph.attach_feature(dh_handler()), - ) - rval = FromFunctionLocalOptimizer(f, tracks, req) - rval.__name__ = f.__name__ - return rval - - return decorator - - -class LocalOptTracker: - r"""A container that maps rewrites to `Op` instances and `Op`-type inheritance.""" - - def __init__(self): - self.tracked_instances: Dict[Op, List[LocalOptimizer]] = {} - self.tracked_types: Dict[type, List[LocalOptimizer]] = {} - self.untracked_opts: List[LocalOptimizer] = [] - - def add_tracker(self, rw: LocalOptimizer): - """Add a `LocalOptimizer` to be keyed by its `LocalOptimizer.tracks` or applied generally.""" - tracks = rw.tracks() - - if tracks is None: - self.untracked_opts.append(rw) - else: - for c in tracks: - if isinstance(c, type): - self.tracked_types.setdefault(c, []).append(rw) - else: - self.tracked_instances.setdefault(c, []).append(rw) - - def _find_impl(self, cls) -> List[LocalOptimizer]: - r"""Returns the `LocalOptimizer`\s that apply to `cls` based on inheritance. - - This based on `functools._find_impl`. - """ - mro = _compose_mro(cls, self.tracked_types.keys()) - matches = [] - for t in mro: - match = self.tracked_types.get(t, None) - if match: - matches.extend(match) - return matches - - @functools.lru_cache() - def get_trackers(self, op: Op) -> List[LocalOptimizer]: - """Get all the rewrites applicable to `op`.""" - return ( - self._find_impl(type(op)) - + self.tracked_instances.get(op, []) - + self.untracked_opts - ) - - def get_rewriters(self): - return chain( - chain.from_iterable( - chain(self.tracked_types.values(), self.tracked_instances.values()) - ), - self.untracked_opts, - ) - - -class LocalOptGroup(LocalOptimizer): - r"""An optimizer that applies a list of `LocalOptimizer`\s to a node. - - Attributes - ---------- - reentrant : bool - Some global optimizers, like `NavigatorOptimizer`, use this value to - determine if they should ignore new nodes. - retains_inputs : bool - States whether or not the inputs of a transformed node are transferred - to the outputs. - """ - - def __init__( - self, - *optimizers: Rewriter, - apply_all_opts: bool = False, - profile: bool = False, - ): - """ - - Parameters - ---------- - optimizers - A list of optimizers to be applied to nodes. - apply_all_opts - If ``False``, it will return after the first successfully applied - rewrite; otherwise, it will apply every applicable rewrite - incrementally. - profile - Whether or not to profile the optimizations. - - """ - super().__init__() - - self.opts: Sequence[Rewriter] = optimizers - assert isinstance(self.opts, tuple) - - self.reentrant = any(getattr(opt, "reentrant", True) for opt in optimizers) - self.retains_inputs = all( - getattr(opt, "retains_inputs", False) for opt in optimizers - ) - - self.apply_all_opts = apply_all_opts - - self.profile = profile - if self.profile: - self.time_opts: Dict[Rewriter, float] = {} - self.process_count: Dict[Rewriter, int] = {} - self.applied_true: Dict[Rewriter, int] = {} - self.node_created: Dict[Rewriter, int] = {} - - self.tracker = LocalOptTracker() - - for o in self.opts: - - self.tracker.add_tracker(o) - - if self.profile: - self.time_opts.setdefault(o, 0.0) - self.process_count.setdefault(o, 0) - self.applied_true.setdefault(o, 0) - self.node_created.setdefault(o, 0) - - def __str__(self): - return getattr( - self, - "__name__", - f"LocalOptGroup({','.join([str(o) for o in self.opts])})", - ) - - def tracks(self): - t = [] - for l in self.opts: - at = l.tracks() - if at: - t.extend(at) - return t - - def transform(self, fgraph, node): - if len(self.opts) == 0: - return - - repl = None - - while True: - opts = self.tracker.get_trackers(node.op) - - new_repl = None - for opt in opts: - opt_start = time.time() - new_repl = opt.transform(fgraph, node) - opt_finish = time.time() - if self.profile: - self.time_opts[opt] += opt_start - opt_finish - self.process_count[opt] += 1 - if not new_repl: - continue - if isinstance(new_repl, (tuple, list)): - new_vars = new_repl - else: # It must be a dict - new_vars = list(new_repl.values()) - - if config.optimizer_verbose: - print( - f"optimizer: rewrite {opt} replaces node {node} with {new_repl}" - ) - - if self.profile: - self.node_created[opt] += len( - list(applys_between(fgraph.variables, new_vars)) - ) - self.applied_true[opt] += 1 - break # break from the for loop over optimization. - if not new_repl: # No optimization applied in the last iteration - return repl - # only 1 iteration - if not self.apply_all_opts: - return new_repl - if not new_vars[0].owner: - # We are at the start of the graph. - return new_repl - if len(new_repl) > 1: - s = {v.owner for v in new_repl} - assert len(s) == 1 - repl = new_repl - node = new_vars[0].owner - - @staticmethod - def print_profile(stream, prof, level=0): - (time_opts, process_count, applied_true, node_created, profile) = prof - - if not profile: - return - - blanc = " " * int(level) - print(blanc, "LocalOptGroup", file=stream) - print(blanc, "---------------------", file=stream) - count_opt = [] - not_used = [] - not_used_time = 0 - for o, count in process_count.items(): - if count > 0: - count_opt.append( - (time_opts[o], applied_true[o], count, o, node_created[o]) - ) - else: - not_used.append((time_opts[o], o)) - not_used_time += time_opts[o] - if count_opt: - print( - blanc, - " time taken - times applied - times tried - name - node_created:", - file=stream, - ) - count_opt.sort() - for (t, a_t, count, o, n_c) in count_opt[::-1]: - print( - blanc, - f" {t:.3f}s - {int(a_t)} - {int(count)} - {o} - {int(n_c)}", - file=stream, - ) - print( - blanc, - f" {not_used_time:.3f}s - in {len(not_used)} optimization that were not used (display those with runtime greater than 0)", - file=stream, - ) - not_used.sort(key=lambda nu: (nu[0], str(nu[1]))) - for (t, o) in not_used[::-1]: - if t > 0: - # Skip opt that have 0 times, they probably wasn't even tried. - print(blanc + " ", f" {t:.3f}s - {o}", file=stream) - else: - print(blanc, " The optimizer wasn't successful ", file=stream) - - print(file=stream) - - def merge_profile(prof1, prof2): - raise NotImplementedError - - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream) - if depth != 0: - depth -= 1 - for lopt in self.opts: - lopt.print_summary(stream, level=(level + 2), depth=depth) - - def add_requirements(self, fgraph): - for opt in self.opts: - opt.add_requirements(fgraph) - - -class OpSub(LocalOptimizer): - """ - - Replaces the application of a certain `Op` by the application of - another `Op` that takes the same inputs as what it is replacing. - - Parameters - ---------- - op1, op2 - ``op1.make_node`` and ``op2.make_node`` must take the same number of - inputs and have the same number of outputs. - - Examples - -------- - OpSub(add, sub) ==> - add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x)) - """ - - # an OpSub does not apply to the nodes it produces - reentrant = False - # all the inputs of the original node are transferred to the outputs - retains_inputs = True - - def __init__(self, op1, op2, transfer_tags=True): - self.op1 = op1 - self.op2 = op2 - self.transfer_tags = transfer_tags - - def op_key(self): - return self.op1 - - def tracks(self): - return [self.op1] - - def transform(self, fgraph, node): - if node.op != self.op1: - return False - repl = self.op2.make_node(*node.inputs) - if self.transfer_tags: - repl.tag = copy.copy(node.tag) - for output, new_output in zip(node.outputs, repl.outputs): - new_output.tag = copy.copy(output.tag) - return repl.outputs - - def __str__(self): - return f"{self.op1} -> {self.op2}" - - -class OpRemove(LocalOptimizer): - """ - Removes all applications of an `Op` by transferring each of its - outputs to the corresponding input. - - """ - - reentrant = False # no nodes are added at all - - def __init__(self, op): - self.op = op - - def op_key(self): - return self.op - - def tracks(self): - return [self.op] - - def transform(self, fgraph, node): - if node.op != self.op: - return False - return node.inputs - - def __str__(self): - return f"{self.op}(x) -> x" - - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - print( - f"{' ' * level}{self.__class__.__name__}(self.op) id={id(self)}", - file=stream, - ) - - -class PatternSub(LocalOptimizer): - """Replace all occurrences of an input pattern with an output pattern. - - The input and output patterns have the following syntax: - - input_pattern ::= (op, , , ...) - input_pattern ::= dict(pattern = , - constraint = ) - sub_pattern ::= input_pattern - sub_pattern ::= string - sub_pattern ::= a Constant instance - sub_pattern ::= int - sub_pattern ::= float - constraint ::= lambda fgraph, expr: additional matching condition - - output_pattern ::= (op, , , ...) - output_pattern ::= string - output_pattern ::= int - output_pattern ::= float - - Each string in the input pattern is a variable that will be set to - whatever expression is found in its place. If the same string is - used more than once, the same expression must be found in those - places. If a string used in the input pattern is used in the - output pattern, the matching expression will be inserted in its - place. The input pattern cannot just be a string but the output - pattern can. - - If you put a constant variable in the input pattern, there will be a - match iff a constant variable with the same value and the same type - is found in its place. - - You can add a constraint to the match by using the ``dict(...)`` form - described above with a ``'constraint'`` key. The constraint must be a - function that takes the fgraph and the current Variable that we are - trying to match and returns True or False according to an - arbitrary criterion. - - The constructor creates a `PatternSub` that replaces occurrences of - `in_pattern` by occurrences of `out_pattern`. - - Parameters - ---------- - in_pattern : - The input pattern that we want to replace. - out_pattern : - The replacement pattern. - allow_multiple_clients : bool - If False, the pattern matching will fail if one of the subpatterns has - more than one client. - skip_identities_fn : TODO - name : - Allows to override this optimizer name. - tracks : optional - The values that :meth:`self.tracks` will return. Useful to speed up - optimization sometimes. - get_nodes : optional - If you provide `tracks`, you must provide this parameter. It must be a - function that takes the tracked node and returns a list of nodes on - which we will try this optimizer. - - Notes - ----- - `tracks` and `get_nodes` can be used to make this optimizer track a less - frequent `Op`, so this will make this optimizer tried less frequently. - - Examples - -------- - - PatternSub((add, 'x', 'y'), (add, 'y', 'x')) - PatternSub((multiply, 'x', 'x'), (square, 'x')) - PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x') - PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x')) - PatternSub((boggle, {'pattern': 'x', - 'constraint': lambda expr: expr.type == scrabble}), - (scrabble, 'x')) - - """ - - def __init__( - self, - in_pattern, - out_pattern, - allow_multiple_clients=False, - skip_identities_fn=None, - name=None, - tracks=(), - get_nodes=None, - values_eq_approx=None, - ): - from aesara.graph.unify import convert_strs_to_vars - - var_map = {} - self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map) - self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map) - self.values_eq_approx = values_eq_approx - if isinstance(in_pattern, (list, tuple)): - self.op = self.in_pattern[0] - elif isinstance(in_pattern, dict): - self.op = self.in_pattern["pattern"][0] - else: - raise TypeError( - "The pattern to search for must start with a specific Op instance." - ) - self.__doc__ = ( - self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n" - ) - self.allow_multiple_clients = allow_multiple_clients - self.skip_identities_fn = skip_identities_fn - if name: - self.__name__ = name - self._tracks = tracks - self.get_nodes = get_nodes - if tracks != (): - assert get_nodes - - def op_key(self): - return self.op - - def tracks(self): - if self._tracks != (): - return self._tracks - return [self.op] - - def transform(self, fgraph, node, get_nodes=True): - """Check if the graph from node corresponds to ``in_pattern``. - - If it does, it constructs ``out_pattern`` and performs the replacement. - - """ - from etuples.core import ExpressionTuple - from unification import reify, unify - - # TODO: We shouldn't need to iterate like this. - if not self.allow_multiple_clients and any( - len(fgraph.clients.get(v)) > 1 - for v in vars_between(fgraph.inputs, node.outputs) - if v not in fgraph.inputs - ): - return False - - if get_nodes and self.get_nodes is not None: - for real_node in self.get_nodes(fgraph, node): - if real_node == "output": - continue - ret = self.transform(fgraph, real_node, get_nodes=False) - if ret is not False and ret is not None: - return dict(zip(real_node.outputs, ret)) - - if node.op != self.op: - return False - - s = unify(self.in_pattern, node.out) - - if s is False: - return False - - ret = reify(self.out_pattern, s) - - if isinstance(ret, ExpressionTuple): - ret = ret.evaled_obj - - if self.values_eq_approx: - ret.tag.values_eq_approx = self.values_eq_approx - - if ret.owner: - if not ( - len(node.outputs) == len(ret.owner.outputs) - and all( - o.type.is_super(new_o.type) - for o, new_o in zip(node.outputs, ret.owner.outputs) - ) - ): - return False - else: - # ret is just an input variable - assert len(node.outputs) == 1 - if not node.outputs[0].type.is_super(ret.type): - return False - - return [ret] - - def __str__(self): - if getattr(self, "__name__", None): - return self.__name__ - - def pattern_to_str(pattern): - if isinstance(pattern, (list, tuple)): - return "{}({})".format( - str(pattern[0]), - ", ".join([pattern_to_str(p) for p in pattern[1:]]), - ) - elif isinstance(pattern, dict): - return "{} subject to {}".format( - pattern_to_str(pattern["pattern"]), - str(pattern.get("constraint", "no conditions")), - ) - else: - return str(pattern) - - return "{} -> {}".format( - pattern_to_str(self.in_pattern), - pattern_to_str(self.out_pattern), - ) - - def __repr__(self): - return str(self) - - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - name = getattr(self, "__name__", getattr(self, "name", None)) - print( - f"{' ' * level}{self.__class__.__name__} {name}({self.in_pattern}, {self.out_pattern}) id={id(self)}", - file=stream, - ) - - -class Updater(Feature): - def __init__(self, importer, pruner, chin, name=None): - self.importer = importer - self.pruner = pruner - self.chin = chin - self.name = name - - def __str__(self): - return f"Updater{{{self.name}}}" - - def on_import(self, fgraph, node, reason): - if self.importer: - self.importer(node) - - def on_prune(self, fgraph, node, reason): - if self.pruner: - self.pruner(node) - - def on_change_input(self, fgraph, node, i, r, new_r, reason): - if self.chin: - self.chin(node, i, r, new_r, reason) - - def on_detach(self, fgraph): - # To allow pickling this object - self.importer = None - self.pruner = None - self.chin = None - - -class NavigatorOptimizer(GlobalOptimizer): - r"""An optimizer that applies a `LocalOptimizer` with considerations for the new nodes it creates. - - - This optimizer also allows the `LocalOptimizer` to use a special ``"remove"`` value - in the ``dict``\s returned by :meth:`LocalOptimizer`. `Variable`\s mapped to this - value are removed from the `FunctionGraph`. - - Parameters - ---------- - local_opt : - A `LocalOptimizer` to apply over a `FunctionGraph` (or ``None``). - ignore_newtrees : - - ``True``: new subgraphs returned by an optimization are not a - candidate for optimization. - - ``False``: new subgraphs returned by an optimization is a candidate - for optimization. - - ``'auto'``: let the `local_opt` set this parameter via its :attr:`reentrant` - attribute. - failure_callback - A function with the signature ``(exception, navigator, [(old, new), - (old,new),...])`` that is called when there's an exception. - - If the exception is raised in ``local_opt.transform``, the ``new`` variables - will be ``None``. - - If the exception is raised during validation (e.g. the new types don't - match) then the new variables will be the ones created by ``self.transform``. - - If this parameter is ``None``, then exceptions are not caught here and - are raised normally. - - """ - - @staticmethod - def warn(exc, nav, repl_pairs, local_opt, node): - """A failure callback that prints a traceback.""" - if config.on_opt_error != "ignore": - _logger.error(f"Optimization failure due to: {local_opt}") - _logger.error(f"node: {node}") - _logger.error("TRACEBACK:") - _logger.error(traceback.format_exc()) - if config.on_opt_error == "pdb": - pdb.post_mortem(sys.exc_info()[2]) - elif isinstance(exc, AssertionError) or config.on_opt_error == "raise": - # We always crash on AssertionError because something may be - # seriously wrong if such an exception is raised. - raise exc - - @staticmethod - def warn_inplace(exc, nav, repl_pairs, local_opt, node): - r"""A failure callback that ignores ``InconsistencyError``\s and prints a traceback. - - If the error occurred during replacement, ``repl_pairs`` is set; - otherwise, its value is ``None``. - - """ - if isinstance(exc, InconsistencyError): - return - return NavigatorOptimizer.warn(exc, nav, repl_pairs, local_opt, node) - - @staticmethod - def warn_ignore(exc, nav, repl_pairs, local_opt, node): - """A failure callback that ignores all errors.""" - - def __init__(self, local_opt, ignore_newtrees="auto", failure_callback=None): - self.local_opt = local_opt - if ignore_newtrees == "auto": - self.ignore_newtrees = not getattr(local_opt, "reentrant", True) - else: - self.ignore_newtrees = ignore_newtrees - self.failure_callback = failure_callback - - def attach_updater(self, fgraph, importer, pruner, chin=None, name=None): - r"""Install `FunctionGraph` listeners to help the navigator deal with the ``ignore_trees``-related functionality. - - Parameters - ---------- - importer : - Function that will be called whenever optimizations add stuff - to the graph. - pruner : - Function to be called when optimizations remove stuff - from the graph. - chin : - "on change input" called whenever a node's inputs change. - name : - name of the ``Updater`` to attach. - - Returns - ------- - The `FunctionGraph` plugin that handles the three tasks. - Keep this around so that `Feature`\s can be detached later. - - """ - if self.ignore_newtrees: - importer = None - - if importer is None and pruner is None: - return None - - u = Updater(importer, pruner, chin, name=name) - fgraph.attach_feature(u) - return u - - def detach_updater(self, fgraph, u): - """Undo the work of ``attach_updater``. - - Parameters - ---------- - fgraph - The `FunctionGraph`. - u - A return-value of ``attach_updater``. - - Returns - ------- - None - - """ - if u is not None: - fgraph.remove_feature(u) - - def process_node(self, fgraph, node, lopt=None): - r"""Apply `lopt` to `node`. - - The :meth:`lopt.transform` method will return either ``False`` or a - list of `Variable`\s that are intended to replace :attr:`node.outputs`. - - If the `fgraph` accepts the replacement, then the optimization is - successful, and this function returns ``True``. - - If there are no replacement candidates or the `fgraph` rejects the - replacements, this function returns ``False``. - - Parameters - ---------- - fgraph : - A `FunctionGraph`. - node : - An `Apply` instance in `fgraph` - lopt : - A `LocalOptimizer` instance that may have a better idea for - how to compute node's outputs. - - Returns - ------- - bool - ``True`` iff the `node`'s outputs were replaced in the `fgraph`. - - """ - lopt = lopt or self.local_opt - try: - replacements = lopt.transform(fgraph, node) - except Exception as e: - if self.failure_callback is not None: - self.failure_callback( - e, self, [(x, None) for x in node.outputs], lopt, node - ) - return False - else: - raise - if replacements is False or replacements is None: - return False - old_vars = node.outputs - remove = [] - if isinstance(replacements, dict): - if "remove" in replacements: - remove = replacements.pop("remove") - old_vars = list(replacements.keys()) - replacements = list(replacements.values()) - elif not isinstance(replacements, (tuple, list)): - raise TypeError( - f"Local optimizer {lopt} gave wrong type of replacement. " - f"Expected list or tuple; got {replacements}" - ) - if len(old_vars) != len(replacements): - raise ValueError( - f"Local optimizer {lopt} gave wrong number of replacements" - ) - # None in the replacement mean that this variable isn't used - # and we want to remove it - for r, rnew in zip(old_vars, replacements): - if rnew is None and len(fgraph.clients[r]) > 0: - raise ValueError( - f"Local optimizer {lopt} tried to remove a variable" - f" that is being used: {r}" - ) - # If an output would be replaced by itself, no need to perform - # the replacement - repl_pairs = [ - (r, rnew) - for r, rnew in zip(old_vars, replacements) - if rnew is not r and rnew is not None - ] - - if len(repl_pairs) == 0: - return False - try: - fgraph.replace_all_validate_remove(repl_pairs, reason=lopt, remove=remove) - return True - except Exception as e: - # This means the replacements were rejected by the fgraph. - # - # This is not supposed to happen. The default failure_callback - # will print a traceback as a warning. - if self.failure_callback is not None: - self.failure_callback(e, self, repl_pairs, lopt, node) - return False - else: - raise - - def add_requirements(self, fgraph): - super().add_requirements(fgraph) - # Added by default - # fgraph.attach_feature(ReplaceValidate()) - if self.local_opt: - self.local_opt.add_requirements(fgraph) - - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream) - if depth != 0: - self.local_opt.print_summary(stream, level=(level + 2), depth=(depth - 1)) - - -class TopoOptimizer(NavigatorOptimizer): - """An optimizer that applies a single `LocalOptimizer` to each node in topological order (or reverse).""" - - def __init__( - self, local_opt, order="in_to_out", ignore_newtrees=False, failure_callback=None - ): - if order not in ("out_to_in", "in_to_out"): - raise ValueError("order must be 'out_to_in' or 'in_to_out'") - self.order = order - super().__init__(local_opt, ignore_newtrees, failure_callback) - - def apply(self, fgraph, start_from=None): - if start_from is None: - start_from = fgraph.outputs - callback_before = fgraph.execute_callbacks_time - nb_nodes_start = len(fgraph.apply_nodes) - t0 = time.time() - q = deque(io_toposort(fgraph.inputs, start_from)) - io_t = time.time() - t0 - - def importer(node): - if node is not current_node: - q.append(node) - - u = self.attach_updater( - fgraph, importer, None, name=getattr(self, "name", None) - ) - nb = 0 - try: - t0 = time.time() - while q: - if self.order == "out_to_in": - node = q.pop() - else: - node = q.popleft() - if node not in fgraph.apply_nodes: - continue - current_node = node - nb += self.process_node(fgraph, node) - loop_t = time.time() - t0 - finally: - self.detach_updater(fgraph, u) - - callback_time = fgraph.execute_callbacks_time - callback_before - nb_nodes_end = len(fgraph.apply_nodes) - return ( - self, - nb, - nb_nodes_start, - nb_nodes_end, - io_t, - loop_t, - callback_time, - self.local_opt, - ) - - @staticmethod - def print_profile(stream, prof, level=0): - blanc = " " * level - if prof is None: # Happen as merge_profile() isn't implemented - print(blanc, "TopoOptimizer merge_profile not implemented", file=stream) - return - - ( - opt, - nb, - nb_nodes_start, - nb_nodes_end, - io_t, - loop_t, - callback_time, - lopt, - ) = prof - - print( - blanc, - "TopoOptimizer ", - getattr(opt, "name", getattr(opt, "__name__", "")), - file=stream, - ) - - print( - blanc, - " nb_node (start, end, changed)", - (nb_nodes_start, nb_nodes_end, nb), - file=stream, - ) - print(blanc, " init io_toposort", io_t, file=stream) - print(blanc, " loop time", loop_t, file=stream) - print(blanc, " callback_time", callback_time, file=stream) - if isinstance(lopt, LocalOptGroup): - if lopt.profile: - lopt.print_profile( - stream, - ( - lopt.time_opts, - lopt.process_count, - lopt.applied_true, - lopt.node_created, - lopt.profile, - ), - level=level + 1, - ) - - def __str__(self): - return getattr(self, "__name__", "") - - -def topogroup_optimizer( - order, *local_opts, name=None, failure_callback=TopoOptimizer.warn_inplace, **kwargs -): - """Apply `local_opts` from the input/output nodes to the output/input nodes of a graph. - - This constructs `TopoOptimizer`s, and uses a `LocalOptGroup` when there's - more than one entry in `local_opts`. - """ - if len(local_opts) > 1: - # Don't wrap it uselessly if their is only 1 optimization. - local_opts = LocalOptGroup(*local_opts) - else: - (local_opts,) = local_opts - if not name: - name = local_opts.__name__ - ret = TopoOptimizer( - local_opts, - order=order, - failure_callback=failure_callback, - **kwargs, - ) - if name: - ret.__name__ = name - return ret - - -in2out = partial(topogroup_optimizer, "in_to_out") -out2in = partial(topogroup_optimizer, "out_to_in") - - -class OpKeyOptimizer(NavigatorOptimizer): - r"""An optimizer that applies a `LocalOptimizer` to specific `Op`\s. - - The `Op`\s are provided by a :meth:`LocalOptimizer.op_key` method (either - as a list of `Op`\s or a single `Op`), and discovered within a - `FunctionGraph` using the `NodeFinder` `Feature`. - - This is similar to the ``tracks`` feature used by other optimizers. - - """ - - def __init__(self, local_opt, ignore_newtrees=False, failure_callback=None): - if not hasattr(local_opt, "op_key"): - raise TypeError(f"{local_opt} must have an `op_key` method.") - super().__init__(local_opt, ignore_newtrees, failure_callback) - - def apply(self, fgraph): - op = self.local_opt.op_key() - if isinstance(op, (list, tuple)): - q = reduce(list.__iadd__, map(fgraph.get_nodes, op)) - else: - q = list(fgraph.get_nodes(op)) - - def importer(node): - if node is not current_node: - if node.op == op: - q.append(node) - - u = self.attach_updater( - fgraph, importer, None, name=getattr(self, "name", None) - ) - try: - while q: - node = q.pop() - if node not in fgraph.apply_nodes: - continue - current_node = node - self.process_node(fgraph, node) - finally: - self.detach_updater(fgraph, u) - - def add_requirements(self, fgraph): - super().add_requirements(fgraph) - fgraph.attach_feature(NodeFinder()) - - -class ChangeTracker(Feature): - def __init__(self): - self.changed = False - self.nb_imported = 0 - - def clone(self): - return type(self)() - - def on_import(self, fgraph, node, reason): - self.nb_imported += 1 - self.changed = True - - def on_change_input(self, fgraph, node, i, r, new_r, reason): - self.changed = True - - def reset(self): - self.changed = False - - def on_attach(self, fgraph): - if hasattr(fgraph, "change_tracker"): - raise AlreadyThere() - fgraph.change_tracker = self - - def on_detach(self, fgraph): - del fgraph.change_tracker - - -def merge_dict(d1, d2): - r"""Merge two ``dict``\s by adding their values.""" - d = d1.copy() - for k, v in d2.items(): - if k in d: - d[k] += v - else: - d[k] = v - return d - - -class EquilibriumOptimizer(NavigatorOptimizer): - """An optimizer that applies an optimization until a fixed-point/equilibrium is reached. - - Parameters - ---------- - optimizers : list or set - Local or global optimizations to apply until equilibrium. - The global optimizer will be run at the start of each iteration before - the local optimizer. - max_use_ratio : int or float - Each optimizer can be applied at most ``(size of graph * this number)`` - times. - ignore_newtrees : - See :attr:`EquilibriumDB.ignore_newtrees`. - final_optimizers : - Global optimizers that will be run after each iteration. - cleanup_optimizers : - Global optimizers that apply a list of pre determined optimization. - They must not traverse the graph as they are called very frequently. - The MergeOptimizer is one example of optimization that respect this. - They are applied after all global optimizers, then when one local - optimizer is applied, then after all final optimizers. - - """ - - def __init__( - self, - optimizers, - failure_callback=None, - ignore_newtrees=True, - tracks_on_change_inputs=False, - max_use_ratio=None, - final_optimizers=None, - cleanup_optimizers=None, - ): - super().__init__( - None, ignore_newtrees=ignore_newtrees, failure_callback=failure_callback - ) - self.global_optimizers = [] - self.final_optimizers = [] - self.cleanup_optimizers = [] - self.tracks_on_change_inputs = tracks_on_change_inputs - - self.local_tracker = LocalOptTracker() - - for opt in optimizers: - if isinstance(opt, LocalOptimizer): - self.local_tracker.add_tracker(opt) - else: - self.global_optimizers.append(opt) - - if final_optimizers: - self.final_optimizers = final_optimizers - if cleanup_optimizers: - self.cleanup_optimizers = cleanup_optimizers - self.max_use_ratio = max_use_ratio - - def get_local_optimizers(self): - yield from self.local_tracker.get_rewriters() - - def add_requirements(self, fgraph): - super().add_requirements(fgraph) - for opt in self.get_local_optimizers(): - opt.add_requirements(fgraph) - for opt in self.global_optimizers: - opt.add_requirements(fgraph) - for opt in self.final_optimizers: - opt.add_requirements(fgraph) - for opt in self.cleanup_optimizers: - opt.add_requirements(fgraph) - - def apply(self, fgraph, start_from=None): - change_tracker = ChangeTracker() - fgraph.attach_feature(change_tracker) - if start_from is None: - start_from = fgraph.outputs - else: - for node in start_from: - assert node in fgraph.outputs - - changed = True - max_use_abort = False - opt_name = None - global_process_count = {} - start_nb_nodes = len(fgraph.apply_nodes) - max_nb_nodes = len(fgraph.apply_nodes) - max_use = max_nb_nodes * self.max_use_ratio - - loop_timing = [] - loop_process_count = [] - global_opt_timing = [] - time_opts = {} - io_toposort_timing = [] - nb_nodes = [] - node_created = {} - global_sub_profs = [] - final_sub_profs = [] - cleanup_sub_profs = [] - for opt in ( - self.global_optimizers - + list(self.get_local_optimizers()) - + self.final_optimizers - + self.cleanup_optimizers - ): - global_process_count.setdefault(opt, 0) - time_opts.setdefault(opt, 0) - node_created.setdefault(opt, 0) - - def apply_cleanup(profs_dict): - changed = False - for copt in self.cleanup_optimizers: - change_tracker.reset() - nb = change_tracker.nb_imported - t_opt = time.time() - sub_prof = copt.apply(fgraph) - time_opts[copt] += time.time() - t_opt - profs_dict[copt].append(sub_prof) - if change_tracker.changed: - process_count.setdefault(copt, 0) - process_count[copt] += 1 - global_process_count[copt] += 1 - changed = True - node_created[copt] += change_tracker.nb_imported - nb - return changed - - while changed and not max_use_abort: - process_count = {} - t0 = time.time() - changed = False - iter_cleanup_sub_profs = {} - for copt in self.cleanup_optimizers: - iter_cleanup_sub_profs[copt] = [] - - # apply global optimizers - sub_profs = [] - for gopt in self.global_optimizers: - change_tracker.reset() - nb = change_tracker.nb_imported - t_opt = time.time() - sub_prof = gopt.apply(fgraph) - time_opts[gopt] += time.time() - t_opt - sub_profs.append(sub_prof) - if change_tracker.changed: - process_count.setdefault(gopt, 0) - process_count[gopt] += 1 - global_process_count[gopt] += 1 - changed = True - node_created[gopt] += change_tracker.nb_imported - nb - if global_process_count[gopt] > max_use: - max_use_abort = True - opt_name = getattr(gopt, "name", None) or getattr( - gopt, "__name__", "" - ) - global_sub_profs.append(sub_profs) - - global_opt_timing.append(float(time.time() - t0)) - - # apply clean up as global opt can have done changes that - # request that - changed |= apply_cleanup(iter_cleanup_sub_profs) - - # apply local optimizer - topo_t0 = time.time() - q = deque(io_toposort(fgraph.inputs, start_from)) - io_toposort_timing.append(time.time() - topo_t0) - - nb_nodes.append(len(q)) - max_nb_nodes = max(max_nb_nodes, len(q)) - max_use = max_nb_nodes * self.max_use_ratio - - def importer(node): - if node is not current_node: - q.append(node) - - chin = None - if self.tracks_on_change_inputs: - - def chin(node, i, r, new_r, reason): - if node is not current_node and not isinstance(node, str): - q.append(node) - - u = self.attach_updater( - fgraph, importer, None, chin=chin, name=getattr(self, "name", None) - ) - try: - while q: - node = q.pop() - if node not in fgraph.apply_nodes: - continue - current_node = node - for lopt in self.local_tracker.get_trackers(node.op): - nb = change_tracker.nb_imported - t_opt = time.time() - lopt_change = self.process_node(fgraph, node, lopt) - time_opts[lopt] += time.time() - t_opt - if not lopt_change: - continue - process_count.setdefault(lopt, 0) - process_count[lopt] += 1 - global_process_count[lopt] += 1 - changed = True - node_created[lopt] += change_tracker.nb_imported - nb - changed |= apply_cleanup(iter_cleanup_sub_profs) - if global_process_count[lopt] > max_use: - max_use_abort = True - opt_name = getattr(lopt, "name", None) or getattr( - lopt, "__name__", "" - ) - if node not in fgraph.apply_nodes: - # go to next node - break - finally: - self.detach_updater(fgraph, u) - - # Apply final optimizers - sub_profs = [] - t_before_final_opt = time.time() - for gopt in self.final_optimizers: - change_tracker.reset() - nb = change_tracker.nb_imported - t_opt = time.time() - sub_prof = gopt.apply(fgraph) - time_opts[gopt] += time.time() - t_opt - sub_profs.append(sub_prof) - if change_tracker.changed: - process_count.setdefault(gopt, 0) - process_count[gopt] += 1 - global_process_count[gopt] += 1 - changed = True - node_created[gopt] += change_tracker.nb_imported - nb - if global_process_count[gopt] > max_use: - max_use_abort = True - opt_name = getattr(gopt, "name", None) or getattr( - gopt, "__name__", "" - ) - final_sub_profs.append(sub_profs) - - global_opt_timing[-1] += time.time() - t_before_final_opt - # apply clean up as final opt can have done changes that - # request that - changed |= apply_cleanup(iter_cleanup_sub_profs) - # merge clean up profiles during that iteration. - c_sub_profs = [] - for copt, sub_profs in iter_cleanup_sub_profs.items(): - sub_prof = sub_profs[0] - for s_p in sub_profs[1:]: - sub_prof = copt.merge_profile(sub_prof, s_p) - c_sub_profs.append(sub_prof) - cleanup_sub_profs.append(c_sub_profs) - - loop_process_count.append(process_count) - loop_timing.append(float(time.time() - t0)) - - end_nb_nodes = len(fgraph.apply_nodes) - - if max_use_abort: - msg = ( - f"EquilibriumOptimizer max'ed out by '{opt_name}'" - + ". You can safely raise the current threshold of " - + "{config.optdb__max_use_ratio:f} with the aesara flag 'optdb__max_use_ratio'." - ) - if config.on_opt_error == "raise": - raise AssertionError(msg) - else: - _logger.error(msg) - fgraph.remove_feature(change_tracker) - assert len(loop_process_count) == len(loop_timing) - assert len(loop_process_count) == len(global_opt_timing) - assert len(loop_process_count) == len(nb_nodes) - assert len(loop_process_count) == len(io_toposort_timing) - assert len(loop_process_count) == len(global_sub_profs) - assert len(loop_process_count) == len(final_sub_profs) - assert len(loop_process_count) == len(cleanup_sub_profs) - return ( - self, - loop_timing, - loop_process_count, - (start_nb_nodes, end_nb_nodes, max_nb_nodes), - global_opt_timing, - nb_nodes, - time_opts, - io_toposort_timing, - node_created, - global_sub_profs, - final_sub_profs, - cleanup_sub_profs, - ) - - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - name = getattr(self, "name", None) - print( - f"{' ' * level}{self.__class__.__name__} {name} id={id(self)}", file=stream - ) - if depth != 0: - for lopt in self.get_local_optimizers(): - lopt.print_summary(stream, level=(level + 2), depth=(depth - 1)) - - @staticmethod - def print_profile(stream, prof, level=0): - ( - opt, - loop_timing, - loop_process_count, - (start_nb_nodes, end_nb_nodes, max_nb_nodes), - global_opt_timing, - nb_nodes, - time_opts, - io_toposort_timing, - node_created, - global_sub_profs, - final_sub_profs, - cleanup_sub_profs, - ) = prof - - blanc = " " * level - print(blanc, "EquilibriumOptimizer", end=" ", file=stream) - print(blanc, getattr(opt, "name", getattr(opt, "__name__", "")), file=stream) - print( - blanc, - f" time {sum(loop_timing):.3f}s for {len(loop_timing)} passes", - file=stream, - ) - print( - blanc, - f" nb nodes (start, end, max) {int(start_nb_nodes)} {int(end_nb_nodes)} {int(max_nb_nodes)}", - file=stream, - ) - print(blanc, f" time io_toposort {sum(io_toposort_timing):.3f}s", file=stream) - s = sum(time_opts[o] for o in opt.get_local_optimizers()) - print(blanc, f" time in local optimizers {s:.3f}s", file=stream) - s = sum(time_opts[o] for o in opt.global_optimizers) - print(blanc, f" time in global optimizers {s:.3f}s", file=stream) - s = sum(time_opts[o] for o in opt.final_optimizers) - print(blanc, f" time in final optimizers {s:.3f}s", file=stream) - s = sum(time_opts[o] for o in opt.cleanup_optimizers) - print(blanc, f" time in cleanup optimizers {s:.3f}s", file=stream) - for i in range(len(loop_timing)): - lopt = "" - if loop_process_count[i]: - d = list( - reversed(sorted(loop_process_count[i].items(), key=lambda a: a[1])) - ) - lopt = " ".join([str((str(k), v)) for k, v in d[:5]]) - if len(d) > 5: - lopt += " ..." - print( - blanc, - ( - f" {int(i):2d} - {loop_timing[i]:.3f}s {int(sum(loop_process_count[i].values()))} ({global_opt_timing[i]:.3f}s in global opts, " - f"{io_toposort_timing[i]:.3f}s io_toposort) - {int(nb_nodes[i])} nodes - {lopt}" - ), - file=stream, - ) - - count_opt = [] - not_used = [] - not_used_time = 0 - process_count = {} - for o in ( - opt.global_optimizers - + list(opt.get_local_optimizers()) - + list(opt.final_optimizers) - + list(opt.cleanup_optimizers) - ): - process_count.setdefault(o, 0) - for count in loop_process_count: - for o, v in count.items(): - process_count[o] += v - for o, count in process_count.items(): - if count > 0: - count_opt.append((time_opts[o], count, node_created[o], o)) - else: - not_used.append((time_opts[o], o)) - not_used_time += time_opts[o] - - if count_opt: - print( - blanc, " times - times applied - nb node created - name:", file=stream - ) - count_opt.sort() - for (t, count, n_created, o) in count_opt[::-1]: - print( - blanc, - f" {t:.3f}s - {int(count)} - {int(n_created)} - {o}", - file=stream, - ) - print( - blanc, - f" {not_used_time:.3f}s - in {len(not_used)} optimization that were not used (display only those with a runtime > 0)", - file=stream, - ) - not_used.sort(key=lambda nu: (nu[0], str(nu[1]))) - for (t, o) in not_used[::-1]: - if t > 0: - # Skip opt that have 0 times, they probably wasn't even tried. - print(blanc + " ", f" {t:.3f}s - {o}", file=stream) - print(file=stream) - gf_opts = [ - o - for o in ( - opt.global_optimizers - + list(opt.final_optimizers) - + list(opt.cleanup_optimizers) - ) - if o.print_profile.__code__ is not GlobalOptimizer.print_profile.__code__ - ] - if not gf_opts: - return - print(blanc, "Global, final and clean up optimizers", file=stream) - for i in range(len(loop_timing)): - print(blanc, f"Iter {int(i)}", file=stream) - for o, prof in zip(opt.global_optimizers, global_sub_profs[i]): - try: - o.print_profile(stream, prof, level + 2) - except NotImplementedError: - print(blanc, "merge not implemented for ", o) - for o, prof in zip(opt.final_optimizers, final_sub_profs[i]): - try: - o.print_profile(stream, prof, level + 2) - except NotImplementedError: - print(blanc, "merge not implemented for ", o) - for o, prof in zip(opt.cleanup_optimizers, cleanup_sub_profs[i]): - try: - o.print_profile(stream, prof, level + 2) - except NotImplementedError: - print(blanc, "merge not implemented for ", o) - - @staticmethod - def merge_profile(prof1, prof2): - # (opt, loop_timing, loop_process_count, max_nb_nodes, - # global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1 - local_optimizers = OrderedSet(prof1[0].get_local_optimizers()).union( - prof2[0].get_local_optimizers() - ) - global_optimizers = OrderedSet(prof1[0].global_optimizers).union( - prof2[0].global_optimizers - ) - final_optimizers = list( - OrderedSet(prof1[0].final_optimizers).union(prof2[0].final_optimizers) - ) - cleanup_optimizers = list( - OrderedSet(prof1[0].cleanup_optimizers).union(prof2[0].cleanup_optimizers) - ) - new_opt = EquilibriumOptimizer( - local_optimizers.union(global_optimizers), - max_use_ratio=1, - final_optimizers=final_optimizers, - cleanup_optimizers=cleanup_optimizers, - ) - - def add_append_list(l1, l2): - l = copy.copy(l1) - for idx, nb in enumerate(l2): - if idx < len(l): - l[idx] += nb - else: - l.append(nb) - return l - - loop_timing = add_append_list(prof1[1], prof2[1]) - - loop_process_count = list(prof1[2]) - global_sub_profs = [] - final_sub_profs = [] - cleanup_sub_profs = [] - - for i in range(min(len(loop_process_count), len(prof2[2]))): - process_count = loop_process_count[i] - for process, count in prof2[2][i].items(): - if process in process_count: - process_count[process] += count - else: - process_count[process] = count - - def merge(opts, attr, idx): - tmp = [] - for opt in opts: - o1 = getattr(prof1[0], attr) - o2 = getattr(prof2[0], attr) - if opt in o1 and opt in o2: - p1 = prof1[idx][i][o1.index(opt)] - p2 = prof2[idx][i][o2.index(opt)] - m = None - if hasattr(opt, "merge_profile"): - m = opt.merge_profile(p1, p2) - elif opt in o1: - m = prof1[idx][i][o1.index(opt)] - else: - m = prof2[idx][i][o2.index(opt)] - tmp.append(m) - return tmp - - global_sub_profs.append(merge(global_optimizers, "global_optimizers", 9)) - final_sub_profs.append(merge(final_optimizers, "final_optimizers", 10)) - cleanup_sub_profs.append( - merge(cleanup_optimizers, "cleanup_optimizers", 11) - ) - - # Add the iteration done by only one of the profile. - loop_process_count.extend(prof1[2][len(loop_process_count) :]) - global_sub_profs.extend(prof1[9][len(global_sub_profs) :]) - final_sub_profs.extend(prof1[10][len(final_sub_profs) :]) - cleanup_sub_profs.extend(prof1[11][len(cleanup_sub_profs) :]) - - global_sub_profs.extend(prof2[9][len(loop_process_count) :]) - final_sub_profs.extend(prof2[10][len(loop_process_count) :]) - cleanup_sub_profs.extend(prof2[11][len(loop_process_count) :]) - - max_nb_nodes = max(prof1[3], prof2[3]) - - global_opt_timing = add_append_list(prof1[4], prof2[4]) - - nb_nodes = add_append_list(prof1[5], prof2[5]) - - time_opts = merge_dict(prof1[6], prof2[6]) - io_toposort_timing = add_append_list(prof1[7], prof2[7]) - assert ( - len(loop_timing) - == len(global_opt_timing) - == len(global_sub_profs) - == len(io_toposort_timing) - == len(nb_nodes) - ) - assert len(loop_timing) == max(len(prof1[1]), len(prof2[1])) - - node_created = merge_dict(prof1[8], prof2[8]) - return ( - new_opt, - loop_timing, - loop_process_count, - max_nb_nodes, - global_opt_timing, - nb_nodes, - time_opts, - io_toposort_timing, - node_created, - global_sub_profs, - final_sub_profs, - cleanup_sub_profs, - ) - - -def _check_chain(r, chain): - """ - WRITEME - - """ - chain = list(reversed(chain)) - while chain: - elem = chain.pop() - if elem is None: - if r.owner is not None: - return False - elif r.owner is None: - return False - elif isinstance(elem, Op): - if r.owner.op != elem: - return False - else: - try: - if issubclass(elem, Op) and not isinstance(r.owner.op, elem): - return False - except TypeError: - return False - if chain: - r = r.owner.inputs[chain.pop()] - # print 'check_chain', _check_chain.n_calls - # _check_chain.n_calls += 1 - - # The return value will be used as a Boolean, but some Variables cannot - # be used as Booleans (the results of comparisons, for instance) - return r is not None - - -def check_chain(r, *chain): - """ - WRITEME - - """ - if isinstance(r, Apply): - r = r.outputs[0] - return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain))) - - -def pre_greedy_local_optimizer(fgraph, optimizations, out): - """Apply local optimizations to a graph. - - This function traverses the computation graph in the graph before the - variable `out` but that are not in the `fgraph`. It applies - `optimizations` to each variable on the traversed graph. - - .. warning:: - - This changes the nodes in a graph in-place. - - Its main use is to apply locally constant folding when generating - the graph of the indices of a subtensor. - - Changes should not be applied to nodes that are in an `fgraph`, - so we use `fgraph` to prevent that. - - Notes - ----- - This doesn't do an equilibrium optimization, so, if there is an - optimization--like `local_upcast_elemwise_constant_inputs`--in the list - that adds additional nodes to the inputs of the node, it might be necessary - to call this function multiple times. - - Parameters - ---------- - fgraph : FunctionGraph - The graph used to avoid/filter nodes. - optimizations : list of LocalOptimizer - The list of local optimizations to apply - out : Variable - A `Variable` specifying the graph to optimize. - - """ - - def local_recursive_function(list_opt, out, optimized_vars, depth): - if not getattr(out, "owner", None): - return [out], optimized_vars - node = out.owner - - if node in fgraph.apply_nodes: - return node.outputs, optimized_vars - - # Walk up the graph via the node's inputs - for idx, inp in enumerate(node.inputs): - if inp in optimized_vars: - nw_in = optimized_vars[inp] - else: - if inp.owner: - outs, optimized_vars = local_recursive_function( - list_opt, inp, optimized_vars, depth + 1 - ) - for k, v in zip(inp.owner.outputs, outs): - optimized_vars[k] = v - nw_in = outs[inp.owner.outputs.index(inp)] - - else: - nw_in = inp - optimized_vars[inp] = inp - - # XXX: An in-place change - node.inputs[idx] = nw_in - - # Apply the optimizations - results = node.outputs - for opt in list_opt: - ret = opt.transform(fgraph, node) - if ret is not False and ret is not None: - assert len(ret) == len(node.outputs), opt - for k, v in zip(node.outputs, ret): - optimized_vars[k] = v - results = ret - if ret[0].owner: - node = out.owner - else: - break - - return results, optimized_vars - - if out.owner: - out_index = out.owner.outputs.index(out) - else: - out_index = 0 - - final_outs, optimized_nodes = local_recursive_function(optimizations, out, {}, 0) - return final_outs[out_index] - - -def copy_stack_trace(from_var, to_var): - r"""Copy the stack traces from `from_var` to `to_var`. - - Parameters - ---------- - from_var : - `Variable` or list `Variable`\s to copy stack traces from. - to_var : - `Variable` or list `Variable`\s to copy stack traces to. - - Notes - ----- - The stacktrace is assumed to be of the form of a list of lists - of tuples. Each tuple contains the filename, line number, function name - and so on. Each list of tuples contains the truples belonging to a - particular `Variable`. - - """ - - # Store stack traces from from_var - tr = [] - if isinstance(from_var, Iterable) and not isinstance(from_var, Variable): - # If from_var is a list, store concatenated stack traces - for v in from_var: - tr += getattr(v.tag, "trace", []) - - else: - # If from_var is not a list, it must be a single tensor variable, - # so just store that particular stack trace - tr = getattr(from_var.tag, "trace", []) - - if tr and isinstance(tr[0], tuple): - # There was one single stack trace, we encapsulate it in a list - tr = [tr] - - # Copy over stack traces to to_var - if isinstance(to_var, Iterable) and not isinstance(to_var, Variable): - # Copy over stack traces from from_var to each variable in - # to_var, including the stack_trace of the to_var before - for v in to_var: - v.tag.trace = getattr(v.tag, "trace", []) + tr - else: - # Copy over stack traces from from_var to each variable to - # to_var, including the stack_trace of the to_var before - to_var.tag.trace = getattr(to_var.tag, "trace", []) + tr - return to_var - - -def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"): - r"""Checks if the outputs of specific `Op`\s have a stack trace. +warnings.warn( + "The module `aesara.graph.opt` is deprecated; use `aesara.graph.rewriting.basic` instead.", + DeprecationWarning, + stacklevel=2, +) - Parameters - ---------- - f_or_fgraph : Function or FunctionGraph - The compiled function or the function graph to be analysed. - ops_to_check - This value can be of four different types: - - classes or instances inheriting from `Op` - - tuple/list of classes or instances inheriting from `Op` - - string - - function returning a boolean and taking as input an instance of `Op` +from aesara.graph.rewriting.basic import * # noqa: F401 E402 F403 +from aesara.graph.rewriting.basic import DEPRECATED_NAMES # noqa: F401 E402 F403 - - if `ops_to_check` is a string, it should be either ``'last'`` or ``'all'``. - ``'last'`` will check only the last `Op` of the graph while ``'all'`` will - check all the `Op`\s of the graph. - - if `ops_to_check` is an `Op` or a tuple/list of `Op`\s, the function will - check that all the outputs of their occurrences in the graph have a - stack trace. - - if `ops_to_check` is a function, it should take as input a - `Op` and return a boolean indicating if the input `Op` should - be checked or not. - bug_print - This value is a string belonging to ``{'raise', 'warn', 'ignore'}``. - You can specify the behaviour of the function when the specified - `ops_to_check` are not in the graph of `f_or_fgraph`: it can either raise - an exception, write a warning or simply ignore it. +def __getattr__(name): + """Intercept module-level attribute access of deprecated symbols. - Returns - ------- - boolean - ``True`` if the outputs of the specified ops have a stack, ``False`` - otherwise. + Adapted from https://stackoverflow.com/a/55139609/3006474. """ - if isinstance(f_or_fgraph, aesara.compile.function.types.Function): - fgraph = f_or_fgraph.maker.fgraph - elif isinstance(f_or_fgraph, aesara.graph.fg.FunctionGraph): - fgraph = f_or_fgraph - else: - raise ValueError("The type of f_or_fgraph is not supported") - - if isinstance(ops_to_check, Op) or ( - inspect.isclass(ops_to_check) and issubclass(ops_to_check, Op) - ): - ops_to_check = (ops_to_check,) - - # if ops_to_check is a string - if isinstance(ops_to_check, str): - if ops_to_check == "last": - apply_nodes_to_check = [ - fgraph.outputs[i].owner for i in range(len(fgraph.outputs)) - ] - elif ops_to_check == "all": - apply_nodes_to_check = fgraph.apply_nodes - else: - raise ValueError("The string ops_to_check is not recognised") - - # if ops_to_check is a list/tuple of ops - elif isinstance(ops_to_check, (tuple, list)): - # Separate classes from instances in ops_to_check - op_instances = [] - op_classes = [] - for obj in ops_to_check: - if isinstance(obj, Op): - op_instances.append(obj) - else: - op_classes.append(obj) - op_classes = tuple(op_classes) - - apply_nodes_to_check = [ - node for node in fgraph.apply_nodes if node.op in ops_to_check - ] + [ - node - for node in fgraph.apply_nodes - if isinstance(node.op, op_classes) - or ( - hasattr(node.op, "scalar_op") - and isinstance(node.op.scalar_op, op_classes) - ) - ] - - # if ops_to_check is a function - elif callable(ops_to_check): - apply_nodes_to_check = [ - node for node in fgraph.apply_nodes if ops_to_check(node) - ] - - else: - raise ValueError("ops_to_check does not have the right type") - - if not apply_nodes_to_check: - msg = ( - "Provided op instances/classes are not in the graph or the " - "graph is empty" - ) - if bug_print == "warn": - warnings.warn(msg) - elif bug_print == "raise": - raise Exception(msg) - elif bug_print == "ignore": - pass - else: - raise ValueError("The string bug_print is not recognised") - - for node in apply_nodes_to_check: - for output in node.outputs: - if not hasattr(output.tag, "trace") or not output.tag.trace: - return False - - return True - - -class CheckStackTraceFeature(Feature): - def on_import(self, fgraph, node, reason): - # In optdb we only register the CheckStackTraceOptimization when - # config.check_stack_trace is not off but we also double check here. - if config.check_stack_trace != "off" and not check_stack_trace(fgraph, "all"): - if config.check_stack_trace == "raise": - raise AssertionError( - "Empty stack trace! The optimization that inserted this variable is " - + str(reason) - ) - elif config.check_stack_trace in ("log", "warn"): - apply_nodes_to_check = fgraph.apply_nodes - for node in apply_nodes_to_check: - for output in node.outputs: - if not hasattr(output.tag, "trace") or not output.tag.trace: - output.tag.trace = [ - [ - ( - "", - 0, - "Empty stack trace! The optimization that" - + "inserted this variable is " - + str(reason), - "", - ) - ] - ] - if config.check_stack_trace == "warn": - warnings.warn( - "Empty stack trace! The optimization that inserted this variable is" - + str(reason) - ) - + global DEPRECATED_NAMES -class CheckStackTraceOptimization(GlobalOptimizer): - """Optimizer that serves to add `CheckStackTraceOptimization` as a feature.""" + from warnings import warn - def add_requirements(self, fgraph): - if not hasattr(fgraph, "CheckStackTraceFeature"): - fgraph.attach_feature(CheckStackTraceFeature()) + for old_name, msg, old_object in DEPRECATED_NAMES: + if name == old_name: + warn(msg, DeprecationWarning, stacklevel=2) + return old_object - def apply(self, fgraph): - pass + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/aesara/graph/opt_utils.py b/aesara/graph/opt_utils.py index 646c5803b9..4078e86851 100644 --- a/aesara/graph/opt_utils.py +++ b/aesara/graph/opt_utils.py @@ -1,223 +1,29 @@ -import copy -from typing import Generator, Sequence, Union, cast +import warnings -import aesara -from aesara.graph.basic import ( - Apply, - Variable, - equal_computations, - graph_inputs, - vars_between, -) -from aesara.graph.fg import FunctionGraph -from aesara.graph.optdb import OptimizationQuery - - -def optimize_graph( - fgraph: Union[Variable, FunctionGraph], - include: Sequence[str] = ["canonicalize"], - custom_opt=None, - clone: bool = False, - **kwargs -) -> Union[Variable, FunctionGraph]: - """Easily optimize a graph. - - Parameters - ========== - fgraph: - A ``FunctionGraph`` or ``Variable`` to be optimized. - include: - String names of the optimizations to be applied. The default - optimization is ``"canonicalization"``. - custom_opt: - A custom ``Optimization`` to also be applied. - clone: - Whether or not to clone the input graph before optimizing. - **kwargs: - Keyword arguments passed to the ``aesara.graph.optdb.OptimizationQuery`` object. - """ - from aesara.compile import optdb - - return_only_out = False - if not isinstance(fgraph, FunctionGraph): - fgraph = FunctionGraph(outputs=[fgraph], clone=clone) - return_only_out = True - - canonicalize_opt = optdb.query(OptimizationQuery(include=include, **kwargs)) - _ = canonicalize_opt.optimize(fgraph) - - if custom_opt: - custom_opt.optimize(fgraph) - - if return_only_out: - return fgraph.outputs[0] - else: - return fgraph - - -def is_same_graph_with_merge(var1, var2, givens=None): - """ - Merge-based implementation of `aesara.graph.basic.is_same_graph`. - - See help on `aesara.graph.basic.is_same_graph` for additional documentation. - """ - from aesara.graph.opt import MergeOptimizer - - if givens is None: - givens = {} - # Copy variables since the MergeOptimizer will modify them. - copied = copy.deepcopy([var1, var2, givens]) - vars = copied[0:2] - givens = copied[2] - # Create FunctionGraph. - inputs = list(graph_inputs(vars)) - # The clone isn't needed as we did a deepcopy and we cloning will - # break the mapping in givens. - fgraph = aesara.graph.fg.FunctionGraph(inputs, vars, clone=False) - # Perform Variable substitution. - for to_replace, replace_by in givens.items(): - fgraph.replace(to_replace, replace_by) - # Perform merge optimization. - MergeOptimizer().optimize(fgraph) - # When two variables perform the same computations, they will have the same - # owner in the optimized graph. - # We need to be careful with the special case where the owner is None, - # which happens when the graph is made of a single Variable. - # We also need to make sure we replace a Variable if it is present in - # `givens`. - vars_replaced = [givens.get(v, v) for v in fgraph.outputs] - o1, o2 = [v.owner for v in vars_replaced] - if o1 is None and o2 is None: - # Comparing two single-Variable graphs: they are equal if they are - # the same Variable. - return vars_replaced[0] == vars_replaced[1] - else: - return o1 is o2 - - -def is_same_graph(var1, var2, givens=None): - """ - Return True iff Variables `var1` and `var2` perform the same computation. - - By 'performing the same computation', we mean that they must share the same - graph, so that for instance this function will return False when comparing - (x * (y * z)) with ((x * y) * z). +warnings.warn( + "The module `aesara.graph.opt_utils` is deprecated; use `aesara.graph.rewriting.utils` instead.", + DeprecationWarning, + stacklevel=2, +) - The current implementation is not efficient since, when possible, it - verifies equality by calling two different functions that are expected to - return the same output. The goal is to verify this assumption, to - eventually get rid of one of them in the future. +from aesara.graph.rewriting.utils import * # noqa: F401 E402 F403 +from aesara.graph.rewriting.utils import DEPRECATED_NAMES # noqa: F401 E402 F403 - Parameters - ---------- - var1 - The first Variable to compare. - var2 - The second Variable to compare. - givens - Similar to the `givens` argument of `aesara.function`, it can be used - to perform substitutions in the computational graph of `var1` and - `var2`. This argument is associated to neither `var1` nor `var2`: - substitutions may affect both graphs if the substituted variable - is present in both. - Examples - -------- +def __getattr__(name): + """Intercept module-level attribute access of deprecated symbols. - ====== ====== ====== ====== - var1 var2 givens output - ====== ====== ====== ====== - x + 1 x + 1 {} True - x + 1 y + 1 {} False - x + 1 y + 1 {x: y} True - ====== ====== ====== ====== + Adapted from https://stackoverflow.com/a/55139609/3006474. """ - use_equal_computations = True - - if givens is None: - givens = {} - - if not isinstance(givens, dict): - givens = dict(givens) - - # Get result from the merge-based function. - rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens) - - if givens: - # We need to build the `in_xs` and `in_ys` lists. To do this, we need - # to be able to tell whether a variable belongs to the computational - # graph of `var1` or `var2`. - # The typical case we want to handle is when `to_replace` belongs to - # one of these graphs, and `replace_by` belongs to the other one. In - # other situations, the current implementation of `equal_computations` - # is probably not appropriate, so we do not call it. - ok = True - in_xs = [] - in_ys = [] - # Compute the sets of all variables found in each computational graph. - inputs_var = list(map(graph_inputs, ([var1], [var2]))) - all_vars = [ - set(vars_between(v_i, v_o)) - for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2])) - ] - - def in_var(x, k): - # Return True iff `x` is in computation graph of variable `vark`. - return x in all_vars[k - 1] + global DEPRECATED_NAMES - for to_replace, replace_by in givens.items(): - # Map a substitution variable to the computational graphs it - # belongs to. - inside = { - v: [in_var(v, k) for k in (1, 2)] for v in (to_replace, replace_by) - } - if ( - inside[to_replace][0] - and not inside[to_replace][1] - and inside[replace_by][1] - and not inside[replace_by][0] - ): - # Substitute variable in `var1` by one from `var2`. - in_xs.append(to_replace) - in_ys.append(replace_by) - elif ( - inside[to_replace][1] - and not inside[to_replace][0] - and inside[replace_by][0] - and not inside[replace_by][1] - ): - # Substitute variable in `var2` by one from `var1`. - in_xs.append(replace_by) - in_ys.append(to_replace) - else: - ok = False - break - if not ok: - # We cannot directly use `equal_computations`. - use_equal_computations = False - else: - in_xs = None - in_ys = None - if use_equal_computations: - rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys) - assert rval2 == rval1 - return rval1 + from warnings import warn + for old_name, msg, old_object in DEPRECATED_NAMES: + if name == old_name: + warn(msg, DeprecationWarning, stacklevel=2) + return old_object -def get_clients_at_depth( - fgraph: FunctionGraph, node: Apply, depth: int -) -> Generator[Apply, None, None]: - """Yields node clients at given depth.""" - for var in node.outputs: - if depth > 0: - for out_node, _ in fgraph.clients[var]: - if out_node == "output": - continue - yield from get_clients_at_depth( - fgraph, cast(Apply, out_node), depth - 1 - ) - else: - assert var.owner is not None - yield var.owner + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/aesara/graph/optdb.py b/aesara/graph/optdb.py index 693e9f2003..90f448cd0a 100644 --- a/aesara/graph/optdb.py +++ b/aesara/graph/optdb.py @@ -1,552 +1,29 @@ -import copy -import math -import sys -from functools import cmp_to_key -from io import StringIO -from typing import Dict, Iterable, Optional, Sequence, Tuple, Union +import warnings -from aesara.configdefaults import config -from aesara.graph import opt as aesara_opt -from aesara.misc.ordered_set import OrderedSet -from aesara.utils import DefaultOrderedDict +warnings.warn( + "The module `aesara.graph.optdb` is deprecated; use `aesara.graph.rewriting.db` instead.", + DeprecationWarning, + stacklevel=2, +) -OptimizersType = Union[aesara_opt.GlobalOptimizer, aesara_opt.LocalOptimizer] +from aesara.graph.rewriting.db import * # noqa: F401 E402 F403 +from aesara.graph.rewriting.db import DEPRECATED_NAMES # noqa: F401 E402 F403 -class OptimizationDatabase: - """A class that represents a collection/database of optimizations. +def __getattr__(name): + """Intercept module-level attribute access of deprecated symbols. - These databases are used to logically organize collections of optimizers - (i.e. ``GlobalOptimizer``s and ``LocalOptimizer``). - """ - - def __init__(self): - self.__db__ = DefaultOrderedDict(OrderedSet) - self._names = set() - # This will be reset by `self.register` (via `obj.name` by the thing - # doing the registering) - self.name = None - - def register( - self, - name: str, - optimizer: Union["OptimizationDatabase", OptimizersType], - *tags: str, - use_db_name_as_tag=True, - **kwargs, - ): - """Register a new optimizer to the database. - - Parameters - ---------- - name: - Name of the optimizer. - opt: - The optimizer to register. - tags: - Tag name that allow to select the optimizer. - use_db_name_as_tag: - Add the database's name as a tag, so that its name can be used in a - query. - By default, all optimizations registered in ``EquilibriumDB`` are - selected when the ``"EquilibriumDB"`` name is used as a tag. We do - not want this behavior for some optimizers like - ``local_remove_all_assert``. Setting `use_db_name_as_tag` to - ``False`` removes that behavior. This mean only the optimizer name - and the tags specified will enable that optimization. - - """ - if not isinstance( - optimizer, - ( - OptimizationDatabase, - aesara_opt.GlobalOptimizer, - aesara_opt.LocalOptimizer, - ), - ): - raise TypeError(f"{optimizer} is not a valid optimizer type.") - - if name in self.__db__: - raise ValueError(f"The tag '{name}' is already present in the database.") - - if use_db_name_as_tag: - if self.name is not None: - tags = tags + (self.name,) - - optimizer.name = name - # This restriction is there because in many place we suppose that - # something in the OptimizationDatabase is there only once. - if optimizer.name in self.__db__: - raise ValueError( - f"Tried to register {optimizer.name} again under the new name {name}. " - "The same optimization cannot be registered multiple times in" - " an ``OptimizationDatabase``; use ProxyDB instead." - ) - self.__db__[name] = OrderedSet([optimizer]) - self._names.add(name) - self.__db__[optimizer.__class__.__name__].add(optimizer) - self.add_tags(name, *tags) - - def add_tags(self, name, *tags): - obj = self.__db__[name] - assert len(obj) == 1 - obj = obj.copy().pop() - for tag in tags: - if tag in self._names: - raise ValueError( - f"The tag '{tag}' for the {obj} collides with an existing name." - ) - self.__db__[tag].add(obj) - - def remove_tags(self, name, *tags): - obj = self.__db__[name] - assert len(obj) == 1 - obj = obj.copy().pop() - for tag in tags: - if tag in self._names: - raise ValueError( - f"The tag '{tag}' for the {obj} collides with an existing name." - ) - self.__db__[tag].remove(obj) - - def __query__(self, q): - # The ordered set is needed for deterministic optimization. - variables = OrderedSet() - for tag in q.include: - variables.update(self.__db__[tag]) - for tag in q.require: - variables.intersection_update(self.__db__[tag]) - for tag in q.exclude: - variables.difference_update(self.__db__[tag]) - remove = OrderedSet() - add = OrderedSet() - for obj in variables: - if isinstance(obj, OptimizationDatabase): - def_sub_query = q - if q.extra_optimizations: - def_sub_query = copy.copy(q) - def_sub_query.extra_optimizations = [] - sq = q.subquery.get(obj.name, def_sub_query) - - replacement = obj.query(sq) - replacement.name = obj.name - remove.add(obj) - add.add(replacement) - variables.difference_update(remove) - variables.update(add) - return variables - - def query(self, *tags, **kwtags): - if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery): - if len(tags) > 1 or kwtags: - raise TypeError( - "If the first argument to query is an `OptimizationQuery`," - " there should be no other arguments." - ) - return self.__query__(tags[0]) - include = [tag[1:] for tag in tags if tag.startswith("+")] - require = [tag[1:] for tag in tags if tag.startswith("&")] - exclude = [tag[1:] for tag in tags if tag.startswith("-")] - if len(include) + len(require) + len(exclude) < len(tags): - raise ValueError( - "All tags must start with one of the following" - " characters: '+', '&' or '-'" - ) - return self.__query__( - OptimizationQuery( - include=include, require=require, exclude=exclude, subquery=kwtags - ) - ) - - def __getitem__(self, name): - variables = self.__db__[name] - if not variables: - raise KeyError(f"Nothing registered for '{name}'") - elif len(variables) > 1: - raise ValueError(f"More than one match for {name} (please use query)") - for variable in variables: - return variable - - def __contains__(self, name): - return name in self.__db__ - - def print_summary(self, stream=sys.stdout): - print(f"{self.__class__.__name__} (id {id(self)})", file=stream) - print(" names", self._names, file=stream) - print(" db", self.__db__, file=stream) - - -# This is deprecated and will be removed. -DB = OptimizationDatabase - - -class OptimizationQuery: - """An object that specifies a set of optimizations by tag/name.""" - - def __init__( - self, - include: Iterable[str], - require: Optional[Union[OrderedSet, Sequence[str]]] = None, - exclude: Optional[Union[OrderedSet, Sequence[str]]] = None, - subquery: Optional[Dict[str, "OptimizationQuery"]] = None, - position_cutoff: float = math.inf, - extra_optimizations: Optional[ - Sequence[ - Tuple[Union["OptimizationQuery", OptimizersType], Union[int, float]] - ] - ] = None, - ): - """ - - Parameters - ========== - include: - A set of tags such that every optimization obtained through this - `OptimizationQuery` must have **one** of the tags listed. This - field is required and basically acts as a starting point for the - search. - require: - A set of tags such that every optimization obtained through this - `OptimizationQuery` must have **all** of these tags. - exclude: - A set of tags such that every optimization obtained through this - ``OptimizationQuery` must have **none** of these tags. - subquery: - A dictionary mapping the name of a sub-database to a special - `OptimizationQuery`. If no subquery is given for a sub-database, - the original `OptimizationQuery` will be used again. - position_cutoff: - Only optimizations with position less than the cutoff are returned. - extra_optimizations: - Extra optimizations to be added. - - """ - self.include = OrderedSet(include) - self.require = OrderedSet(require) if require else OrderedSet() - self.exclude = OrderedSet(exclude) if exclude else OrderedSet() - self.subquery = subquery or {} - self.position_cutoff = position_cutoff - self.name: Optional[str] = None - if extra_optimizations is None: - extra_optimizations = [] - self.extra_optimizations = list(extra_optimizations) - - def __str__(self): - return ( - "OptimizationQuery(" - + f"inc={self.include},ex={self.exclude}," - + f"require={self.require},subquery={self.subquery}," - + f"position_cutoff={self.position_cutoff}," - + f"extra_opts={self.extra_optimizations})" - ) - - def __setstate__(self, state): - self.__dict__.update(state) - if not hasattr(self, "extra_optimizations"): - self.extra_optimizations = [] - - def including(self, *tags: str) -> "OptimizationQuery": - """Add rewrites with the given tags.""" - return OptimizationQuery( - self.include.union(tags), - self.require, - self.exclude, - self.subquery, - self.position_cutoff, - self.extra_optimizations, - ) - - def excluding(self, *tags: str) -> "OptimizationQuery": - """Remove rewrites with the given tags.""" - return OptimizationQuery( - self.include, - self.require, - self.exclude.union(tags), - self.subquery, - self.position_cutoff, - self.extra_optimizations, - ) - - def requiring(self, *tags: str) -> "OptimizationQuery": - """Filter for rewrites with the given tags.""" - return OptimizationQuery( - self.include, - self.require.union(tags), - self.exclude, - self.subquery, - self.position_cutoff, - self.extra_optimizations, - ) - - def register( - self, *optimizations: Tuple["OptimizationQuery", Union[int, float]] - ) -> "OptimizationQuery": - """Include the given optimizations.""" - return OptimizationQuery( - self.include, - self.require, - self.exclude, - self.subquery, - self.position_cutoff, - self.extra_optimizations + list(optimizations), - ) - - -# This is deprecated and will be removed. -Query = OptimizationQuery - - -class EquilibriumDB(OptimizationDatabase): - """ - A set of potential optimizations which should be applied in an arbitrary - order until equilibrium is reached. - - Canonicalize, Stabilize, and Specialize are all equilibrium optimizations. - - Parameters - ---------- - ignore_newtrees - If False, we will apply local opt on new node introduced during local - optimization application. This could result in less fgraph iterations, - but this doesn't mean it will be faster globally. - - tracks_on_change_inputs - If True, we will re-apply local opt on nodes whose inputs - changed during local optimization application. This could - result in less fgraph iterations, but this doesn't mean it - will be faster globally. - - Notes - ----- - We can use `LocalOptimizer` and `GlobalOptimizer` since `EquilibriumOptimizer` - supports both. - - It is probably not a good idea to have ignore_newtrees=False and - tracks_on_change_inputs=True - - """ - - def __init__(self, ignore_newtrees=True, tracks_on_change_inputs=False): - """ - Parameters - ========== - ignore_newtrees: - If False, we will apply local opt on new node introduced during local - optimization application. This could result in less fgraph iterations, - but this doesn't mean it will be faster globally. - - tracks_on_change_inputs: - If True, we will re-apply local opt on nodes whose inputs - changed during local optimization application. This could - result in less fgraph iterations, but this doesn't mean it - will be faster globally. - """ - super().__init__() - self.ignore_newtrees = ignore_newtrees - self.tracks_on_change_inputs = tracks_on_change_inputs - self.__final__ = {} - self.__cleanup__ = {} - - def register(self, name, obj, *tags, final_opt=False, cleanup=False, **kwargs): - if final_opt and cleanup: - raise ValueError("`final_opt` and `cleanup` cannot both be true.") - super().register(name, obj, *tags, **kwargs) - self.__final__[name] = final_opt - self.__cleanup__[name] = cleanup - - def query(self, *tags, **kwtags): - _opts = super().query(*tags, **kwtags) - final_opts = [o for o in _opts if self.__final__.get(o.name, False)] - cleanup_opts = [o for o in _opts if self.__cleanup__.get(o.name, False)] - opts = [o for o in _opts if o not in final_opts and o not in cleanup_opts] - if len(final_opts) == 0: - final_opts = None - if len(cleanup_opts) == 0: - cleanup_opts = None - return aesara_opt.EquilibriumOptimizer( - opts, - max_use_ratio=config.optdb__max_use_ratio, - ignore_newtrees=self.ignore_newtrees, - tracks_on_change_inputs=self.tracks_on_change_inputs, - failure_callback=aesara_opt.NavigatorOptimizer.warn_inplace, - final_optimizers=final_opts, - cleanup_optimizers=cleanup_opts, - ) - - -class SequenceDB(OptimizationDatabase): - """A sequence of potential optimizations. - - Retrieve a sequence of optimizations (a SeqOptimizer) by calling query(). - - Each potential optimization is registered with a floating-point position. - No matter which optimizations are selected by a query, they are carried - out in order of increasing position. - - The optdb itself (`aesara.compile.mode.optdb`), from which (among many - other tags) fast_run and fast_compile optimizers are drawn is a SequenceDB. - - """ - - seq_opt = aesara_opt.SeqOptimizer - - def __init__(self, failure_callback=aesara_opt.SeqOptimizer.warn): - super().__init__() - self.__position__ = {} - self.failure_callback = failure_callback - - def register(self, name, obj, *tags, **kwargs): - super().register(name, obj, *tags, **kwargs) - position = kwargs.pop("position", "last") - if position == "last": - if len(self.__position__) == 0: - self.__position__[name] = 0 - else: - self.__position__[name] = max(self.__position__.values()) + 1 - elif isinstance(position, (int, float)): - self.__position__[name] = position - else: - raise TypeError(f"`position` must be numeric; got {position}") - - def query( - self, *tags, position_cutoff: Optional[Union[int, float]] = None, **kwtags - ): - """ - - Parameters - ---------- - position_cutoff : float or int - Only optimizations with position less than the cutoff are returned. - - """ - opts = super().query(*tags, **kwtags) - - if position_cutoff is None: - position_cutoff = config.optdb__position_cutoff - - position_dict = self.__position__ - - if len(tags) >= 1 and isinstance(tags[0], OptimizationQuery): - # the call to super should have raise an error with a good message - assert len(tags) == 1 - if getattr(tags[0], "position_cutoff", None): - position_cutoff = tags[0].position_cutoff - - # The OptimizationQuery instance might contain extra optimizations which need - # to be added the the sequence of optimizations (don't alter the - # original dictionary) - if len(tags[0].extra_optimizations) > 0: - position_dict = position_dict.copy() - for extra_opt in tags[0].extra_optimizations: - # Give a name to the extra optimization (include both the - # class name for descriptiveness and id to avoid name - # collisions) - opt, position = extra_opt - opt.name = f"{opt.__class__}_{id(opt)}" - - # Add the extra optimization to the optimization sequence - if position < position_cutoff: - opts.add(opt) - position_dict[opt.name] = position - - opts = [o for o in opts if position_dict[o.name] < position_cutoff] - opts.sort(key=lambda obj: (position_dict[obj.name], obj.name)) - - if self.failure_callback: - ret = self.seq_opt(opts, failure_callback=self.failure_callback) - else: - ret = self.seq_opt(opts) - - if hasattr(tags[0], "name"): - ret.name = tags[0].name - return ret - - def print_summary(self, stream=sys.stdout): - print(f"{self.__class__.__name__ } (id {id(self)})", file=stream) - positions = list(self.__position__.items()) - - def c(a, b): - return (a[1] > b[1]) - (a[1] < b[1]) - - positions.sort(key=cmp_to_key(c)) - - print("\tposition", positions, file=stream) - print("\tnames", self._names, file=stream) - print("\tdb", self.__db__, file=stream) - - def __str__(self): - sio = StringIO() - self.print_summary(sio) - return sio.getvalue() - - -class LocalGroupDB(SequenceDB): - """ - Generate a local optimizer of type LocalOptGroup instead - of a global optimizer. - - It supports the tracks, to only get applied to some Op. - - """ - - def __init__( - self, - apply_all_opts: bool = False, - profile: bool = False, - local_opt=aesara_opt.LocalOptGroup, - ): - super().__init__(failure_callback=None) - self.apply_all_opts = apply_all_opts - self.profile = profile - self.local_opt = local_opt - self.__name__: str = "" - - def register(self, name, obj, *tags, position="last", **kwargs): - super().register(name, obj, *tags, position=position, **kwargs) - - def query(self, *tags, **kwtags): - opts = list(super().query(*tags, **kwtags)) - ret = self.local_opt( - *opts, apply_all_opts=self.apply_all_opts, profile=self.profile - ) - return ret - - -class TopoDB(OptimizationDatabase): - """Generate a `GlobalOptimizer` of type TopoOptimizer.""" - - def __init__( - self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None - ): - super().__init__() - self.db = db - self.order = order - self.ignore_newtrees = ignore_newtrees - self.failure_callback = failure_callback - - def query(self, *tags, **kwtags): - return aesara_opt.TopoOptimizer( - self.db.query(*tags, **kwtags), - self.order, - self.ignore_newtrees, - self.failure_callback, - ) - - -class ProxyDB(OptimizationDatabase): - """A object that wraps an existing ``OptimizationDatabase``. - - This is needed because we can't register the same ``OptimizationDatabase`` - multiple times in different positions in a ``SequentialDB``. + Adapted from https://stackoverflow.com/a/55139609/3006474. """ + global DEPRECATED_NAMES - def __init__(self, db): - if not isinstance(db, OptimizationDatabase): - raise TypeError("`db` must be an `OptimizationDatabase`.") + from warnings import warn - self.db = db + for old_name, msg, old_object in DEPRECATED_NAMES: + if name == old_name: + warn(msg, DeprecationWarning, stacklevel=2) + return old_object - def query(self, *tags, **kwtags): - return self.db.query(*tags, **kwtags) + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/aesara/graph/rewriting/__init__.py b/aesara/graph/rewriting/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aesara/graph/rewriting/basic.py b/aesara/graph/rewriting/basic.py new file mode 100644 index 0000000000..586dca33d3 --- /dev/null +++ b/aesara/graph/rewriting/basic.py @@ -0,0 +1,3264 @@ +"""This module defines the base classes for graph rewriting.""" +import abc +import copy +import functools +import inspect +import logging +import pdb +import sys +import time +import traceback +import warnings +from collections import UserList, defaultdict, deque +from collections.abc import Iterable +from functools import _compose_mro, partial, reduce # type: ignore +from itertools import chain +from typing import TYPE_CHECKING, Callable, Dict +from typing import Iterable as IterableType +from typing import List, Optional, Sequence, Tuple, Union, cast + +from typing_extensions import Literal + +import aesara +from aesara.configdefaults import config +from aesara.graph import destroyhandler as dh +from aesara.graph.basic import ( + Apply, + AtomicVariable, + Constant, + Variable, + applys_between, + io_toposort, + vars_between, +) +from aesara.graph.features import AlreadyThere, Feature, NodeFinder +from aesara.graph.fg import FunctionGraph +from aesara.graph.op import Op +from aesara.graph.utils import AssocList, InconsistencyError +from aesara.misc.ordered_set import OrderedSet +from aesara.utils import flatten + + +if TYPE_CHECKING: + from aesara.graph.rewriting.unify import Var + + +_logger = logging.getLogger("aesara.graph.rewriting.basic") + +RemoveKeyType = Literal["remove"] +TransformOutputType = Union[ + bool, + Sequence[Variable], + Dict[Union[Variable, Literal["remove"]], Union[Variable, Sequence[Variable]]], +] +FailureCallbackType = Callable[ + [ + Exception, + "NodeProcessingGraphRewriter", + List[Tuple[Variable, None]], + "NodeRewriter", + Apply, + ], + None, +] + + +class MetaNodeRewriterSkip(AssertionError): + """This is an `AssertionError`, but instead of having the + `MetaNodeRewriter` print the error, it just skip that + compilation. + + """ + + +class Rewriter(abc.ABC): + """Abstract base class for graph/term rewriters.""" + + name: Optional[str] = None + + @abc.abstractmethod + def add_requirements(self, fgraph: FunctionGraph): + r"""Add `Feature`\s and other requirements to a `FunctionGraph`.""" + + @abc.abstractmethod + def print_summary(self, stream=sys.stdout, level=0, depth=-1): + """Print a single-line, indented representation of the rewriter.""" + + def __eq__(self, other): + return self is other + + def __hash__(self): + return id(self) + + +class GraphRewriter(Rewriter): + """A rewriter that can be applied to a `FunctionGraph` in order to transform it. + + This class represents a generalized rewrite that includes the way a graph + is traversed and/or changed as a whole. + + """ + + @abc.abstractmethod + def apply(self, fgraph): + """Apply the rewriter to a `FunctionGraph`. + + It may use all the methods defined by the `FunctionGraph`. If the + `GraphRewriter` needs to use a certain tool, such as an + `InstanceFinder`, it can do so in its `add_requirements` method. + + """ + raise NotImplementedError() + + def optimize(self, *args, **kwargs): + warnings.warn( + "`GraphRewriter.optimize` is deprecated; use `GraphRewriter.rewrite` instead.", + DeprecationWarning, + stacklevel=2, + ) + self.rewrite(*args, **kwargs) + + def rewrite(self, fgraph, *args, **kwargs): + """ + + This is meant as a shortcut for the following:: + + self.add_requirements(fgraph) + self.apply(fgraph) + + """ + self.add_requirements(fgraph) + return self.apply(fgraph, *args, **kwargs) + + def __call__(self, fgraph): + """Rewrite a `FunctionGraph`.""" + return self.rewrite(fgraph) + + def add_requirements(self, fgraph): + ... + + def print_summary(self, stream=sys.stdout, level=0, depth=-1): + name = getattr(self, "name", None) + print( + f"{' ' * level}{self.__class__.__name__} {name} id={id(self)}", + file=stream, + ) + + @classmethod + def print_profile(cls, stream, prof, level=0): + if prof is not None: + raise NotImplementedError( + "The function `print_profile` must be overridden when the" + " rewriter returns profiling information." + ) + + +class NodeRewriter(Rewriter): + """A `Rewriter` that is applied to an `Apply` node.""" + + def tracks(self) -> Optional[Sequence[Op]]: + """Return the list of `Op` classes to which this rewrite applies. + + Returns ``None`` when the rewrite applies to all nodes. + + """ + return None + + @abc.abstractmethod + def transform( + self, fgraph: FunctionGraph, node: Apply, *args, **kwargs + ) -> TransformOutputType: + r"""Rewrite the sub-graph given by `node`. + + Subclasses should implement this function so that it returns one of the + following: + + - ``False`` to indicate that this rewrite cannot be applied to `node` + - A list of `Variable`\s to use in place of the `node`'s current outputs + - A ``dict`` mapping old `Variable`\s to `Variable`\s, or the key + ``"remove"`` mapping to a list of `Variable`\s to be removed. + + Parameters + ---------- + fgraph + A `FunctionGraph` containing `node`. + node + An `Apply` node to be rewritten. + + """ + + raise NotImplementedError() + + def add_requirements(self, fgraph: FunctionGraph): + r"""Add required `Feature`\s to `fgraph`.""" + + def print_summary(self, stream=sys.stdout, level=0, depth=-1): + print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream) + + +class FromFunctionGraphRewriter(GraphRewriter): + """A `GraphRewriter` constructed from a given function.""" + + def __init__(self, fn, requirements=()): + self.fn = fn + self.requirements = requirements + + def apply(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + def add_requirements(self, fgraph): + for req in self.requirements: + req(fgraph) + + def print_summary(self, stream=sys.stdout, level=0, depth=-1): + print(f"{' ' * level}{self.apply} id={id(self)}", file=stream) + + def __call__(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + def __str__(self): + return self.__name__ + + +def graph_rewriter(f): + """Decorator for `FromFunctionGraphRewriter`.""" + rval = FromFunctionGraphRewriter(f) + rval.__name__ = f.__name__ + return rval + + +def inplace_graph_rewriter(f): + """Decorator for `FromFunctionGraphRewriter` that also adds the `DestroyHandler` features.""" + dh_handler = dh.DestroyHandler + requirements = (lambda fgraph: fgraph.attach_feature(dh_handler()),) + rval = FromFunctionGraphRewriter(f, requirements) + rval.__name__ = f.__name__ + return rval + + +class SequentialGraphRewriter(GraphRewriter, UserList): + """A `GraphRewriter` that applies a list of rewriters sequentially.""" + + @classmethod + def warn(cls, exc, self, rewriter): + """Default failure callback function for `SequentialGraphRewriter`.""" + _logger.error(f"{cls.__name__} apply {rewriter}") + _logger.error("Traceback:") + _logger.error(traceback.format_exc()) + if config.on_opt_error == "raise": + raise exc + elif config.on_opt_error == "pdb": + pdb.post_mortem(sys.exc_info()[2]) + + def __init__(self, *rewrites, failure_callback=None): + """ + Parameters + ---------- + *rewrites + The List of rewriters to be applied to a node + failure_callback + A callback used when a failure happens during rewriting. + + """ + if len(rewrites) == 1 and isinstance(rewrites[0], (list, tuple)): + rewrites = rewrites[0] + + super().__init__(rewrites) + + self.failure_callback = failure_callback + + def apply(self, fgraph): + """Applies each `GraphRewriter` in ``self.data`` to `fgraph`.""" + l = [] + if fgraph.profile: + validate_before = fgraph.profile.validate_time + sub_validate_time = [validate_before] + callbacks_before = fgraph.execute_callbacks_times.copy() + else: + sub_validate_time = [] + callbacks_before = [] + callback_before = fgraph.execute_callbacks_time + nb_node_before = len(fgraph.apply_nodes) + sub_profs = [] + nb_nodes = [] + + self.pre_profile = ( + self, + l, + -1, + -1, + nb_node_before, + -1, + sub_profs, + sub_validate_time, + nb_nodes, + {}, + ) + try: + for rewriter in self.data: + try: + nb_nodes_before = len(fgraph.apply_nodes) + t0 = time.time() + sub_prof = rewriter.apply(fgraph) + l.append(float(time.time() - t0)) + sub_profs.append(sub_prof) + nb_nodes.append((nb_nodes_before, len(fgraph.apply_nodes))) + if fgraph.profile: + sub_validate_time.append(fgraph.profile.validate_time) + except AssertionError: + # do not catch Assertion failures + raise + except Exception as e: + if self.failure_callback: + self.failure_callback(e, self, rewriter) + continue + else: + raise + finally: + + if fgraph.profile: + validate_time = fgraph.profile.validate_time - validate_before + callbacks_time = {} + for k, v in fgraph.execute_callbacks_times.items(): + if k in callbacks_before: + t = v - callbacks_before[k] + if t > 0: + callbacks_time[k] = t + else: + callbacks_time[k] = v + else: + validate_time = None + callbacks_time = {} + callback_time = fgraph.execute_callbacks_time - callback_before + self.pre_profile = ( + self, + l, + validate_time, + callback_time, + nb_node_before, + len(fgraph.apply_nodes), + sub_profs, + sub_validate_time, + nb_nodes, + callbacks_time, + ) + return self.pre_profile + + def __repr__(self): + return f"{type(self).__name__}({self.data})" + + def add_requirements(self, fgraph): + for rewrite in self.data: + rewrite.add_requirements(fgraph) + + def print_summary(self, stream=sys.stdout, level=0, depth=-1): + name = getattr(self, "name", None) + print( + f"{' ' * level}{self.__class__.__name__} {name} id={id(self)}", file=stream + ) + # This way, -1 will do all depth + if depth != 0: + depth -= 1 + for rewrite in self.data: + rewrite.print_summary(stream, level=(level + 2), depth=depth) + + @classmethod + def print_profile(cls, stream, prof, level=0): + ( + rewrites, + prof, + validate_time, + callback_time, + nb_node_before, + nb_node_after, + sub_profs, + sub_validate_time, + nb_nodes, + callbacks_time, + ) = prof + + validate_time = validate_time or float("nan") + callback_time = callback_time or float("nan") + + blanc = " " * level + + print(blanc, cls.__name__, end=" ", file=stream) + if hasattr(rewrites, "name"): + print(blanc, rewrites.name, end=" ", file=stream) + elif hasattr(rewrites, "__name__"): + print(blanc, rewrites.__name__, end=" ", file=stream) + print( + ( + f" time {sum(prof):.3f}s for {int(nb_node_before)}/{int(nb_node_after)} nodes" + " before/after rewriting" + ), + file=stream, + ) + print(blanc, f" {callback_time:.3f}s for callback", file=stream) + print(blanc, f" {validate_time:.3f}s for fgraph.validate()", file=stream) + if callback_time > 1: + print(blanc, " callbacks_time", file=stream) + for i in sorted(callbacks_time.items(), key=lambda a: -a[1]): + if i[1] > 0: + # We want to have the __str__ called, so we can't + # just print i. + print(blanc, " ", i[0], ",", i[1], file=stream) + + if level == 0: + print( + blanc, + " time - (name, class, index, nodes before, nodes after) - validate time", + file=stream, + ) + ll = [] + for (rewrite, nb_n) in zip(rewrites, nb_nodes): + if hasattr(rewrite, "__name__"): + name = rewrite.__name__ + else: + name = rewrite.name + idx = rewrites.index(rewrite) + ll.append((name, rewrite.__class__.__name__, idx) + nb_n) + lll = sorted(zip(prof, ll), key=lambda a: a[0]) + + for (t, rewrite) in lll[::-1]: + i = rewrite[2] + if sub_validate_time: + val_time = sub_validate_time[i + 1] - sub_validate_time[i] + print( + blanc, + f" {t:.6f}s - {rewrite} - {val_time:.3f}s", + file=stream, + ) + else: + print(blanc, f" {t:.6f}s - {rewrite}", file=stream) + + if sub_profs[i]: + rewrites[i].print_profile(stream, sub_profs[i], level=level + 1) + print(file=stream) + + @staticmethod + def merge_profile(prof1, prof2): + """Merge two profiles.""" + new_t = [] # the times for the rewrites + new_l = [] # the rewrites + new_sub_profile = [] + # Merge common (i.e. same object) rewrites + for l in set(prof1[0]).intersection(set(prof2[0])): + idx1 = prof1[0].index(l) + idx2 = prof2[0].index(l) + new_t.append(prof1[1][idx1] + prof2[1][idx2]) + new_l.append(l) + if hasattr(l, "merge_profile"): + assert len(prof1[6][idx1]) == len(prof2[6][idx2]) + new_sub_profile.append(l.merge_profile(prof1[6][idx1], prof2[6][idx2])) + else: + new_sub_profile.append(None) + + from io import StringIO + + for l in set(prof1[0]).symmetric_difference(set(prof2[0])): + # The set trick above only works for the same rewrite objects; it + # doesn't work for equivalent rewrites, so we try to merge + # equivalent rewrites here. + new_l_names = [o.name for o in new_l] + if l.name in new_l_names: + idx = new_l_names.index(l.name) + io1 = StringIO() + io2 = StringIO() + l.print_summary(io1) + new_l[idx].print_summary(io2) + if io1.read() == io2.read(): + if l in prof1[0]: + p = prof1 + else: + p = prof2 + new_t[idx] += p[1][p[0].index(l)] + if hasattr(l, "merge_profile"): + assert len(p[6][p[0].index(l)]) == len(new_sub_profile[idx]) + new_sub_profile[idx] = l.merge_profile( + new_sub_profile[idx], p[6][p[0].index(l)] + ) + else: + new_sub_profile[idx] = None + continue + if l in prof1[0]: + p = prof1 + else: + p = prof2 + new_t.append(p[1][p[0].index(l)]) + idx = p[0].index(l) + new_l.append(l) + new_sub_profile.append(p[6][idx]) + + new_rewrite = SequentialGraphRewriter(*new_l) + new_nb_nodes = [] + for p1, p2 in zip(prof1[8], prof2[8]): + new_nb_nodes.append((p1[0] + p2[0], p1[1] + p2[1])) + new_nb_nodes.extend(prof1[8][len(new_nb_nodes) :]) + new_nb_nodes.extend(prof2[8][len(new_nb_nodes) :]) + + new_callbacks_times = merge_dict(prof1[9], prof2[9]) + # We need to assert based on the name as we merge also based on + # the name. + assert {l.name for l in prof1[0]}.issubset({l.name for l in new_l}) + assert {l.name for l in prof2[0]}.issubset({l.name for l in new_l}) + assert len(new_t) == len(new_rewrite) == len(new_sub_profile) + return ( + new_rewrite, + new_t, + prof1[2] + prof2[2], + prof1[3] + prof2[3], + -1, + -1, + new_sub_profile, + [], + new_nb_nodes, + new_callbacks_times, + ) + + +class MergeFeature(Feature): + """Keeps track of variables in a `FunctionGraph` that cannot be merged together. + + That way, the `MergeOptimizer` can remember the result of the last + merge-pass on the `FunctionGraph`. + + """ + + def on_attach(self, fgraph): + if hasattr(fgraph, "merge_feature"): + raise AlreadyThere() + + fgraph.merge_feature = self + + self.seen_atomics = set() + self.atomic_sig = AssocList() + self.atomic_sig_inv = AssocList() + + # For all Apply nodes + # Set of distinct (not mergeable) nodes + self.nodes_seen = set() + # Ordered set of distinct (not mergeable) nodes without any input + self.noinput_nodes = OrderedSet() + + # Each element of scheduled is a list of list of (out, new_out) pairs. + # Each list of pairs represent the substitution needed to replace all + # the outputs of a node with the outputs of a replacement candidate. + # Each node can have several candidates. For instance, if "node" has + # 2 outputs, and there are 3 replacement candidates, we will have: + # shelf.scheduled = [ + # [[(node.out1, cand1.out1), (node.out2, cand1.out2)], + # [(node.out1, cand2.out1), (node.out2, cand2.out2)], + # [(node.out1, cand3.out1), (node.out2, cand3.out2)]]] + self.scheduled = [] + + # List of (node, candidate) pairs, where we tried to replace node by + # candidate, but it failed. This is used to avoid infinite loops + # during the replacement phase. + self.blacklist = [] + + for node in fgraph.toposort(): + self.on_import(fgraph, node, "on_attach") + + def clone(self): + return type(self)() + + def on_change_input(self, fgraph, node, i, r, new_r, reason): + if node in self.nodes_seen: + # If inputs to a node change, it's not guaranteed that the node is + # distinct from the other nodes in `self.nodes_seen`. + self.nodes_seen.discard(node) + self.process_node(fgraph, node) + + if isinstance(new_r, AtomicVariable): + self.process_atomic(fgraph, new_r) + + def on_import(self, fgraph, node, reason): + for c in node.inputs: + if isinstance(c, AtomicVariable): + self.process_atomic(fgraph, c) + + self.process_node(fgraph, node) + + def on_prune(self, fgraph, node, reason): + self.nodes_seen.discard(node) + if not node.inputs: + self.noinput_nodes.discard(node) + for c in node.inputs: + if isinstance(c, AtomicVariable) and len(fgraph.clients[c]) <= 1: + # This was the last node using this constant + sig = self.atomic_sig[c] + self.atomic_sig.discard(c) + self.atomic_sig_inv.discard(sig) + self.seen_atomics.discard(id(c)) + + def process_atomic(self, fgraph, c): + """Check if an atomic `c` can be merged, and queue that replacement.""" + if id(c) in self.seen_atomics: + return + sig = c.merge_signature() + other_c = self.atomic_sig_inv.get(sig, None) + if other_c is not None: + # multiple names will clobber each other.. + # we adopt convention to keep the last name + if c.name: + other_c.name = c.name + self.scheduled.append([[(c, other_c, "merge")]]) + else: + # this is a new constant + self.atomic_sig[c] = sig + self.atomic_sig_inv[sig] = c + self.seen_atomics.add(id(c)) + + def process_node(self, fgraph, node): + r"""Check if a `node` can be merged, and queue that replacement. + + When `node` is changed we check for other nodes (via the clients map) + that depend on the same inputs. If any of those other nodes have the + same inputs and `Op` as `node`, they are queued to be merged. + + """ + + if node in self.nodes_seen: + return + + if node.inputs: + # We use the smallest clients list. Some `Op`s like `Elemwise` + # have rewrites that put constants as the first inputs. Since + # constants generally have more clients than other types of nodes, + # using `node.inputs[0]` will make us look at more nodes on + # average, so by picking the smallest clients list, we might speed + # things up? + + clients = sorted( + (fgraph.clients[inp] for inp in node.inputs), key=lambda x: len(x) + )[0] + assert len(clients) > 0 + + merge_candidates = [c for c, i in clients if c in self.nodes_seen] + else: + # If two nodes have no input, but perform the same operation, + # they are not always constant-folded, so we want to merge them. + # In that case, the candidates are all the nodes without inputs. + merge_candidates = self.noinput_nodes + + replacement_candidates = [] + for candidate in merge_candidates: + + if candidate is node: + continue + if len(node.inputs) != len(candidate.inputs): + continue + + inputs_match = all( + node_in is cand_in + for node_in, cand_in in zip(node.inputs, candidate.inputs) + ) + + if inputs_match and node.op == candidate.op: + if (node, candidate) in self.blacklist: + # They were already tried, and there was an error + continue + + # Schedule transfer of clients from node to candidate + pairs = list( + zip( + node.outputs, + candidate.outputs, + ["merge"] * len(node.outputs), + ) + ) + + replacement_candidates.append(pairs) + + if replacement_candidates: + self.scheduled.append(replacement_candidates) + else: + self.nodes_seen.add(node) + if not node.inputs: + self.noinput_nodes.add(node) + + +class MergeOptimizer(GraphRewriter): + r"""Merges parts of the graph that are identical and redundant. + + The basic principle is that if two `Apply`\s have `Op`\s that compare equal, and + identical inputs, then they do not both need to be computed. The clients of + one are transferred to the other and one of them is removed from the graph. + This procedure is carried out in input-to-output order throughout the graph. + + The first step of merging is atomic variable-merging, so that all clients of a + :class:`Constant` like ``int(1)``, are transferred to just one particular + instance of ``int(1)``. :class:`NominalVariable`\s are not merged individually + like this; only the nodes that use them are. + + """ + + def add_requirements(self, fgraph): + if not hasattr(fgraph, "merge_feature"): + fgraph.attach_feature(MergeFeature()) + + def apply(self, fgraph): + sched = fgraph.merge_feature.scheduled + nb_fail = 0 + t0 = time.time() + if fgraph.profile: + validate_before = fgraph.profile.validate_time + callback_before = fgraph.execute_callbacks_time + callbacks_before = fgraph.execute_callbacks_times.copy() + + nb_merged = 0 + nb_atomic = 0 + while sched: + pairs_list = sched.pop() + success = True + for pairs_ in pairs_list: + # We must check again the equivalence, as the graph could've + # changed. If so, doing the replacement can introduce a node + # that depends on itself. Doing the full check of such cycles + # every time is very time consuming. I think this double check + # is faster than doing the full cycle check. The full cycle + # check is skipped by `Validator.validate` if the graph doesn't + # contain destroyers. + var, candidate_var, merge_mode = pairs_[0] + if merge_mode == "new_node" and var in fgraph.variables: + pass + elif ( + var not in fgraph.variables or candidate_var not in fgraph.variables + ): + continue + + # Keep len(item) == 2 for item in pairs + pairs = [pair[:2] for pair in pairs_] + + if var.owner and candidate_var.owner: + if merge_mode == "new_node": + inputs_match = True + else: + inputs_match = all( + node_in is cand_in + for node_in, cand_in in zip( + var.owner.inputs, candidate_var.owner.inputs + ) + ) + + # No need to compare the op again, as it don't change. + if not inputs_match: + continue + + if hasattr(fgraph, "destroy_handler"): + # If both nodes have clients that destroy them, we + # can't merge them. + clients = ( + fgraph.clients[pairs[0][0]] + fgraph.clients[pairs[0][1]] + ) + if any( + i in flatten(c.op.destroy_map.values()) + for c, i in clients + if c != "output" and c.op.destroy_map + ): + continue + + if len(pairs) == 1 and pairs[0][0].type != pairs[0][1].type: + res = pairs[0][0].type.convert_variable(pairs[0][1]) + + # Since the fgraph.replace only checks the convert_variable + # in one way, we change the order in the case that + # convert_variable will not be successful. + if not res: + pairs = [(pairs[0][1], pairs[0][0])] + + try: + # If they're all `AtomicVariable`s, there's no need to call validate. + if all(isinstance(old, AtomicVariable) for old, _ in pairs): + fgraph.replace_all(pairs, reason="MergeOptimizer") + else: + fgraph.replace_all_validate(pairs, reason="MergeOptimizer") + except InconsistencyError: + success = False + nb_fail += 1 + fgraph.merge_feature.blacklist.append( + (pairs[0][0].owner, pairs[0][1].owner) + ) + + if success: + nb_merged += len(pairs) + if isinstance(pairs[0][0], AtomicVariable): + nb_atomic += 1 + break + + if fgraph.profile: + validate_time = fgraph.profile.validate_time - validate_before + callback_time = fgraph.execute_callbacks_time - callback_before + callbacks_time = {} + for k, v in fgraph.execute_callbacks_times.items(): + if k in callbacks_before: + t = v - callbacks_before[k] + if t > 0: + callbacks_time[k] = t + else: + callbacks_time[k] = v + else: + validate_time = None + callback_time = None + callbacks_time = {} + + fgraph.merge_feature.blacklist = [] + + return ( + nb_fail, + time.time() - t0, + validate_time, + callback_time, + callbacks_time, + nb_merged, + nb_atomic, + ) + + def __str__(self): + return self.__class__.__name__ + + @classmethod + def print_profile(cls, stream, prof, level=0): + + ( + nb_fail, + replace_time, + validate_time, + callback_time, + callbacks_time, + nb_merged, + nb_atomic, + ) = prof + + validate_time = validate_time or float("nan") + callback_time = callback_time or float("nan") + + blanc = " " * level + print(blanc, cls.__name__, file=stream) + print( + blanc, + f" nb fail={nb_fail:5d} merged={nb_merged:5d} atomic={nb_atomic:5d}", + file=stream, + ) + print( + blanc, + f" time replace={replace_time:2.2f} validate={validate_time:2.2f} callback={callback_time:2.2f}", + file=stream, + ) + if callback_time > 1: + print(blanc, " callbacks_time", file=stream) + for i in sorted(callbacks_time.items(), key=lambda a: a[1]): + if i[1] > 0: + # We want to have the __str__ called, so we can't + # just print i. + print(blanc, " ", i[0], ",", i[1], file=stream) + + @staticmethod + def merge_profile(prof1, prof2): + def merge_none_number(v1, v2): + if v1 is None: + return v2 + if v2 is None: + return v1 + return v1 + v2 + + nb_fail = prof1[0] + prof2[0] + replace_time = prof1[1] + prof2[1] + validate_time = merge_none_number(prof1[2], prof2[2]) + callback_time = merge_none_number(prof1[3], prof2[3]) + callbacks_time = merge_dict(prof1[4], prof2[4]) + nb_merged = prof1[5] + prof2[5] + nb_atomic = prof1[6] + prof2[6] + return ( + nb_fail, + replace_time, + validate_time, + callback_time, + callbacks_time, + nb_merged, + nb_atomic, + ) + + +def pre_constant_merge(fgraph, variables): + """Merge constants in the graphs given by `variables`. + + .. warning:: + + This changes the nodes in a graph in-place! + + Parameters + ---------- + fgraph + A `FunctionGraph` instance in which some of these `variables` may + reside. + + We want to avoid terms in `variables` that are contained in `fgraph`. + The reason for that: it will break consistency of `fgraph` and its + features (e.g. `ShapeFeature`). + + variables + A list of nodes for which we want to merge constant inputs. + + Notes + ----- + It is used to pre-merge nodes generated inside an rewrite. It is + useful if there are many such replacements to make, so that `DebugMode` + will not check each of them. + + """ + seen_var = set() + # signature -> variable (for constants) + const_sig_inv = {} + if isinstance(variables, Variable): + variables = [variables] + + def recursive_merge(var): + + if var in seen_var: + return var + + if not hasattr(var, "owner"): + return var + + # We don't want to merge constants that are *within* the + # `FunctionGraph` + if var.owner in fgraph.apply_nodes: + return var + + seen_var.add(var) + + if isinstance(var, Constant): + sig = var.signature() + + if sig in const_sig_inv: + return const_sig_inv[sig] + + const_sig_inv[sig] = var + + return var + + if var.owner: + for idx, inp in enumerate(var.owner.inputs): + # XXX: This is changing the graph in place! + var.owner.inputs[idx] = recursive_merge(inp) + return var + + return [recursive_merge(v) for v in variables] + + +class MetaNodeRewriter(NodeRewriter): + r""" + Base class for meta-rewriters that try a set of `NodeRewriter`\s + to replace a node and choose the one that executes the fastest. + + If the error `MetaNodeRewriterSkip` is raised during + compilation, we will skip that function compilation and not print + the error. + + """ + + def __init__(self): + self.verbose = config.metaopt__verbose + self.track_dict = defaultdict(lambda: []) + self.tag_dict = defaultdict(lambda: []) + self._tracks = [] + self.rewriters = [] + + def register(self, rewriter: NodeRewriter, tag_list: IterableType[str]): + self.rewriters.append(rewriter) + + tracks = rewriter.tracks() + if tracks: + for c in tracks: + self.track_dict[c].append(rewriter) + self._tracks.append(c) + + for tag in tag_list: + self.tag_dict[tag].append(rewriter) + + def tracks(self): + return self._tracks + + def transform(self, fgraph, node, *args, **kwargs): + # safety check: depending on registration, tracks may have been ignored + if self._tracks is not None: + if not isinstance(node.op, tuple(self._tracks)): + return + # first, we need to provide dummy values for all inputs + # to the node that are not shared variables anyway + givens = {} + missing = set() + for input in node.inputs: + if isinstance(input, aesara.compile.SharedVariable): + pass + elif hasattr(input.tag, "test_value"): + givens[input] = aesara.shared( + input.type.filter(input.tag.test_value), + input.name, + shape=input.broadcastable, + borrow=True, + ) + else: + missing.add(input) + if missing: + givens.update(self.provide_inputs(node, missing)) + missing.difference_update(givens.keys()) + # ensure we have data for all input variables that need it + if missing: + if self.verbose > 0: + print( + f"{self.__class__.__name__} cannot meta-rewrite {node}, " + f"{len(missing)} of {int(node.nin)} input shapes unknown" + ) + return + # now we can apply the different rewrites in turn, + # compile the resulting subgraphs and time their execution + if self.verbose > 1: + print( + f"{self.__class__.__name__} meta-rewriting {node} ({len(self.get_rewrites(node))} choices):" + ) + timings = [] + for node_rewriter in self.get_rewrites(node): + outputs = node_rewriter.transform(fgraph, node, *args, **kwargs) + if outputs: + try: + fn = aesara.function( + [], outputs, givens=givens, on_unused_input="ignore" + ) + fn.trust_input = True + timing = min(self.time_call(fn) for _ in range(2)) + except MetaNodeRewriterSkip: + continue + except Exception as e: + if self.verbose > 0: + print(f"* {node_rewriter}: exception", e) + continue + else: + if self.verbose > 1: + print(f"* {node_rewriter}: {timing:.5g} sec") + timings.append((timing, outputs, node_rewriter)) + else: + if self.verbose > 0: + print(f"* {node_rewriter}: not applicable") + # finally, we choose the fastest one + if timings: + timings.sort() + if self.verbose > 1: + print(f"= {timings[0][2]}") + return timings[0][1] + return + + def provide_inputs(self, node, inputs): + """Return a dictionary mapping some `inputs` to `SharedVariable` instances of with dummy values. + + The `node` argument can be inspected to infer required input shapes. + + """ + raise NotImplementedError() + + def get_rewrites(self, node): + """Return the rewrites that apply to `node`. + + This uses ``self.track_dict[type(node.op)]`` by default. + """ + return self.track_dict[type(node.op)] + + def time_call(self, fn): + start = time.time() + fn() + return time.time() - start + + +class FromFunctionNodeRewriter(NodeRewriter): + """A `NodeRewriter` constructed from a function.""" + + def __init__(self, fn, tracks=None, requirements=()): + self.fn = fn + self._tracks = tracks + self._tracked_types = ( + tuple(t for t in tracks if isinstance(t, type)) if tracks else () + ) + self.requirements = requirements + + def transform(self, fgraph, node): + if self._tracks: + if not ( + node.op in self._tracks or isinstance(node.op, self._tracked_types) + ): + return False + + return self.fn(fgraph, node) + + def add_requirements(self, fgraph): + for req in self.requirements: + req(fgraph) + + def tracks(self): + return self._tracks + + def __str__(self): + return getattr(self, "__name__", repr(self)) + + def __repr__(self): + return f"FromFunctionNodeRewriter({repr(self.fn)}, {repr(self._tracks)}, {repr(self.requirements)})" + + def print_summary(self, stream=sys.stdout, level=0, depth=-1): + print(f"{' ' * level}{self.transform} id={id(self)}", file=stream) + + +def node_rewriter( + tracks: Optional[Sequence[Union[Op, type]]], + inplace: bool = False, + requirements: Optional[Tuple[type, ...]] = (), +): + r"""A decorator used to construct `FromFunctionNodeRewriter` instances. + + Parameters + ---------- + tracks + The `Op` types or instances to which this rewrite applies. + Use ``None`` instead of an empty list to have the rewrite apply to + all `Op`\s. + inplace + A boolean indicating whether or not the rewrite works in-place. + If ``True``, a `DestroyHandler` `Feature` is added automatically added + to the `FunctionGraph`\s applied to this rewrite. + requirements + `Feature` types required by this rewrite. + + """ + + if requirements is None: + requirements = () + + def decorator(f): + if tracks is not None: + if len(tracks) == 0: + raise ValueError( + "Use `None` instead of an empty list to make an rewrite apply to all nodes." + ) + for t in tracks: + if not ( + isinstance(t, Op) or (isinstance(t, type) and issubclass(t, Op)) + ): + raise TypeError( + "`tracks` must consist of `Op` classes or instances." + ) + req = requirements + if inplace: + dh_handler = dh.DestroyHandler + req = tuple(requirements) + ( + lambda fgraph: fgraph.attach_feature(dh_handler()), + ) + rval = FromFunctionNodeRewriter(f, tracks, req) + rval.__name__ = f.__name__ + return rval + + return decorator + + +class OpToRewriterTracker: + r"""A container that maps `NodeRewriter`\s to `Op` instances and `Op`-type inheritance.""" + + def __init__(self): + self.tracked_instances: Dict[Op, List[NodeRewriter]] = {} + self.tracked_types: Dict[type, List[NodeRewriter]] = {} + self.untracked_rewrites: List[NodeRewriter] = [] + + def add_tracker(self, rw: NodeRewriter): + """Add a `NodeRewriter` to be keyed by its `NodeRewriter.tracks` or applied generally.""" + tracks = rw.tracks() + + if tracks is None: + self.untracked_rewrites.append(rw) + else: + for c in tracks: + if isinstance(c, type): + self.tracked_types.setdefault(c, []).append(rw) + else: + self.tracked_instances.setdefault(c, []).append(rw) + + def _find_impl(self, cls) -> List[NodeRewriter]: + r"""Returns the `NodeRewriter`\s that apply to `cls` based on inheritance. + + This based on `functools._find_impl`. + """ + mro = _compose_mro(cls, self.tracked_types.keys()) + matches = [] + for t in mro: + match = self.tracked_types.get(t, None) + if match: + matches.extend(match) + return matches + + @functools.lru_cache() + def get_trackers(self, op: Op) -> List[NodeRewriter]: + """Get all the rewrites applicable to `op`.""" + return ( + self._find_impl(type(op)) + + self.tracked_instances.get(op, []) + + self.untracked_rewrites + ) + + def get_rewriters(self): + return chain( + chain.from_iterable( + chain(self.tracked_types.values(), self.tracked_instances.values()) + ), + self.untracked_rewrites, + ) + + +class SequentialNodeRewriter(NodeRewriter): + r"""An rewriter that applies a list of `NodeRewriter`\s to a node. + + Attributes + ---------- + reentrant : bool + Some global rewriters, like `NodeProcessingGraphRewriter`, use this value to + determine if they should ignore new nodes. + retains_inputs : bool + States whether or not the inputs of a transformed node are transferred + to the outputs. + """ + + def __init__( + self, + *rewriters: Rewriter, + apply_all_rewrites: bool = False, + profile: bool = False, + ): + """ + + Parameters + ---------- + rewriters + A list of rewriters to be applied to nodes. + apply_all_rewrites + If ``False``, it will return after the first successfully applied + rewrite; otherwise, it will apply every applicable rewrite + incrementally. + profile + Whether or not to profile the rewrites. + + """ + super().__init__() + + self.rewrites: Sequence[Rewriter] = rewriters + assert isinstance(self.rewrites, tuple) + + self.reentrant = any( + getattr(rewrite, "reentrant", True) for rewrite in rewriters + ) + self.retains_inputs = all( + getattr(rewrite, "retains_inputs", False) for rewrite in rewriters + ) + + self.apply_all_rewrites = apply_all_rewrites + + self.profile = profile + if self.profile: + self.time_rewrites: Dict[Rewriter, float] = {} + self.process_count: Dict[Rewriter, int] = {} + self.applied_true: Dict[Rewriter, int] = {} + self.node_created: Dict[Rewriter, int] = {} + + self.tracker = OpToRewriterTracker() + + for o in self.rewrites: + + self.tracker.add_tracker(o) + + if self.profile: + self.time_rewrites.setdefault(o, 0.0) + self.process_count.setdefault(o, 0) + self.applied_true.setdefault(o, 0) + self.node_created.setdefault(o, 0) + + def __str__(self): + return getattr( + self, + "__name__", + f"{type(self).__name__}({','.join([str(o) for o in self.rewrites])})", + ) + + def tracks(self): + t = [] + for l in self.rewrites: + at = l.tracks() + if at: + t.extend(at) + return t + + def transform(self, fgraph, node): + if len(self.rewrites) == 0: + return + + repl = None + + while True: + rewrites = self.tracker.get_trackers(node.op) + + new_repl = None + for rewrite in rewrites: + rewrite_start = time.time() + new_repl = rewrite.transform(fgraph, node) + rewrite_finish = time.time() + if self.profile: + self.time_rewrites[rewrite] += rewrite_start - rewrite_finish + self.process_count[rewrite] += 1 + if not new_repl: + continue + if isinstance(new_repl, (tuple, list)): + new_vars = new_repl + else: # It must be a dict + new_vars = list(new_repl.values()) + + if config.optimizer_verbose: + print( + f"rewriting: rewrite {rewrite} replaces node {node} with {new_repl}" + ) + + if self.profile: + self.node_created[rewrite] += len( + list(applys_between(fgraph.variables, new_vars)) + ) + self.applied_true[rewrite] += 1 + break + if not new_repl: # No rewrites applied in the last iteration + return repl + # only 1 iteration + if not self.apply_all_rewrites: + return new_repl + if not new_vars[0].owner: + # We are at the start of the graph. + return new_repl + if len(new_repl) > 1: + s = {v.owner for v in new_repl} + assert len(s) == 1 + repl = new_repl + node = new_vars[0].owner + + @classmethod + def print_profile(cls, stream, prof, level=0): + (time_rewrites, process_count, applied_true, node_created, profile) = prof + + if not profile: + return + + blanc = " " * int(level) + print(blanc, f"{cls.__name__}", file=stream) + print(blanc, "---------------------", file=stream) + count_rewrite = [] + not_used = [] + not_used_time = 0 + for o, count in process_count.items(): + if count > 0: + count_rewrite.append( + (time_rewrites[o], applied_true[o], count, o, node_created[o]) + ) + else: + not_used.append((time_rewrites[o], o)) + not_used_time += time_rewrites[o] + if count_rewrite: + print( + blanc, + " time taken - times applied - times tried - name - node_created:", + file=stream, + ) + count_rewrite.sort() + for (t, a_t, count, o, n_c) in count_rewrite[::-1]: + print( + blanc, + f" {t:.3f}s - {int(a_t)} - {int(count)} - {o} - {int(n_c)}", + file=stream, + ) + print( + blanc, + ( + f" {not_used_time:.3f}s - in {len(not_used)} rewrite(s) that were not used " + "(displaying only those with a runtime greater than 0)" + ), + file=stream, + ) + not_used.sort(key=lambda nu: (nu[0], str(nu[1]))) + for (t, o) in not_used[::-1]: + if t > 0: + # Skip rewrites that have 0 times; they probably weren't even tried. + print(blanc + " ", f" {t:.3f}s - {o}", file=stream) + else: + print(blanc, " The rewriter wasn't successful ", file=stream) + + print(file=stream) + + @staticmethod + def merge_profile(prof1, prof2): + raise NotImplementedError + + def print_summary(self, stream=sys.stdout, level=0, depth=-1): + print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream) + if depth != 0: + depth -= 1 + for lrewrite in self.rewrites: + lrewrite.print_summary(stream, level=(level + 2), depth=depth) + + def add_requirements(self, fgraph): + for rewrite in self.rewrites: + rewrite.add_requirements(fgraph) + + +class SubstitutionNodeRewriter(NodeRewriter): + """ + + Replaces the application of a certain `Op` by the application of + another `Op` that takes the same inputs as what it is replacing. + + Parameters + ---------- + op1, op2 + ``op1.make_node`` and ``op2.make_node`` must take the same number of + inputs and have the same number of outputs. + + Examples + -------- + + SubstitutionNodeRewriter(add, sub) ==> + add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x)) + + """ + + # an SubstitutionNodeRewriter does not apply to the nodes it produces + reentrant = False + # all the inputs of the original node are transferred to the outputs + retains_inputs = True + + def __init__(self, op1, op2, transfer_tags=True): + self.op1 = op1 + self.op2 = op2 + self.transfer_tags = transfer_tags + + def op_key(self): + return self.op1 + + def tracks(self): + return [self.op1] + + def transform(self, fgraph, node): + if node.op != self.op1: + return False + repl = self.op2.make_node(*node.inputs) + if self.transfer_tags: + repl.tag = copy.copy(node.tag) + for output, new_output in zip(node.outputs, repl.outputs): + new_output.tag = copy.copy(output.tag) + return repl.outputs + + def __str__(self): + return f"{self.op1} -> {self.op2}" + + +class RemovalNodeRewriter(NodeRewriter): + """ + Removes all applications of an `Op` by transferring each of its + outputs to the corresponding input. + + """ + + reentrant = False # no nodes are added at all + + def __init__(self, op): + self.op = op + + def op_key(self): + return self.op + + def tracks(self): + return [self.op] + + def transform(self, fgraph, node): + if node.op != self.op: + return False + return node.inputs + + def __str__(self): + return f"{self.op}(x) -> x" + + def print_summary(self, stream=sys.stdout, level=0, depth=-1): + print( + f"{' ' * level}{self.__class__.__name__}(self.op) id={id(self)}", + file=stream, + ) + + +class PatternNodeRewriter(NodeRewriter): + """Replace all occurrences of an input pattern with an output pattern. + + The input and output patterns have the following syntax: + + input_pattern ::= (op, , , ...) + input_pattern ::= dict(pattern = , + constraint = ) + sub_pattern ::= input_pattern + sub_pattern ::= string + sub_pattern ::= a Constant instance + sub_pattern ::= int + sub_pattern ::= float + constraint ::= lambda fgraph, expr: additional matching condition + + output_pattern ::= (op, , , ...) + output_pattern ::= string + output_pattern ::= int + output_pattern ::= float + + Each string in the input pattern is a variable that will be set to + whatever expression is found in its place. If the same string is + used more than once, the same expression must be found in those + places. If a string used in the input pattern is used in the + output pattern, the matching expression will be inserted in its + place. The input pattern cannot just be a string but the output + pattern can. + + If you put a constant variable in the input pattern, there will be a + match iff a constant variable with the same value and the same type + is found in its place. + + You can add a constraint to the match by using the ``dict(...)`` form + described above with a ``'constraint'`` key. The constraint must be a + function that takes the fgraph and the current Variable that we are + trying to match and returns True or False according to an + arbitrary criterion. + + The constructor creates a `PatternNodeRewriter` that replaces occurrences of + `in_pattern` by occurrences of `out_pattern`. + + Examples + -------- + + PatternNodeRewriter((add, 'x', 'y'), (add, 'y', 'x')) + PatternNodeRewriter((multiply, 'x', 'x'), (square, 'x')) + PatternNodeRewriter((subtract, (add, 'x', 'y'), 'y'), 'x') + PatternNodeRewriter((power, 'x', Constant(double, 2.0)), (square, 'x')) + PatternNodeRewriter((boggle, {'pattern': 'x', + 'constraint': lambda expr: expr.type == scrabble}), + (scrabble, 'x')) + + """ + + def __init__( + self, + in_pattern, + out_pattern, + allow_multiple_clients: bool = False, + skip_identities_fn=None, + name: Optional[str] = None, + tracks=(), + get_nodes=None, + values_eq_approx=None, + ): + """ + + Parameters + ---------- + in_pattern + The input pattern that we want to replace. + out_pattern + The replacement pattern. + allow_multiple_clients + If ``False``, the pattern matching will fail if one of the subpatterns has + more than one client. + skip_identities_fn + TODO + name + Set the name of this rewriter. + tracks + The values that :meth:`self.tracks` will return. + get_nodes + If you provide `tracks`, you must provide this parameter. It must be a + function that takes the tracked node and returns a list of nodes on + which we will try this rewrite. + + Notes + ----- + `tracks` and `get_nodes` can be used to make this rewrite track a less + frequent `Op`, which will prevent the rewrite from being tried as + often. + + """ + from aesara.graph.rewriting.unify import convert_strs_to_vars + + var_map: Dict[str, "Var"] = {} + self.in_pattern = convert_strs_to_vars(in_pattern, var_map=var_map) + self.out_pattern = convert_strs_to_vars(out_pattern, var_map=var_map) + self.values_eq_approx = values_eq_approx + if isinstance(in_pattern, (list, tuple)): + self.op = self.in_pattern[0] + elif isinstance(in_pattern, dict): + self.op = self.in_pattern["pattern"][0] + else: + raise TypeError( + "The pattern to search for must start with a specific Op instance." + ) + self.__doc__ = f"{self.__class__.__doc__}\n\nThis instance does: {self}\n" + self.allow_multiple_clients = allow_multiple_clients + self.skip_identities_fn = skip_identities_fn + if name: + self.__name__ = name + self._tracks = tracks + self.get_nodes = get_nodes + if tracks != (): + assert get_nodes + + def op_key(self): + return self.op + + def tracks(self): + if self._tracks != (): + return self._tracks + return [self.op] + + def transform(self, fgraph, node, get_nodes=True): + """Check if the graph from node corresponds to ``in_pattern``. + + If it does, it constructs ``out_pattern`` and performs the replacement. + + """ + from etuples.core import ExpressionTuple + from unification import reify, unify + + # TODO: We shouldn't need to iterate like this. + if not self.allow_multiple_clients and any( + len(fgraph.clients.get(v)) > 1 + for v in vars_between(fgraph.inputs, node.outputs) + if v not in fgraph.inputs + ): + return False + + if get_nodes and self.get_nodes is not None: + for real_node in self.get_nodes(fgraph, node): + if real_node == "output": + continue + ret = self.transform(fgraph, real_node, get_nodes=False) + if ret is not False and ret is not None: + return dict(zip(real_node.outputs, ret)) + + if node.op != self.op: + return False + + s = unify(self.in_pattern, node.out) + + if s is False: + return False + + ret = reify(self.out_pattern, s) + + if isinstance(ret, ExpressionTuple): + ret = ret.evaled_obj + + if self.values_eq_approx: + ret.tag.values_eq_approx = self.values_eq_approx + + if ret.owner: + if not ( + len(node.outputs) == len(ret.owner.outputs) + and all( + o.type.is_super(new_o.type) + for o, new_o in zip(node.outputs, ret.owner.outputs) + ) + ): + return False + else: + # ret is just an input variable + assert len(node.outputs) == 1 + if not node.outputs[0].type.is_super(ret.type): + return False + + return [ret] + + def __str__(self): + if getattr(self, "__name__", None): + return self.__name__ + + def pattern_to_str(pattern): + if isinstance(pattern, (list, tuple)): + return "{}({})".format( + str(pattern[0]), + ", ".join([pattern_to_str(p) for p in pattern[1:]]), + ) + elif isinstance(pattern, dict): + return "{} subject to {}".format( + pattern_to_str(pattern["pattern"]), + str(pattern.get("constraint", "no conditions")), + ) + else: + return str(pattern) + + return "{} -> {}".format( + pattern_to_str(self.in_pattern), + pattern_to_str(self.out_pattern), + ) + + def __repr__(self): + return str(self) + + def print_summary(self, stream=sys.stdout, level=0, depth=-1): + name = getattr(self, "__name__", getattr(self, "name", None)) + print( + f"{' ' * level}{self.__class__.__name__} {name}({self.in_pattern}, {self.out_pattern}) id={id(self)}", + file=stream, + ) + + +class DispatchingFeature(Feature): + """A `Feature` consisting of user-defined functions implementing each `Feature` callback method.""" + + def __init__(self, importer, pruner, chin, name=None): + self.importer = importer + self.pruner = pruner + self.chin = chin + self.name = name + + def __str__(self): + return f"{type(self).__name__}{{{self.name}}}" + + def on_import(self, fgraph, node, reason): + if self.importer: + self.importer(node) + + def on_prune(self, fgraph, node, reason): + if self.pruner: + self.pruner(node) + + def on_change_input(self, fgraph, node, i, r, new_r, reason): + if self.chin: + self.chin(node, i, r, new_r, reason) + + def on_detach(self, fgraph): + # To allow pickling this object + self.importer = None + self.pruner = None + self.chin = None + + +class NodeProcessingGraphRewriter(GraphRewriter): + r"""A class providing a base implementation for applying `NodeRewriter.transform` results to a graph. + + This rewriter accepts the output of `NodeRewriter.transform` + implementations and applies them to a `FunctionGraph`. + + It accepts a sequence of new output nodes or ``dict``s. Entries in + these ``dict``\s can be `Variable`\s and their new values. It also accepts + a special ``"remove"`` key. A sequence of `Variable`\s mapped to the key + ``"remove"`` are removed from the `FunctionGraph`. + + It also adds some interface elements for simple reentrant/recursive + application of rewrites. The parameter `NodeRewriter.ignore_newtrees` is + intended to be used by subclasses, alongside the + `NodeRewriter.attach_updater` and `NodeRewriter.detach_updater` methods, to + determine whether or not sub-graphs created by rewrites are to have the + same rewrites applied to them. + + """ + + @classmethod + def warn(cls, exc, nav, repl_pairs, node_rewriter, node): + """A failure callback that prints a traceback.""" + if config.on_opt_error != "ignore": + _logger.error(f"Rewrite failure due to: {node_rewriter}") + _logger.error(f"node: {node}") + _logger.error("TRACEBACK:") + _logger.error(traceback.format_exc()) + if config.on_opt_error == "pdb": + pdb.post_mortem(sys.exc_info()[2]) + elif isinstance(exc, AssertionError) or config.on_opt_error == "raise": + # We always crash on AssertionError because something may be + # seriously wrong if such an exception is raised. + raise exc + + @classmethod + def warn_inplace(cls, exc, nav, repl_pairs, node_rewriter, node): + r"""A failure callback that ignores `InconsistencyError`\s and prints a traceback. + + If the error occurred during replacement, `repl_pairs` is set; + otherwise, its value is ``None``. + + """ + if isinstance(exc, InconsistencyError): + return + return cls.warn(exc, nav, repl_pairs, node_rewriter, node) + + @classmethod + def warn_ignore(cls, exc, nav, repl_pairs, node_rewriter, node): + """A failure callback that ignores all errors.""" + + def __init__( + self, + node_rewriter: Optional[NodeRewriter], + ignore_newtrees: Literal[True, False, "auto"], + failure_callback: Optional[FailureCallbackType] = None, + ): + """ + + Parameters + ---------- + node_rewriter + A `NodeRewriter` to apply over a `FunctionGraph` (or ``None``). + ignore_newtrees + - ``True``: new subgraphs returned by an `NodeRewriter` are not a + candidate for rewriting. + - ``False``: new subgraphs returned by an `NodeRewriter` is a + candidate for rewriting. + - ``'auto'``: let the `node_rewriter` set this parameter via its + :attr:`reentrant` attribute. + failure_callback + A function with the signature + ``(exception, navigator, [(old, new), (old,new),...])`` + that is called when there's an exception. + + If the exception is raised in `node_rewriter.transform`, the + ``new`` variables will be ``None``. + + If the exception is raised during validation (e.g. the new types + don't match) then the new variables will be the ones created by + ``self.transform``. + + If this parameter is ``None``, then exceptions are not caught here + and are raised normally. + + """ + self.node_rewriter = node_rewriter + if ignore_newtrees == "auto": + self.ignore_newtrees = not getattr(node_rewriter, "reentrant", True) + else: + self.ignore_newtrees = ignore_newtrees + self.failure_callback = failure_callback + super().__init__() + + def attach_updater( + self, + fgraph: FunctionGraph, + importer: Optional[Callable], + pruner: Optional[Callable], + chin: Optional[Callable] = None, + name: Optional[str] = None, + ) -> Optional[DispatchingFeature]: + r"""Install `FunctionGraph` listeners to help the navigator deal with the recursion-related functionality. + + Parameters + ---------- + importer + Function to be called when a rewrite adds something to the graph. + pruner + Function to be called when a rewrite removes something from the + graph. + chin + Function to be called when a node's inputs change. + name + Name of the `DispatchingFeature` to attach. + + Returns + ------- + The `FunctionGraph` plugin that handles the three tasks. + Keep this around so that `Feature`\s can be detached later. + + """ + if self.ignore_newtrees: + importer = None + + if importer is None and pruner is None: + return None + + u = DispatchingFeature(importer, pruner, chin, name=name) + fgraph.attach_feature(u) + return u + + def detach_updater( + self, fgraph: FunctionGraph, updater: Optional[DispatchingFeature] + ): + """Undo the work of `attach_updater`. + + Parameters + ---------- + fgraph + The `FunctionGraph`. + updater + The `DispatchingFeature` to remove. + + Returns + ------- + None + + """ + if updater is not None: + fgraph.remove_feature(updater) + + def process_node( + self, + fgraph: FunctionGraph, + node: Apply, + node_rewriter: Optional[NodeRewriter] = None, + ): + r"""Apply `node_rewriter` to `node`. + + The :meth:`node_rewriter.transform` method will return either ``False``, a + list of `Variable`\s that are intended to replace :attr:`node.outputs`, or + a ``dict`` specifying replacements--or the key ``"remove"`` mapped to a + sequence of `Variable`\s to be removed. + + Parameters + ---------- + fgraph + A `FunctionGraph`. + node + An `Apply` instance in `fgraph` + node_rewriter + A `NodeRewriter` instance that may have a better idea for + how to compute node's outputs. + + Returns + ------- + bool + If `fgraph` accepts the replacement, then the rewrite is + successful and this function returns ``True``. If there are no + replacement candidates, or the `fgraph` rejects the replacements, + this function returns ``False``. + + + """ + node_rewriter = node_rewriter or self.node_rewriter + # TODO FIXME: This class's interface is broken + assert node_rewriter is not None + try: + replacements = node_rewriter.transform(fgraph, node) + except Exception as e: + if self.failure_callback is not None: + self.failure_callback( + e, self, [(x, None) for x in node.outputs], node_rewriter, node + ) + return False + else: + raise + if replacements is False or replacements is None: + return False + old_vars = node.outputs + remove: List[Variable] = [] + if isinstance(replacements, dict): + if "remove" in replacements: + remove = list(cast(Sequence[Variable], replacements.pop("remove"))) + old_vars = list(cast(Sequence[Variable], replacements.keys())) + replacements = list(cast(Sequence[Variable], replacements.values())) + elif not isinstance(replacements, (tuple, list)): + raise TypeError( + f"Node rewriter {node_rewriter} gave wrong type of replacement. " + f"Expected list or tuple; got {replacements}" + ) + if len(old_vars) != len(replacements): + raise ValueError( + f"Node rewriter {node_rewriter} gave wrong number of replacements" + ) + # None in the replacement mean that this variable isn't used + # and we want to remove it + for r, rnew in zip(old_vars, replacements): + if rnew is None and len(fgraph.clients[r]) > 0: + raise ValueError( + f"Node rewriter {node_rewriter} tried to remove a variable" + f" that is being used: {r}" + ) + # If an output would be replaced by itself, no need to perform + # the replacement + repl_pairs = [ + (r, rnew) + for r, rnew in zip(old_vars, replacements) + if rnew is not r and rnew is not None + ] + + if len(repl_pairs) == 0: + return False + try: + fgraph.replace_all_validate_remove( # type: ignore + repl_pairs, reason=node_rewriter, remove=remove + ) + return True + except Exception as e: + # This means the replacements were rejected by the fgraph. + # + # This is not supposed to happen. The default failure_callback + # will print a traceback as a warning. + if self.failure_callback is not None: + self.failure_callback(e, self, repl_pairs, node_rewriter, node) + return False + else: + raise + + def add_requirements(self, fgraph): + super().add_requirements(fgraph) + # Added by default + # fgraph.attach_feature(ReplaceValidate()) + if self.node_rewriter: + self.node_rewriter.add_requirements(fgraph) + + def print_summary(self, stream=sys.stdout, level=0, depth=-1): + print(f"{' ' * level}{self.__class__.__name__} id={id(self)}", file=stream) + if depth != 0: + self.node_rewriter.print_summary( + stream, level=(level + 2), depth=(depth - 1) + ) + + +class WalkingGraphRewriter(NodeProcessingGraphRewriter): + """A rewriter that applies a single `NodeRewriter` to each node in topological order (or reverse).""" + + def __init__( + self, + node_rewriter: NodeRewriter, + order: Literal["out_to_in", "in_to_out"] = "in_to_out", + ignore_newtrees: bool = False, + failure_callback: Optional[FailureCallbackType] = None, + ): + if order not in ("out_to_in", "in_to_out"): + raise ValueError("order must be 'out_to_in' or 'in_to_out'") + self.order = order + super().__init__(node_rewriter, ignore_newtrees, failure_callback) + + def apply(self, fgraph, start_from=None): + if start_from is None: + start_from = fgraph.outputs + callback_before = fgraph.execute_callbacks_time + nb_nodes_start = len(fgraph.apply_nodes) + t0 = time.time() + q = deque(io_toposort(fgraph.inputs, start_from)) + io_t = time.time() - t0 + + def importer(node): + if node is not current_node: + q.append(node) + + u = self.attach_updater( + fgraph, importer, None, name=getattr(self, "name", None) + ) + nb = 0 + try: + t0 = time.time() + while q: + if self.order == "out_to_in": + node = q.pop() + else: + node = q.popleft() + if node not in fgraph.apply_nodes: + continue + current_node = node + nb += self.process_node(fgraph, node) + loop_t = time.time() - t0 + finally: + self.detach_updater(fgraph, u) + + callback_time = fgraph.execute_callbacks_time - callback_before + nb_nodes_end = len(fgraph.apply_nodes) + return ( + self, + nb, + nb_nodes_start, + nb_nodes_end, + io_t, + loop_t, + callback_time, + self.node_rewriter, + ) + + @classmethod + def print_profile(cls, stream, prof, level=0): + blanc = " " * level + if prof is None: # Happen as merge_profile() isn't implemented + print(blanc, f"{cls.__name__} merge_profile not implemented", file=stream) + return + + ( + rewrite, + nb, + nb_nodes_start, + nb_nodes_end, + io_t, + loop_t, + callback_time, + node_rewriter, + ) = prof + + print( + blanc, + f"{cls.__name__} ", + getattr(rewrite, "name", getattr(rewrite, "__name__", "")), + file=stream, + ) + + print( + blanc, + " nb_node (start, end, changed)", + (nb_nodes_start, nb_nodes_end, nb), + file=stream, + ) + print(blanc, " init io_toposort", io_t, file=stream) + print(blanc, " loop time", loop_t, file=stream) + print(blanc, " callback_time", callback_time, file=stream) + if isinstance(node_rewriter, SequentialNodeRewriter): + if node_rewriter.profile: + node_rewriter.print_profile( + stream, + ( + node_rewriter.time_rewrites, + node_rewriter.process_count, + node_rewriter.applied_true, + node_rewriter.node_created, + node_rewriter.profile, + ), + level=level + 1, + ) + + def __str__(self): + return getattr(self, "__name__", super().__str__()) + + +def walking_rewriter( + order, + *node_rewriters, + name=None, + failure_callback=WalkingGraphRewriter.warn_inplace, + **kwargs, +): + r"""Apply `node_rewriters` from the input/output nodes to the output/input nodes of a graph. + + This constructs `WalkingGraphRewriter`\s, and uses a `SequentialNodeRewriter` when there's + more than one entry in `node_rewriters`. + """ + if len(node_rewriters) > 1: + # Don't wrap it uselessly if there is only one rewrite. + node_rewriters = SequentialNodeRewriter(*node_rewriters) + else: + (node_rewriters,) = node_rewriters + if not name: + name = node_rewriters.__name__ + ret = WalkingGraphRewriter( + node_rewriters, + order=order, + failure_callback=failure_callback, + **kwargs, + ) + if name: + ret.__name__ = name + return ret + + +in2out = partial(walking_rewriter, "in_to_out") +out2in = partial(walking_rewriter, "out_to_in") + + +class OpKeyGraphRewriter(NodeProcessingGraphRewriter): + r"""A rewriter that applies a `NodeRewriter` to specific `Op`\s. + + The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either + as a list of `Op`\s or a single `Op`), and discovered within a + `FunctionGraph` using the `NodeFinder` `Feature`. + + This is similar to the `Op`-based tracking feature used by other rewriters. + + """ + + def __init__(self, node_rewriter, ignore_newtrees=False, failure_callback=None): + if not hasattr(node_rewriter, "op_key"): + raise TypeError(f"{node_rewriter} must have an `op_key` method.") + super().__init__(node_rewriter, ignore_newtrees, failure_callback) + + def apply(self, fgraph): + op = self.node_rewriter.op_key() + if isinstance(op, (list, tuple)): + q = reduce(list.__iadd__, map(fgraph.get_nodes, op)) + else: + q = list(fgraph.get_nodes(op)) + + def importer(node): + if node is not current_node: + if node.op == op: + q.append(node) + + u = self.attach_updater( + fgraph, importer, None, name=getattr(self, "name", None) + ) + try: + while q: + node = q.pop() + if node not in fgraph.apply_nodes: + continue + current_node = node + self.process_node(fgraph, node) + finally: + self.detach_updater(fgraph, u) + + def add_requirements(self, fgraph): + super().add_requirements(fgraph) + fgraph.attach_feature(NodeFinder()) + + +class ChangeTracker(Feature): + def __init__(self): + self.changed = False + self.nb_imported = 0 + + def clone(self): + return type(self)() + + def on_import(self, fgraph, node, reason): + self.nb_imported += 1 + self.changed = True + + def on_change_input(self, fgraph, node, i, r, new_r, reason): + self.changed = True + + def reset(self): + self.changed = False + + def on_attach(self, fgraph): + if hasattr(fgraph, "change_tracker"): + raise AlreadyThere() + fgraph.change_tracker = self + + def on_detach(self, fgraph): + del fgraph.change_tracker + + +def merge_dict(d1, d2): + r"""Merge two ``dict``\s by adding their values.""" + d = d1.copy() + for k, v in d2.items(): + if k in d: + d[k] += v + else: + d[k] = v + return d + + +class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): + """A `Rewriter` that applies its rewrites until a fixed-point/equilibrium is reached.""" + + def __init__( + self, + rewriters: Sequence[Rewriter], + failure_callback: Optional[FailureCallbackType] = None, + ignore_newtrees: bool = True, + tracks_on_change_inputs: bool = False, + max_use_ratio: Optional[float] = None, + final_rewriters: Optional[Sequence[GraphRewriter]] = None, + cleanup_rewriters: Optional[Sequence[GraphRewriter]] = None, + ): + """ + + Parameters + ---------- + rewriters + Node or graph rewriters to apply until equilibrium. + The global rewriter will be run at the start of each iteration before + the node rewriter. + failure_callback + See :attr:`NodeProcessingGraphRewriter.failure_callback`. + ignore_newtrees + See :attr:`NodeProcessingGraphRewriter.ignore_newtrees`. + tracks_on_change_inputs + See :attr:`NodeProcessingGraphRewriter.tracks_on_change_inputs`. + max_use_ratio + Each rewriter can be applied at most ``(size_of_graph * max_use_ratio)`` + times. + final_rewriters + Rewriters that will be run after each iteration. + cleanup_rewriters + Rewriters applied after all graph rewriters, then when one + `NodeRewriter` is applied, then after all final rewriters. + They should not traverse the entire graph, since they are called + very frequently. The `MergeOptimizer` is one example of a rewriter + that respects this. + + """ + super().__init__( + None, ignore_newtrees=ignore_newtrees, failure_callback=failure_callback + ) + self.global_rewriters: List[GraphRewriter] = [] + self.tracks_on_change_inputs = tracks_on_change_inputs + + self.node_tracker = OpToRewriterTracker() + + for rewriter in rewriters: + if isinstance(rewriter, NodeRewriter): + self.node_tracker.add_tracker(rewriter) + else: + assert isinstance(rewriter, GraphRewriter) + self.global_rewriters.append(rewriter) + + if final_rewriters: + self.final_rewriters = list(final_rewriters) + else: + self.final_rewriters = [] + + if cleanup_rewriters: + self.cleanup_rewriters = list(cleanup_rewriters) + else: + self.cleanup_rewriters = [] + + self.max_use_ratio = max_use_ratio + + def get_node_rewriters(self): + yield from self.node_tracker.get_rewriters() + + def get_local_optimizers(self): + warnings.warn( + "`get_local_optimizers` is deprecated; use `get_node_rewriters` instead.", + DeprecationWarning, + stacklevel=2, + ) + yield from self.get_node_rewriters() + + def add_requirements(self, fgraph): + super().add_requirements(fgraph) + for rewriter in self.get_node_rewriters(): + rewriter.add_requirements(fgraph) + for rewriter in self.global_rewriters: + rewriter.add_requirements(fgraph) + for rewriter in self.final_rewriters: + rewriter.add_requirements(fgraph) + for rewriter in self.cleanup_rewriters: + rewriter.add_requirements(fgraph) + + def apply(self, fgraph, start_from=None): + change_tracker = ChangeTracker() + fgraph.attach_feature(change_tracker) + if start_from is None: + start_from = fgraph.outputs + else: + for node in start_from: + assert node in fgraph.outputs + + changed = True + max_use_abort = False + rewriter_name = None + global_process_count = {} + start_nb_nodes = len(fgraph.apply_nodes) + max_nb_nodes = len(fgraph.apply_nodes) + max_use = max_nb_nodes * self.max_use_ratio + + loop_timing = [] + loop_process_count = [] + global_rewriter_timing = [] + time_rewriters = {} + io_toposort_timing = [] + nb_nodes = [] + node_created = {} + global_sub_profs = [] + final_sub_profs = [] + cleanup_sub_profs = [] + for rewriter in ( + self.global_rewriters + + list(self.get_node_rewriters()) + + self.final_rewriters + + self.cleanup_rewriters + ): + global_process_count.setdefault(rewriter, 0) + time_rewriters.setdefault(rewriter, 0) + node_created.setdefault(rewriter, 0) + + def apply_cleanup(profs_dict): + changed = False + for crewriter in self.cleanup_rewriters: + change_tracker.reset() + nb = change_tracker.nb_imported + t_rewrite = time.time() + sub_prof = crewriter.apply(fgraph) + time_rewriters[crewriter] += time.time() - t_rewrite + profs_dict[crewriter].append(sub_prof) + if change_tracker.changed: + process_count.setdefault(crewriter, 0) + process_count[crewriter] += 1 + global_process_count[crewriter] += 1 + changed = True + node_created[crewriter] += change_tracker.nb_imported - nb + return changed + + while changed and not max_use_abort: + process_count = {} + t0 = time.time() + changed = False + iter_cleanup_sub_profs = {} + for crewrite in self.cleanup_rewriters: + iter_cleanup_sub_profs[crewrite] = [] + + # Apply global rewriters + sub_profs = [] + for grewrite in self.global_rewriters: + change_tracker.reset() + nb = change_tracker.nb_imported + t_rewrite = time.time() + sub_prof = grewrite.apply(fgraph) + time_rewriters[grewrite] += time.time() - t_rewrite + sub_profs.append(sub_prof) + if change_tracker.changed: + process_count.setdefault(grewrite, 0) + process_count[grewrite] += 1 + global_process_count[grewrite] += 1 + changed = True + node_created[grewrite] += change_tracker.nb_imported - nb + if global_process_count[grewrite] > max_use: + max_use_abort = True + rewriter_name = getattr(grewrite, "name", None) or getattr( + grewrite, "__name__", "" + ) + global_sub_profs.append(sub_profs) + + global_rewriter_timing.append(float(time.time() - t0)) + + changed |= apply_cleanup(iter_cleanup_sub_profs) + + topo_t0 = time.time() + q = deque(io_toposort(fgraph.inputs, start_from)) + io_toposort_timing.append(time.time() - topo_t0) + + nb_nodes.append(len(q)) + max_nb_nodes = max(max_nb_nodes, len(q)) + max_use = max_nb_nodes * self.max_use_ratio + + def importer(node): + if node is not current_node: + q.append(node) + + chin = None + if self.tracks_on_change_inputs: + + def chin(node, i, r, new_r, reason): + if node is not current_node and not isinstance(node, str): + q.append(node) + + u = self.attach_updater( + fgraph, importer, None, chin=chin, name=getattr(self, "name", None) + ) + try: + while q: + node = q.pop() + if node not in fgraph.apply_nodes: + continue + current_node = node + for node_rewriter in self.node_tracker.get_trackers(node.op): + nb = change_tracker.nb_imported + t_rewrite = time.time() + node_rewriter_change = self.process_node( + fgraph, node, node_rewriter + ) + time_rewriters[node_rewriter] += time.time() - t_rewrite + if not node_rewriter_change: + continue + process_count.setdefault(node_rewriter, 0) + process_count[node_rewriter] += 1 + global_process_count[node_rewriter] += 1 + changed = True + node_created[node_rewriter] += change_tracker.nb_imported - nb + changed |= apply_cleanup(iter_cleanup_sub_profs) + if global_process_count[node_rewriter] > max_use: + max_use_abort = True + rewriter_name = getattr( + node_rewriter, "name", None + ) or getattr(node_rewriter, "__name__", "") + if node not in fgraph.apply_nodes: + # go to next node + break + finally: + self.detach_updater(fgraph, u) + + # Apply final rewriters + sub_profs = [] + t_before_final_rewrites = time.time() + for grewrite in self.final_rewriters: + change_tracker.reset() + nb = change_tracker.nb_imported + t_rewrite = time.time() + sub_prof = grewrite.apply(fgraph) + time_rewriters[grewrite] += time.time() - t_rewrite + sub_profs.append(sub_prof) + if change_tracker.changed: + process_count.setdefault(grewrite, 0) + process_count[grewrite] += 1 + global_process_count[grewrite] += 1 + changed = True + node_created[grewrite] += change_tracker.nb_imported - nb + if global_process_count[grewrite] > max_use: + max_use_abort = True + rewriter_name = getattr(grewrite, "name", None) or getattr( + grewrite, "__name__", "" + ) + final_sub_profs.append(sub_profs) + + global_rewriter_timing[-1] += time.time() - t_before_final_rewrites + + changed |= apply_cleanup(iter_cleanup_sub_profs) + + # Merge clean up profiles during that iteration + c_sub_profs = [] + for crewrite, sub_profs in iter_cleanup_sub_profs.items(): + sub_prof = sub_profs[0] + for s_p in sub_profs[1:]: + sub_prof = crewrite.merge_profile(sub_prof, s_p) + c_sub_profs.append(sub_prof) + cleanup_sub_profs.append(c_sub_profs) + + loop_process_count.append(process_count) + loop_timing.append(float(time.time() - t0)) + + end_nb_nodes = len(fgraph.apply_nodes) + + if max_use_abort: + msg = ( + f"{type(self).__name__} max'ed out by {rewriter_name}." + "You can safely raise the current threshold of " + f"{config.optdb__max_use_ratio} with the option `optdb__max_use_ratio`." + ) + if config.on_opt_error == "raise": + raise AssertionError(msg) + else: + _logger.error(msg) + fgraph.remove_feature(change_tracker) + assert len(loop_process_count) == len(loop_timing) + assert len(loop_process_count) == len(global_rewriter_timing) + assert len(loop_process_count) == len(nb_nodes) + assert len(loop_process_count) == len(io_toposort_timing) + assert len(loop_process_count) == len(global_sub_profs) + assert len(loop_process_count) == len(final_sub_profs) + assert len(loop_process_count) == len(cleanup_sub_profs) + return ( + self, + loop_timing, + loop_process_count, + (start_nb_nodes, end_nb_nodes, max_nb_nodes), + global_rewriter_timing, + nb_nodes, + time_rewriters, + io_toposort_timing, + node_created, + global_sub_profs, + final_sub_profs, + cleanup_sub_profs, + ) + + def print_summary(self, stream=sys.stdout, level=0, depth=-1): + name = getattr(self, "name", None) + print( + f"{' ' * level}{self.__class__.__name__} {name} id={id(self)}", file=stream + ) + if depth != 0: + for node_rewriter in self.get_node_rewriters(): + node_rewriter.print_summary( + stream, level=(level + 2), depth=(depth - 1) + ) + + @classmethod + def print_profile(cls, stream, prof, level=0): + ( + rewrite, + loop_timing, + loop_process_count, + (start_nb_nodes, end_nb_nodes, max_nb_nodes), + global_rewrite_timing, + nb_nodes, + time_rewrites, + io_toposort_timing, + node_created, + global_sub_profs, + final_sub_profs, + cleanup_sub_profs, + ) = prof + + blanc = " " * level + print(blanc, cls.__name__, end=" ", file=stream) + print( + blanc, + getattr(rewrite, "name", getattr(rewrite, "__name__", "")), + file=stream, + ) + print( + blanc, + f" time {sum(loop_timing):.3f}s for {len(loop_timing)} passes", + file=stream, + ) + print( + blanc, + f" nb nodes (start, end, max) {int(start_nb_nodes)} {int(end_nb_nodes)} {int(max_nb_nodes)}", + file=stream, + ) + print(blanc, f" time io_toposort {sum(io_toposort_timing):.3f}s", file=stream) + s = sum(time_rewrites[o] for o in rewrite.get_node_rewriters()) + print(blanc, f" time in node rewriters {s:.3f}s", file=stream) + s = sum(time_rewrites[o] for o in rewrite.global_rewriters) + print(blanc, f" time in graph rewriters {s:.3f}s", file=stream) + s = sum(time_rewrites[o] for o in rewrite.final_rewriters) + print(blanc, f" time in final rewriters {s:.3f}s", file=stream) + s = sum(time_rewrites[o] for o in rewrite.cleanup_rewriters) + print(blanc, f" time in cleanup rewriters {s:.3f}s", file=stream) + for i in range(len(loop_timing)): + loop_times = "" + if loop_process_count[i]: + d = list( + reversed(sorted(loop_process_count[i].items(), key=lambda a: a[1])) + ) + loop_times = " ".join([str((str(k), v)) for k, v in d[:5]]) + if len(d) > 5: + loop_times += " ..." + print( + blanc, + ( + f" {int(i):2d} - {loop_timing[i]:.3f}s {int(sum(loop_process_count[i].values()))} ({global_rewrite_timing[i]:.3f}s in graph rewriters, " + f"{io_toposort_timing[i]:.3f}s io_toposort) - {int(nb_nodes[i])} nodes - {loop_times}" + ), + file=stream, + ) + + count_rewrite = [] + not_used = [] + not_used_time = 0 + process_count = {} + for o in ( + rewrite.global_rewriters + + list(rewrite.get_node_rewriters()) + + list(rewrite.final_rewriters) + + list(rewrite.cleanup_rewriters) + ): + process_count.setdefault(o, 0) + for count in loop_process_count: + for o, v in count.items(): + process_count[o] += v + for o, count in process_count.items(): + if count > 0: + count_rewrite.append((time_rewrites[o], count, node_created[o], o)) + else: + not_used.append((time_rewrites[o], o)) + not_used_time += time_rewrites[o] + + if count_rewrite: + print( + blanc, " times - times applied - nb node created - name:", file=stream + ) + count_rewrite.sort() + for (t, count, n_created, o) in count_rewrite[::-1]: + print( + blanc, + f" {t:.3f}s - {int(count)} - {int(n_created)} - {o}", + file=stream, + ) + print( + blanc, + f" {not_used_time:.3f}s - in {len(not_used)} rewrites that were not used (i.e. those with a run-time of zero)", + file=stream, + ) + not_used.sort(key=lambda nu: (nu[0], str(nu[1]))) + for (t, o) in not_used[::-1]: + if t > 0: + # Skip rewrites that have no run-times; they probably weren't even tried. + print(blanc + " ", f" {t:.3f}s - {o}", file=stream) + print(file=stream) + gf_rewrites = [ + o + for o in ( + rewrite.global_rewrites + + list(rewrite.final_rewriters) + + list(rewrite.cleanup_rewriters) + ) + if o.print_profile.__code__ is not GraphRewriter.print_profile.__code__ + ] + if not gf_rewrites: + return + print(blanc, "Global, final, and clean up rewriters", file=stream) + for i in range(len(loop_timing)): + print(blanc, f"Iter {int(i)}", file=stream) + for o, prof in zip(rewrite.global_rewriters, global_sub_profs[i]): + try: + o.print_profile(stream, prof, level + 2) + except NotImplementedError: + print(blanc, "merge not implemented for ", o) + for o, prof in zip(rewrite.final_rewriters, final_sub_profs[i]): + try: + o.print_profile(stream, prof, level + 2) + except NotImplementedError: + print(blanc, "merge not implemented for ", o) + for o, prof in zip(rewrite.cleanup_rewriters, cleanup_sub_profs[i]): + try: + o.print_profile(stream, prof, level + 2) + except NotImplementedError: + print(blanc, "merge not implemented for ", o) + + @staticmethod + def merge_profile(prof1, prof2): + node_rewriters = OrderedSet(prof1[0].get_node_rewriters()).union( + prof2[0].get_node_rewriters() + ) + global_rewriters = OrderedSet(prof1[0].global_rewriters).union( + prof2[0].global_rewriters + ) + final_rewriters = list( + OrderedSet(prof1[0].final_rewriters).union(prof2[0].final_rewriters) + ) + cleanup_rewriters = list( + OrderedSet(prof1[0].cleanup_rewriters).union(prof2[0].cleanup_rewriters) + ) + new_rewriter = EquilibriumGraphRewriter( + node_rewriters.union(global_rewriters), + max_use_ratio=1, + final_rewriters=final_rewriters, + cleanup_rewriters=cleanup_rewriters, + ) + + def add_append_list(l1, l2): + l = copy.copy(l1) + for idx, nb in enumerate(l2): + if idx < len(l): + l[idx] += nb + else: + l.append(nb) + return l + + loop_timing = add_append_list(prof1[1], prof2[1]) + + loop_process_count = list(prof1[2]) + global_sub_profs = [] + final_sub_profs = [] + cleanup_sub_profs = [] + + for i in range(min(len(loop_process_count), len(prof2[2]))): + process_count = loop_process_count[i] + for process, count in prof2[2][i].items(): + if process in process_count: + process_count[process] += count + else: + process_count[process] = count + + def merge(rewriters, attr, idx): + tmp = [] + for rewriter in rewriters: + o1 = getattr(prof1[0], attr) + o2 = getattr(prof2[0], attr) + if rewriter in o1 and rewriter in o2: + p1 = prof1[idx][i][o1.index(rewriter)] + p2 = prof2[idx][i][o2.index(rewriter)] + m = None + if hasattr(rewriter, "merge_profile"): + m = rewriter.merge_profile(p1, p2) + elif rewriter in o1: + m = prof1[idx][i][o1.index(rewriter)] + else: + m = prof2[idx][i][o2.index(rewriter)] + tmp.append(m) + return tmp + + global_sub_profs.append(merge(global_rewriters, "global_rewriters", 9)) + final_sub_profs.append(merge(final_rewriters, "final_rewriters", 10)) + cleanup_sub_profs.append(merge(cleanup_rewriters, "cleanup_rewriters", 11)) + + # Add the iteration done by only one of the profile. + loop_process_count.extend(prof1[2][len(loop_process_count) :]) + global_sub_profs.extend(prof1[9][len(global_sub_profs) :]) + final_sub_profs.extend(prof1[10][len(final_sub_profs) :]) + cleanup_sub_profs.extend(prof1[11][len(cleanup_sub_profs) :]) + + global_sub_profs.extend(prof2[9][len(loop_process_count) :]) + final_sub_profs.extend(prof2[10][len(loop_process_count) :]) + cleanup_sub_profs.extend(prof2[11][len(loop_process_count) :]) + + max_nb_nodes = max(prof1[3], prof2[3]) + + global_rewrite_timing = add_append_list(prof1[4], prof2[4]) + + nb_nodes = add_append_list(prof1[5], prof2[5]) + + time_rewrites = merge_dict(prof1[6], prof2[6]) + io_toposort_timing = add_append_list(prof1[7], prof2[7]) + assert ( + len(loop_timing) + == len(global_rewrite_timing) + == len(global_sub_profs) + == len(io_toposort_timing) + == len(nb_nodes) + ) + assert len(loop_timing) == max(len(prof1[1]), len(prof2[1])) + + node_created = merge_dict(prof1[8], prof2[8]) + return ( + new_rewriter, + loop_timing, + loop_process_count, + max_nb_nodes, + global_rewrite_timing, + nb_nodes, + time_rewrites, + io_toposort_timing, + node_created, + global_sub_profs, + final_sub_profs, + cleanup_sub_profs, + ) + + +def _check_chain(r, chain): + """ + WRITEME + + """ + chain = list(reversed(chain)) + while chain: + elem = chain.pop() + if elem is None: + if r.owner is not None: + return False + elif r.owner is None: + return False + elif isinstance(elem, Op): + if r.owner.op != elem: + return False + else: + try: + if issubclass(elem, Op) and not isinstance(r.owner.op, elem): + return False + except TypeError: + return False + if chain: + r = r.owner.inputs[chain.pop()] + # print 'check_chain', _check_chain.n_calls + # _check_chain.n_calls += 1 + + # The return value will be used as a Boolean, but some Variables cannot + # be used as Booleans (the results of comparisons, for instance) + return r is not None + + +def check_chain(r, *chain): + """ + WRITEME + + """ + if isinstance(r, Apply): + r = r.outputs[0] + return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain))) + + +def pre_greedy_node_rewriter( + fgraph: FunctionGraph, rewrites: Sequence[NodeRewriter], out: Variable +) -> Variable: + """Apply node rewriters throughout a graph in a greedy, pre-traversal way. + + This function traverses the computation graph in the graph before the + variable `out` but that are not in the `fgraph`. It applies + `rewrites` to each variable on the traversed graph. + + .. warning:: + + This changes the nodes in a graph in-place. + + Its main use is to apply locally constant folding when generating + the graph of the indices of a `Subtensor`. + + Changes should not be applied to nodes that are in an `fgraph`, + so we use `fgraph` to prevent that. + + Notes + ----- + This doesn't do an equilibrium rewrite, so, if there is a rewrite--like + `local_upcast_elemwise_constant_inputs`--in the list that adds additional + nodes to the inputs of the node, it might be necessary to call this + function multiple times. + + Parameters + ---------- + fgraph + The graph used to avoid/filter nodes. + rewrites + A sequence of rewrites to apply. + out + The graph to rewrite. + + """ + + def local_recursive_function( + rewrite_list: Sequence[NodeRewriter], + out: Variable, + rewritten_vars: Dict[Variable, Variable], + depth: int, + ) -> Tuple[List[Variable], Dict[Variable, Variable]]: + if not getattr(out, "owner", None): + return [out], rewritten_vars + node = out.owner + + if node in fgraph.apply_nodes: + return node.outputs, rewritten_vars + + # Walk up the graph via the node's inputs + for idx, inp in enumerate(node.inputs): + if inp in rewritten_vars: + nw_in = rewritten_vars[inp] + else: + if inp.owner: + outs, rewritten_vars = local_recursive_function( + rewrite_list, inp, rewritten_vars, depth + 1 + ) + for k, v in zip(inp.owner.outputs, outs): + rewritten_vars[k] = v + nw_in = outs[inp.owner.outputs.index(inp)] + + else: + nw_in = inp + rewritten_vars[inp] = inp + + # XXX: An in-place change + node.inputs[idx] = nw_in + + # Apply the rewrites + results = node.outputs + for rewrite in rewrite_list: + ret = rewrite.transform(fgraph, node) + if ret is not False and ret is not None: + assert isinstance(ret, Sequence) + assert len(ret) == len(node.outputs), rewrite + for k, v in zip(node.outputs, ret): + rewritten_vars[k] = v + results = ret + if ret[0].owner: + node = out.owner + else: + break + + return results, rewritten_vars + + if out.owner: + out_index: int = out.owner.outputs.index(out) + else: + out_index = 0 + + final_outs, rewritten_nodes = local_recursive_function(rewrites, out, {}, 0) + return final_outs[out_index] + + +def copy_stack_trace(from_var, to_var): + r"""Copy the stack traces from `from_var` to `to_var`. + + Parameters + ---------- + from_var : + `Variable` or list `Variable`\s to copy stack traces from. + to_var : + `Variable` or list `Variable`\s to copy stack traces to. + + Notes + ----- + The stacktrace is assumed to be of the form of a list of lists + of tuples. Each tuple contains the filename, line number, function name + and so on. Each list of tuples contains the truples belonging to a + particular `Variable`. + + """ + + # Store stack traces from from_var + tr = [] + if isinstance(from_var, Iterable) and not isinstance(from_var, Variable): + # If from_var is a list, store concatenated stack traces + for v in from_var: + tr += getattr(v.tag, "trace", []) + + else: + # If from_var is not a list, it must be a single tensor variable, + # so just store that particular stack trace + tr = getattr(from_var.tag, "trace", []) + + if tr and isinstance(tr[0], tuple): + # There was one single stack trace, we encapsulate it in a list + tr = [tr] + + # Copy over stack traces to to_var + if isinstance(to_var, Iterable) and not isinstance(to_var, Variable): + # Copy over stack traces from from_var to each variable in + # to_var, including the stack_trace of the to_var before + for v in to_var: + v.tag.trace = getattr(v.tag, "trace", []) + tr + else: + # Copy over stack traces from from_var to each variable to + # to_var, including the stack_trace of the to_var before + to_var.tag.trace = getattr(to_var.tag, "trace", []) + tr + return to_var + + +def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"): + r"""Checks if the outputs of specific `Op`\s have a stack trace. + + Parameters + ---------- + f_or_fgraph : Function or FunctionGraph + The compiled function or the function graph to be analysed. + ops_to_check + This value can be of four different types: + - classes or instances inheriting from `Op` + - tuple/list of classes or instances inheriting from `Op` + - string + - function returning a boolean and taking as input an instance of `Op` + + - if `ops_to_check` is a string, it should be either ``'last'`` or ``'all'``. + ``'last'`` will check only the last `Op` of the graph while ``'all'`` will + check all the `Op`\s of the graph. + - if `ops_to_check` is an `Op` or a tuple/list of `Op`\s, the function will + check that all the outputs of their occurrences in the graph have a + stack trace. + - if `ops_to_check` is a function, it should take as input a + `Op` and return a boolean indicating if the input `Op` should + be checked or not. + + bug_print + This value is a string belonging to ``{'raise', 'warn', 'ignore'}``. + You can specify the behaviour of the function when the specified + `ops_to_check` are not in the graph of `f_or_fgraph`: it can either raise + an exception, write a warning or simply ignore it. + + Returns + ------- + boolean + ``True`` if the outputs of the specified ops have a stack, ``False`` + otherwise. + + """ + if isinstance(f_or_fgraph, aesara.compile.function.types.Function): + fgraph = f_or_fgraph.maker.fgraph + elif isinstance(f_or_fgraph, aesara.graph.fg.FunctionGraph): + fgraph = f_or_fgraph + else: + raise ValueError("The type of f_or_fgraph is not supported") + + if isinstance(ops_to_check, Op) or ( + inspect.isclass(ops_to_check) and issubclass(ops_to_check, Op) + ): + ops_to_check = (ops_to_check,) + + # if ops_to_check is a string + if isinstance(ops_to_check, str): + if ops_to_check == "last": + apply_nodes_to_check = [ + fgraph.outputs[i].owner for i in range(len(fgraph.outputs)) + ] + elif ops_to_check == "all": + apply_nodes_to_check = fgraph.apply_nodes + else: + raise ValueError("The string ops_to_check is not recognised") + + # if ops_to_check is a list/tuple of ops + elif isinstance(ops_to_check, (tuple, list)): + # Separate classes from instances in ops_to_check + op_instances = [] + op_classes = [] + for obj in ops_to_check: + if isinstance(obj, Op): + op_instances.append(obj) + else: + op_classes.append(obj) + op_classes = tuple(op_classes) + + apply_nodes_to_check = [ + node for node in fgraph.apply_nodes if node.op in ops_to_check + ] + [ + node + for node in fgraph.apply_nodes + if isinstance(node.op, op_classes) + or ( + hasattr(node.op, "scalar_op") + and isinstance(node.op.scalar_op, op_classes) + ) + ] + + # if ops_to_check is a function + elif callable(ops_to_check): + apply_nodes_to_check = [ + node for node in fgraph.apply_nodes if ops_to_check(node) + ] + + else: + raise ValueError("ops_to_check does not have the right type") + + if not apply_nodes_to_check: + msg = ( + "Provided op instances/classes are not in the graph or the " + "graph is empty" + ) + if bug_print == "warn": + warnings.warn(msg) + elif bug_print == "raise": + raise Exception(msg) + elif bug_print == "ignore": + pass + else: + raise ValueError("The string bug_print is not recognised") + + for node in apply_nodes_to_check: + for output in node.outputs: + if not hasattr(output.tag, "trace") or not output.tag.trace: + return False + + return True + + +class CheckStackTraceFeature(Feature): + def on_import(self, fgraph, node, reason): + # In `optdb` we only register the `CheckStackTraceRewriter` when + # `config.check_stack_trace` is not off, but we also double check here. + if config.check_stack_trace != "off" and not check_stack_trace(fgraph, "all"): + if config.check_stack_trace == "raise": + raise AssertionError( + f"Empty stack trace. The rewrite that inserted this variable is {reason}." + ) + elif config.check_stack_trace in ("log", "warn"): + apply_nodes_to_check = fgraph.apply_nodes + for node in apply_nodes_to_check: + for output in node.outputs: + if not hasattr(output.tag, "trace") or not output.tag.trace: + output.tag.trace = [ + [ + ( + "", + 0, + f"Empty stack trace. The rewrite that inserted this variable is {reason}.", + "", + ) + ] + ] + if config.check_stack_trace == "warn": + warnings.warn( + f"Empty stack trace. The rewrite that inserted this variable is {reason}." + ) + + +class CheckStackTraceRewriter(GraphRewriter): + """Rewriter that serves to add `CheckStackTraceRewriter` as a feature.""" + + def add_requirements(self, fgraph): + if not hasattr(fgraph, "CheckStackTraceFeature"): + fgraph.attach_feature(CheckStackTraceFeature()) + + def apply(self, fgraph): + pass + + +DEPRECATED_NAMES = [ + ( + "LocalMetaOptimizerSkipAssertionError", + "`LocalMetaOptimizerSkipAssertionError` is deprecated: use `MetaNodeRewriterSkip` instead.", + MetaNodeRewriterSkip, + ), + ( + "GlobalOptimizer", + "`GlobalOptimizer` is deprecated: use `GraphRewriter` instead.", + GraphRewriter, + ), + ( + "LocalOptimizer", + "`LocalOptimizer` is deprecated: use `NodeRewriter` instead.", + NodeRewriter, + ), + ( + "local_optimizer", + "`local_optimizer` is deprecated: use `node_rewriter` instead.", + node_rewriter, + ), + ( + "pre_greedy_local_optimizer", + "`pre_greedy_local_optimizer` is deprecated: use `pre_greedy_node_rewriter` instead.", + pre_greedy_node_rewriter, + ), + ( + "FromFunctionOptimizer", + "`FromFunctionOptimizer` is deprecated: use `FromFunctionGraphRewriter` instead.", + FromFunctionGraphRewriter, + ), + ( + "optimizer", + "`optimizer` is deprecated: use `graph_rewriter` instead.", + graph_rewriter, + ), + ( + "inplace_optimizer", + "`inplace_optimizer` is deprecated: use `graph_rewriter` instead.", + graph_rewriter, + ), + ( + "LocalMetaOptimizer", + "`LocalMetaOptimizer` is deprecated: use `MetaNodeRewriter` instead.", + MetaNodeRewriter, + ), + ( + "SeqOptimizer", + "`SeqOptimizer` is deprecated: use `SequentialGraphRewriter` instead.", + SequentialGraphRewriter, + ), + ( + "FromFunctionLocalOptimizer", + "`FromFunctionLocalOptimizer` is deprecated: use `FromFunctionNodeRewriter` instead.", + FromFunctionNodeRewriter, + ), + ( + "LocalOptTracker", + "`LocalOptTracker` is deprecated: use `OpToRewriterTracker` instead.", + OpToRewriterTracker, + ), + ( + "LocalOptGroup", + "`LocalOptGroup` is deprecated: use `SequentialNodeRewriter` instead.", + SequentialNodeRewriter, + ), + ( + "OpSub", + "`OpSub` is deprecated: use `SubstitutionNodeRewriter` instead.", + SubstitutionNodeRewriter, + ), + ( + "OpRemove", + "`OpRemove` is deprecated: use `RemovalNodeRewriter` instead.", + RemovalNodeRewriter, + ), + ( + "PatternSub", + "`PatternSub` is deprecated: use `PatternNodeRewriter` instead.", + PatternNodeRewriter, + ), + ( + "NavigatorOptimizer", + "`NavigatorOptimizer` is deprecated: use `NodeProcessingGraphRewriter` instead.", + NodeProcessingGraphRewriter, + ), + ( + "TopoOptimizer", + "`TopoOptimizer` is deprecated: use `WalkingGraphRewriter` instead.", + WalkingGraphRewriter, + ), + ( + "topogroup_optimizer", + "`topogroup_optimizer` is deprecated: use `walking_rewriter` instead.", + walking_rewriter, + ), + ( + "OpKeyOptimizer", + "`OpKeyOptimizer` is deprecated: use `OpKeyGraphRewriter` instead.", + OpKeyGraphRewriter, + ), + ( + "EquilibriumOptimizer", + "`EquilibriumOptimizer` is deprecated: use `EquilibriumGraphRewriter` instead.", + EquilibriumGraphRewriter, + ), +] + + +def __getattr__(name): + """Intercept module-level attribute access of deprecated symbols. + + Adapted from https://stackoverflow.com/a/55139609/3006474. + + """ + from warnings import warn + + for old_name, msg, old_object in DEPRECATED_NAMES: + if name == old_name: + warn(msg, DeprecationWarning, stacklevel=2) + return old_object + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/aesara/graph/rewriting/db.py b/aesara/graph/rewriting/db.py new file mode 100644 index 0000000000..ad9d4f5863 --- /dev/null +++ b/aesara/graph/rewriting/db.py @@ -0,0 +1,578 @@ +import copy +import math +import sys +from functools import cmp_to_key +from io import StringIO +from typing import Dict, Iterable, Optional, Sequence, Tuple, Union + +from aesara.configdefaults import config +from aesara.graph.rewriting import basic as aesara_rewriting +from aesara.misc.ordered_set import OrderedSet +from aesara.utils import DefaultOrderedDict + + +RewritesType = Union[aesara_rewriting.GraphRewriter, aesara_rewriting.NodeRewriter] + + +class RewriteDatabase: + r"""A class that represents a collection/database of rewrites. + + These databases are used to logically organize collections of rewrites + (i.e. `GraphRewriter`\s and `NodeRewriter`). + """ + + def __init__(self): + self.__db__ = DefaultOrderedDict(OrderedSet) + self._names = set() + # This will be reset by `self.register` (via `obj.name` by the thing + # doing the registering) + self.name = None + + def register( + self, + name: str, + rewriter: Union["RewriteDatabase", RewritesType], + *tags: str, + use_db_name_as_tag=True, + ): + """Register a new rewriter to the database. + + Parameters + ---------- + name: + Name of the rewriter. + rewriter: + The rewriter to register. + tags: + Tag name that allows one to select the rewrite using a + `RewriteDatabaseQuery`. + use_db_name_as_tag: + Add the database's name as a tag, so that its name can be used in a + query. + By default, all rewrites registered to an `EquilibriumDB` are + selected when the ``"EquilibriumDB"`` name is used as a tag. We do + not want this behavior for some rewrites like + ``local_remove_all_assert``. Setting `use_db_name_as_tag` to + ``False`` removes that behavior. This means that only the rewrite's name + and/or its tags will enable it. + + """ + if not isinstance( + rewriter, + ( + RewriteDatabase, + aesara_rewriting.GraphRewriter, + aesara_rewriting.NodeRewriter, + ), + ): + raise TypeError(f"{rewriter} is not a valid rewrite type.") + + if name in self.__db__: + raise ValueError(f"The tag '{name}' is already present in the database.") + + if use_db_name_as_tag: + if self.name is not None: + tags = tags + (self.name,) + + rewriter.name = name + # This restriction is there because in many place we suppose that + # something in the RewriteDatabase is there only once. + if rewriter.name in self.__db__: + raise ValueError( + f"Tried to register {rewriter.name} again under the new name {name}. " + "The same rewrite cannot be registered multiple times in" + " an `RewriteDatabase`; use `ProxyDB` instead." + ) + self.__db__[name] = OrderedSet([rewriter]) + self._names.add(name) + self.__db__[rewriter.__class__.__name__].add(rewriter) + self.add_tags(name, *tags) + + def add_tags(self, name, *tags): + obj = self.__db__[name] + assert len(obj) == 1 + obj = obj.copy().pop() + for tag in tags: + if tag in self._names: + raise ValueError( + f"The tag '{tag}' for the {obj} collides with an existing name." + ) + self.__db__[tag].add(obj) + + def remove_tags(self, name, *tags): + obj = self.__db__[name] + assert len(obj) == 1 + obj = obj.copy().pop() + for tag in tags: + if tag in self._names: + raise ValueError( + f"The tag '{tag}' for the {obj} collides with an existing name." + ) + self.__db__[tag].remove(obj) + + def __query__(self, q): + # The ordered set is needed for deterministic rewriting. + variables = OrderedSet() + for tag in q.include: + variables.update(self.__db__[tag]) + for tag in q.require: + variables.intersection_update(self.__db__[tag]) + for tag in q.exclude: + variables.difference_update(self.__db__[tag]) + remove = OrderedSet() + add = OrderedSet() + for obj in variables: + if isinstance(obj, RewriteDatabase): + def_sub_query = q + if q.extra_rewrites: + def_sub_query = copy.copy(q) + def_sub_query.extra_rewrites = [] + sq = q.subquery.get(obj.name, def_sub_query) + + replacement = obj.query(sq) + replacement.name = obj.name + remove.add(obj) + add.add(replacement) + variables.difference_update(remove) + variables.update(add) + return variables + + def query(self, *tags, **kwtags): + if len(tags) >= 1 and isinstance(tags[0], RewriteDatabaseQuery): + if len(tags) > 1 or kwtags: + raise TypeError( + "If the first argument to query is an `RewriteDatabaseQuery`," + " there should be no other arguments." + ) + return self.__query__(tags[0]) + include = [tag[1:] for tag in tags if tag.startswith("+")] + require = [tag[1:] for tag in tags if tag.startswith("&")] + exclude = [tag[1:] for tag in tags if tag.startswith("-")] + if len(include) + len(require) + len(exclude) < len(tags): + raise ValueError( + "All tags must start with one of the following" + " characters: '+', '&' or '-'" + ) + return self.__query__( + RewriteDatabaseQuery( + include=include, require=require, exclude=exclude, subquery=kwtags + ) + ) + + def __getitem__(self, name): + variables = self.__db__[name] + if not variables: + raise KeyError(f"Nothing registered for '{name}'") + elif len(variables) > 1: + raise ValueError(f"More than one match for {name} (please use query)") + for variable in variables: + return variable + + def __contains__(self, name): + return name in self.__db__ + + def print_summary(self, stream=sys.stdout): + print(f"{self.__class__.__name__} (id {id(self)})", file=stream) + print(" names", self._names, file=stream) + print(" db", self.__db__, file=stream) + + +class RewriteDatabaseQuery: + """An object that specifies a set of rewrites by tag/name.""" + + def __init__( + self, + include: Iterable[Union[str, None]], + require: Optional[Union[OrderedSet, Sequence[str]]] = None, + exclude: Optional[Union[OrderedSet, Sequence[str]]] = None, + subquery: Optional[Dict[str, "RewriteDatabaseQuery"]] = None, + position_cutoff: float = math.inf, + extra_rewrites: Optional[ + Sequence[ + Tuple[Union["RewriteDatabaseQuery", RewritesType], Union[int, float]] + ] + ] = None, + ): + """ + + Parameters + ========== + include: + A set of tags such that every rewirte obtained through this + `RewriteDatabaseQuery` must have **one** of the tags listed. This + field is required and basically acts as a starting point for the + search. + require: + A set of tags such that every rewrite obtained through this + `RewriteDatabaseQuery` must have **all** of these tags. + exclude: + A set of tags such that every rewrite obtained through this + ``RewriteDatabaseQuery` must have **none** of these tags. + subquery: + A dictionary mapping the name of a sub-database to a special + `RewriteDatabaseQuery`. If no subquery is given for a sub-database, + the original `RewriteDatabaseQuery` will be used again. + position_cutoff: + Only rewrites with position less than the cutoff are returned. + extra_rewrites: + Extra rewrites to be added. + + """ + self.include = OrderedSet(include) + self.require = OrderedSet(require) if require else OrderedSet() + self.exclude = OrderedSet(exclude) if exclude else OrderedSet() + self.subquery = subquery or {} + self.position_cutoff = position_cutoff + self.name: Optional[str] = None + if extra_rewrites is None: + extra_rewrites = [] + self.extra_rewrites = list(extra_rewrites) + + def __str__(self): + return ( + "RewriteDatabaseQuery(" + + f"inc={self.include},ex={self.exclude}," + + f"require={self.require},subquery={self.subquery}," + + f"position_cutoff={self.position_cutoff}," + + f"extra_rewrites={self.extra_rewrites})" + ) + + def __setstate__(self, state): + self.__dict__.update(state) + if not hasattr(self, "extra_rewrites"): + self.extra_rewrites = [] + + def including(self, *tags: str) -> "RewriteDatabaseQuery": + """Add rewrites with the given tags.""" + return RewriteDatabaseQuery( + self.include.union(tags), + self.require, + self.exclude, + self.subquery, + self.position_cutoff, + self.extra_rewrites, + ) + + def excluding(self, *tags: str) -> "RewriteDatabaseQuery": + """Remove rewrites with the given tags.""" + return RewriteDatabaseQuery( + self.include, + self.require, + self.exclude.union(tags), + self.subquery, + self.position_cutoff, + self.extra_rewrites, + ) + + def requiring(self, *tags: str) -> "RewriteDatabaseQuery": + """Filter for rewrites with the given tags.""" + return RewriteDatabaseQuery( + self.include, + self.require.union(tags), + self.exclude, + self.subquery, + self.position_cutoff, + self.extra_rewrites, + ) + + def register( + self, *rewrites: Tuple["RewriteDatabaseQuery", Union[int, float]] + ) -> "RewriteDatabaseQuery": + """Include the given rewrites.""" + return RewriteDatabaseQuery( + self.include, + self.require, + self.exclude, + self.subquery, + self.position_cutoff, + self.extra_rewrites + list(rewrites), + ) + + +class EquilibriumDB(RewriteDatabase): + """A database of rewrites that should be applied until equilibrium is reached. + + Canonicalize, Stabilize, and Specialize are all equilibrium rewriters. + + Notes + ----- + We can use `NodeRewriter` and `GraphRewriter` since `EquilibriumGraphRewriter` + supports both. + + It is probably not a good idea to have both ``ignore_newtrees == False`` + and ``tracks_on_change_inputs == True``. + + """ + + def __init__( + self, ignore_newtrees: bool = True, tracks_on_change_inputs: bool = False + ): + """ + + Parameters + ---------- + ignore_newtrees + If ``False``, apply rewrites to new nodes introduced during + rewriting. + + tracks_on_change_inputs + If ``True``, re-apply rewrites on nodes with changed inputs. + + """ + super().__init__() + self.ignore_newtrees = ignore_newtrees + self.tracks_on_change_inputs = tracks_on_change_inputs + self.__final__: Dict[str, bool] = {} + self.__cleanup__: Dict[str, bool] = {} + + def register( + self, + name: str, + rewriter: Union["RewriteDatabase", RewritesType], + *tags: str, + final_rewriter: bool = False, + cleanup: bool = False, + **kwargs, + ): + if final_rewriter and cleanup: + raise ValueError("`final_rewriter` and `cleanup` cannot both be true.") + super().register(name, rewriter, *tags, **kwargs) + self.__final__[name] = final_rewriter + self.__cleanup__[name] = cleanup + + def query(self, *tags, **kwtags): + _rewriters = super().query(*tags, **kwtags) + final_rewriters = [o for o in _rewriters if self.__final__.get(o.name, False)] + cleanup_rewriters = [ + o for o in _rewriters if self.__cleanup__.get(o.name, False) + ] + rewriters = [ + o + for o in _rewriters + if o not in final_rewriters and o not in cleanup_rewriters + ] + if len(final_rewriters) == 0: + final_rewriters = None + if len(cleanup_rewriters) == 0: + cleanup_rewriters = None + return aesara_rewriting.EquilibriumGraphRewriter( + rewriters, + max_use_ratio=config.optdb__max_use_ratio, + ignore_newtrees=self.ignore_newtrees, + tracks_on_change_inputs=self.tracks_on_change_inputs, + failure_callback=aesara_rewriting.NodeProcessingGraphRewriter.warn_inplace, + final_rewriters=final_rewriters, + cleanup_rewriters=cleanup_rewriters, + ) + + +class SequenceDB(RewriteDatabase): + """A sequence of potential rewrites. + + Retrieve a sequence of rewrites as a `SequentialGraphRewriter` by calling + `SequenceDB.query`. + + Each potential rewrite is registered with a floating-point position. + No matter which rewrites are selected by a query, they are carried + out in order of increasing position. + + """ + + seq_rewriter_type = aesara_rewriting.SequentialGraphRewriter + + def __init__(self, failure_callback=aesara_rewriting.SequentialGraphRewriter.warn): + super().__init__() + self.__position__ = {} + self.failure_callback = failure_callback + + def register(self, name, obj, *tags, **kwargs): + position = kwargs.pop("position", "last") + + super().register(name, obj, *tags, **kwargs) + + if position == "last": + if len(self.__position__) == 0: + self.__position__[name] = 0 + else: + self.__position__[name] = max(self.__position__.values()) + 1 + elif isinstance(position, (int, float)): + self.__position__[name] = position + else: + raise TypeError(f"`position` must be numeric; got {position}") + + def query( + self, *tags, position_cutoff: Optional[Union[int, float]] = None, **kwtags + ): + """ + + Parameters + ---------- + position_cutoff : float or int + Only rewrites with position less than the cutoff are returned. + + """ + rewrites = super().query(*tags, **kwtags) + + if position_cutoff is None: + position_cutoff = config.optdb__position_cutoff + + position_dict = self.__position__ + + if len(tags) >= 1 and isinstance(tags[0], RewriteDatabaseQuery): + # the call to super should have raise an error with a good message + assert len(tags) == 1 + if getattr(tags[0], "position_cutoff", None): + position_cutoff = tags[0].position_cutoff + + # The RewriteDatabaseQuery instance might contain extra rewrites which need + # to be added the the sequence of rewrites (don't alter the + # original dictionary) + if len(tags[0].extra_rewrites) > 0: + position_dict = position_dict.copy() + for extra_rewrite in tags[0].extra_rewrites: + # Give a name to the extra rewrites (include both the + # class name for descriptiveness and id to avoid name + # collisions) + rewrite, position = extra_rewrite + rewrite.name = f"{rewrite.__class__}_{id(rewrite)}" + + if position < position_cutoff: + rewrites.add(rewrite) + position_dict[rewrite.name] = position + + rewrites = [o for o in rewrites if position_dict[o.name] < position_cutoff] + rewrites.sort(key=lambda obj: (position_dict[obj.name], obj.name)) + + if self.failure_callback: + ret = self.seq_rewriter_type( + rewrites, failure_callback=self.failure_callback + ) + else: + ret = self.seq_rewriter_type(rewrites) + + if hasattr(tags[0], "name"): + ret.name = tags[0].name + return ret + + def print_summary(self, stream=sys.stdout): + print(f"{self.__class__.__name__ } (id {id(self)})", file=stream) + positions = list(self.__position__.items()) + + def c(a, b): + return (a[1] > b[1]) - (a[1] < b[1]) + + positions.sort(key=cmp_to_key(c)) + + print("\tposition", positions, file=stream) + print("\tnames", self._names, file=stream) + print("\tdb", self.__db__, file=stream) + + def __str__(self): + sio = StringIO() + self.print_summary(sio) + return sio.getvalue() + + +class LocalGroupDB(SequenceDB): + r"""A database that generates `NodeRewriter`\s of type `SequentialNodeRewriter`.""" + + def __init__( + self, + apply_all_rewrites: bool = False, + profile: bool = False, + node_rewriter=aesara_rewriting.SequentialNodeRewriter, + ): + super().__init__(failure_callback=None) + self.apply_all_rewrites = apply_all_rewrites + self.profile = profile + self.node_rewriter = node_rewriter + self.__name__: str = "" + + def register(self, name, obj, *tags, position="last", **kwargs): + super().register(name, obj, *tags, position=position, **kwargs) + + def query(self, *tags, **kwtags): + rewrites = list(super().query(*tags, **kwtags)) + ret = self.node_rewriter( + *rewrites, apply_all_rewrites=self.apply_all_rewrites, profile=self.profile + ) + return ret + + +class TopoDB(RewriteDatabase): + """Generate a `GraphRewriter` of type `WalkingGraphRewriter`.""" + + def __init__( + self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None + ): + super().__init__() + self.db = db + self.order = order + self.ignore_newtrees = ignore_newtrees + self.failure_callback = failure_callback + + def query(self, *tags, **kwtags): + return aesara_rewriting.WalkingGraphRewriter( + self.db.query(*tags, **kwtags), + self.order, + self.ignore_newtrees, + self.failure_callback, + ) + + +class ProxyDB(RewriteDatabase): + """A object that wraps an existing ``RewriteDatabase``. + + This is needed because we can't register the same ``RewriteDatabase`` + multiple times in different positions in a ``SequentialDB``. + + """ + + def __init__(self, db): + if not isinstance(db, RewriteDatabase): + raise TypeError("`db` must be an `RewriteDatabase`.") + + self.db = db + + def query(self, *tags, **kwtags): + return self.db.query(*tags, **kwtags) + + +DEPRECATED_NAMES = [ + ( + "DB", + "`DB` is deprecated; use `RewriteDatabase` instead.", + RewriteDatabase, + ), + ( + "Query", + "`Query` is deprecated; use `RewriteDatabaseQuery` instead.", + RewriteDatabaseQuery, + ), + ( + "OptimizationDatabase", + "`OptimizationDatabase` is deprecated; use `RewriteDatabase` instead.", + RewriteDatabase, + ), + ( + "OptimizationQuery", + "`OptimizationQuery` is deprecated; use `RewriteDatabaseQuery` instead.", + RewriteDatabaseQuery, + ), +] + + +def __getattr__(name): + """Intercept module-level attribute access of deprecated symbols. + + Adapted from https://stackoverflow.com/a/55139609/3006474. + + """ + from warnings import warn + + for old_name, msg, old_object in DEPRECATED_NAMES: + if name == old_name: + warn(msg, DeprecationWarning, stacklevel=2) + return old_object + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/aesara/graph/rewriting/kanren.py b/aesara/graph/rewriting/kanren.py new file mode 100644 index 0000000000..212d86d02d --- /dev/null +++ b/aesara/graph/rewriting/kanren.py @@ -0,0 +1,100 @@ +from typing import Callable, Iterator, List, Optional, Union + +from etuples.core import ExpressionTuple +from kanren import run +from unification import var +from unification.variable import Var + +from aesara.graph.basic import Apply, Variable +from aesara.graph.rewriting.basic import NodeRewriter +from aesara.graph.rewriting.unify import eval_if_etuple + + +class KanrenRelationSub(NodeRewriter): + r"""A rewriter that uses `kanren` to match and replace terms. + + See `kanren `__ for more information + miniKanren and the API for constructing `kanren` goals. + + Example + ------- + + ..code-block:: python + + from kanren import eq, conso, var + + import aesara.tensor as at + from aesara.graph.rewriting.kanren import KanrenRelationSub + + + def relation(in_lv, out_lv): + # A `kanren` goal that changes `at.log` terms to `at.exp` + cdr_lv = var() + return eq(conso(at.log, cdr_lv, in_lv), + conso(at.exp, cdr_lv, out_lv)) + + + kanren_sub_opt = KanrenRelationSub(relation) + + """ + + reentrant = True + + def __init__( + self, + kanren_relation: Callable[[Variable, Var], Callable], + results_filter: Optional[ + Callable[[Iterator], Optional[List[Union[ExpressionTuple, Variable]]]] + ] = None, + node_filter: Callable[[Apply], bool] = lambda x: True, + ): + r"""Create a `KanrenRelationSub`. + + Parameters + ---------- + kanren_relation + A function that takes an input graph and an output logic variable and + returns a `kanren` goal. + results_filter + A function that takes the direct output of ``kanren.run(None, ...)`` + and returns a single result. The default implementation returns + the first result. + node_filter + A function taking a single node and returns ``True`` when the node + should be processed. + """ + if results_filter is None: + + def results_filter( + x: Iterator, + ) -> Optional[List[Union[ExpressionTuple, Variable]]]: + return next(x, None) + + self.kanren_relation = kanren_relation + self.results_filter = results_filter + self.node_filter = node_filter + super().__init__() + + def transform(self, fgraph, node): + if self.node_filter(node) is False: + return False + + try: + input_expr = node.default_output() + except ValueError: + input_expr = node.outputs + + q = var() + kanren_results = run(None, q, self.kanren_relation(input_expr, q)) + + chosen_res = self.results_filter(kanren_results) + + if chosen_res: + if isinstance(chosen_res, list): + new_outputs = [eval_if_etuple(v) for v in chosen_res] + else: + new_outputs = [eval_if_etuple(chosen_res)] + + return new_outputs + else: + return False diff --git a/aesara/graph/rewriting/unify.py b/aesara/graph/rewriting/unify.py new file mode 100644 index 0000000000..4a56212fc8 --- /dev/null +++ b/aesara/graph/rewriting/unify.py @@ -0,0 +1,293 @@ +""" +If you have two expressions containing unification variables, these expressions +can be "unified" if there exists an assignment to all unification variables +such that the two expressions are equal. + +For instance, [5, A, B] and [A, C, 9] can be unified if A=C=5 and B=9, +yielding [5, 5, 9]. +[5, [A, B]] and [A, [1, 2]] cannot be unified because there is no value for A +that satisfies the constraints. That's useful for pattern matching. + +""" + +from collections.abc import Mapping +from numbers import Number +from typing import Dict, Optional, Tuple, Union + +import numpy as np +from cons.core import ConsError, _car, _cdr +from etuples import apply, etuple, etuplize +from etuples.core import ExpressionTuple +from unification.core import _unify, assoc +from unification.utils import transitive_get as walk +from unification.variable import Var, isvar, var + +from aesara.graph.basic import Constant, Variable +from aesara.graph.op import Op +from aesara.graph.type import Type + + +def eval_if_etuple(x): + if isinstance(x, ExpressionTuple): + return x.evaled_obj + return x + + +class ConstrainedVar(Var): + """A logical variable with a constraint. + + These will unify with other `Var`s regardless of the constraints. + """ + + __slots__ = ("constraint",) + + def __new__(cls, constraint, token=None, prefix=""): + if token is None: + token = f"{prefix}_{Var._id}" + Var._id += 1 + + key = (token, constraint) + obj = cls._refs.get(key, None) + + if obj is None: + obj = object.__new__(cls) + obj.token = token + obj.constraint = constraint + cls._refs[key] = obj + + return obj + + def __eq__(self, other): + if type(self) == type(other): + return self.token == other.token and self.constraint == other.constraint + return NotImplemented + + def __hash__(self): + return hash((type(self), self.token, self.constraint)) + + def __str__(self): + return f"~{self.token} [{self.constraint}]" + + def __repr__(self): + return f"{type(self).__name__}({repr(self.constraint)}, {self.token})" + + +def car_Variable(x): + if x.owner: + return x.owner.op + else: + raise ConsError("Not a cons pair.") + + +_car.add((Variable,), car_Variable) + + +def cdr_Variable(x): + if x.owner: + x_e = etuple(_car(x), *x.owner.inputs, evaled_obj=x) + else: + raise ConsError("Not a cons pair.") + + return x_e[1:] + + +_cdr.add((Variable,), cdr_Variable) + + +def car_Op(x): + if hasattr(x, "__props__"): + return type(x) + + raise ConsError("Not a cons pair.") + + +_car.add((Op,), car_Op) + + +def cdr_Op(x): + if not hasattr(x, "__props__"): + raise ConsError("Not a cons pair.") + + x_e = etuple( + _car(x), + *[getattr(x, p) for p in getattr(x, "__props__", ())], + evaled_obj=x, + ) + return x_e[1:] + + +_cdr.add((Op,), cdr_Op) + + +def car_Type(x): + return type(x) + + +_car.add((Type,), car_Type) + + +def cdr_Type(x): + x_e = etuple( + _car(x), *[getattr(x, p) for p in getattr(x, "__props__", ())], evaled_obj=x + ) + return x_e[1:] + + +_cdr.add((Type,), cdr_Type) + + +def apply_Op_ExpressionTuple(op, etuple_arg): + res = op.make_node(*etuple_arg) + + try: + return res.default_output() + except ValueError: + return res.outputs + + +apply.add((Op, ExpressionTuple), apply_Op_ExpressionTuple) + + +def _unify_etuplize_first_arg(u, v, s): + try: + u_et = etuplize(u, shallow=True) + yield _unify(u_et, v, s) + except TypeError: + yield False + return + + +_unify.add((Op, ExpressionTuple, Mapping), _unify_etuplize_first_arg) +_unify.add( + (ExpressionTuple, Op, Mapping), lambda u, v, s: _unify_etuplize_first_arg(v, u, s) +) + +_unify.add((Type, ExpressionTuple, Mapping), _unify_etuplize_first_arg) +_unify.add( + (ExpressionTuple, Type, Mapping), lambda u, v, s: _unify_etuplize_first_arg(v, u, s) +) + + +def _unify_Variable_Variable(u, v, s): + # Avoid converting to `etuple`s, when possible + if u == v: + yield s + return + + if not u.owner and not v.owner: + yield False + return + + yield _unify( + etuplize(u, shallow=True) if u.owner else u, + etuplize(v, shallow=True) if v.owner else v, + s, + ) + + +_unify.add((Variable, Variable, Mapping), _unify_Variable_Variable) + + +def _unify_Constant_Constant(u, v, s): + # XXX: This ignores shape and type differences. It's only implemented this + # way for backward compatibility + if np.array_equiv(u.data, v.data): + yield s + else: + yield False + + +_unify.add((Constant, Constant, Mapping), _unify_Constant_Constant) + + +def _unify_Variable_ExpressionTuple(u, v, s): + # `Constant`s are "atomic" + if not u.owner: + yield False + return + + yield _unify(etuplize(u, shallow=True), v, s) + + +_unify.add( + (Variable, ExpressionTuple, Mapping), + _unify_Variable_ExpressionTuple, +) +_unify.add( + (ExpressionTuple, Variable, Mapping), + lambda u, v, s: _unify_Variable_ExpressionTuple(v, u, s), +) + + +@_unify.register(ConstrainedVar, (ConstrainedVar, Var, object), Mapping) +def _unify_ConstrainedVar_object(u, v, s): + u_w = walk(u, s) + + if isvar(v): + v_w = walk(v, s) + else: + v_w = v + + if u_w == v_w: + yield s + elif isvar(u_w): + if ( + not isvar(v_w) + and isinstance(u_w, ConstrainedVar) + and not u_w.constraint(eval_if_etuple(v_w)) + ): + yield False + return + yield assoc(s, u_w, v_w) + elif isvar(v_w): + if ( + not isvar(u_w) + and isinstance(v_w, ConstrainedVar) + and not v_w.constraint(eval_if_etuple(u_w)) + ): + yield False + return + yield assoc(s, v_w, u_w) + else: + yield _unify(u_w, v_w, s) + + +_unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object) + + +def convert_strs_to_vars( + x: Union[Tuple, str, Dict], var_map: Optional[Dict[str, Var]] = None +) -> Union[ExpressionTuple, Var]: + r"""Convert tuples and strings to `etuple`\s and logic variables, respectively. + + Constrained logic variables are specified via `dict`s with the keys + `"pattern"`, which specifies the logic variable as a string, and + `"constraint"`, which provides the `Callable` constraint. + """ + if var_map is None: + var_map = {} + + def _convert(y): + if isinstance(y, str): + v = var_map.get(y, var(y)) + var_map[y] = v + return v + elif isinstance(y, dict): + pattern = y["pattern"] + if not isinstance(pattern, str): + raise TypeError( + "Constraints can only be assigned to logic variables (i.e. strings)" + ) + constraint = y["constraint"] + v = var_map.get(pattern, ConstrainedVar(constraint, pattern)) + var_map[pattern] = v + return v + elif isinstance(y, tuple): + return etuple(*tuple(_convert(e) for e in y)) + elif isinstance(y, (Number, np.ndarray)): + from aesara.tensor import as_tensor_variable + + return as_tensor_variable(y) + return y + + return _convert(x) diff --git a/aesara/graph/rewriting/utils.py b/aesara/graph/rewriting/utils.py new file mode 100644 index 0000000000..536c45620b --- /dev/null +++ b/aesara/graph/rewriting/utils.py @@ -0,0 +1,275 @@ +import copy +import warnings +from typing import TYPE_CHECKING, Generator, Optional, Sequence, Union, cast + +import aesara +from aesara.graph.basic import ( + Apply, + Variable, + equal_computations, + graph_inputs, + vars_between, +) +from aesara.graph.fg import FunctionGraph +from aesara.graph.rewriting.db import RewriteDatabaseQuery + + +if TYPE_CHECKING: + from aesara.graph.rewriting.basic import GraphRewriter + + +def rewrite_graph( + graph: Union[Variable, Sequence[Variable], FunctionGraph], + include: Sequence[str] = ("canonicalize",), + custom_rewrite: Optional["GraphRewriter"] = None, + clone: bool = False, + custom_opt: Optional["GraphRewriter"] = None, + **kwargs, +) -> Union[Variable, Sequence[Variable], FunctionGraph]: + """Easily apply rewrites to a graph. + + Parameters + ---------- + graph + A `FunctionGraph` or `Variable` to be rewritten. + include + String names of the rewrites to be queried, via a + `RewriteDatabaseQuery` instance, and applied. The default rewrite + query string is ``"canonicalization"``. + custom_rewrite + A custom `Rewriter` to also be applied. + clone + Whether or not to clone the input graph before rewriting. + **kwargs + Keyword arguments passed to a `RewriteDatabaseQuery` object. + """ + from aesara.compile import optdb + + return_fgraph = False + if isinstance(graph, FunctionGraph): + outputs: Sequence[Variable] = graph.outputs + fgraph = graph + return_fgraph = True + else: + if isinstance(graph, (list, tuple)): + outputs = graph + else: + assert isinstance(graph, Variable) + outputs = [graph] + + fgraph = FunctionGraph(outputs=outputs, clone=clone) + + query_rewrites = optdb.query(RewriteDatabaseQuery(include=include, **kwargs)) + _ = query_rewrites.rewrite(fgraph) + + if custom_opt is not None: + warnings.warn( + "`custom_opt` is deprecated; use `custom_rewrite` instead.", + DeprecationWarning, + stacklevel=2, + ) + custom_rewrite = custom_opt + + if custom_rewrite: + custom_rewrite.rewrite(fgraph) + + if return_fgraph: + return fgraph + else: + if isinstance(graph, (list, tuple)): + return fgraph.outputs + else: + return fgraph.outputs[0] + + +def is_same_graph_with_merge(var1, var2, givens=None): + """ + Merge-based implementation of `aesara.graph.basic.is_same_graph`. + + See help on `aesara.graph.basic.is_same_graph` for additional documentation. + + """ + from aesara.graph.rewriting.basic import MergeOptimizer + + if givens is None: + givens = {} + # Copy variables since the MergeOptimizer will modify them. + copied = copy.deepcopy([var1, var2, givens]) + vars = copied[0:2] + givens = copied[2] + # Create FunctionGraph. + inputs = list(graph_inputs(vars)) + # The clone isn't needed as we did a deepcopy and we cloning will + # break the mapping in givens. + fgraph = aesara.graph.fg.FunctionGraph(inputs, vars, clone=False) + # Perform Variable substitution. + for to_replace, replace_by in givens.items(): + fgraph.replace(to_replace, replace_by) + # Perform merge optimization. + MergeOptimizer().rewrite(fgraph) + # When two variables perform the same computations, they will have the same + # owner in the rewritten graph. + # We need to be careful with the special case where the owner is None, + # which happens when the graph is made of a single Variable. + # We also need to make sure we replace a Variable if it is present in + # `givens`. + vars_replaced = [givens.get(v, v) for v in fgraph.outputs] + o1, o2 = [v.owner for v in vars_replaced] + if o1 is None and o2 is None: + # Comparing two single-Variable graphs: they are equal if they are + # the same Variable. + return vars_replaced[0] == vars_replaced[1] + else: + return o1 is o2 + + +def is_same_graph(var1, var2, givens=None): + """ + Return True iff Variables `var1` and `var2` perform the same computation. + + By 'performing the same computation', we mean that they must share the same + graph, so that for instance this function will return False when comparing + (x * (y * z)) with ((x * y) * z). + + The current implementation is not efficient since, when possible, it + verifies equality by calling two different functions that are expected to + return the same output. The goal is to verify this assumption, to + eventually get rid of one of them in the future. + + Parameters + ---------- + var1 + The first Variable to compare. + var2 + The second Variable to compare. + givens + Similar to the `givens` argument of `aesara.function`, it can be used + to perform substitutions in the computational graph of `var1` and + `var2`. This argument is associated to neither `var1` nor `var2`: + substitutions may affect both graphs if the substituted variable + is present in both. + + Examples + -------- + + ====== ====== ====== ====== + var1 var2 givens output + ====== ====== ====== ====== + x + 1 x + 1 {} True + x + 1 y + 1 {} False + x + 1 y + 1 {x: y} True + ====== ====== ====== ====== + + """ + use_equal_computations = True + + if givens is None: + givens = {} + + if not isinstance(givens, dict): + givens = dict(givens) + + # Get result from the merge-based function. + rval1 = is_same_graph_with_merge(var1=var1, var2=var2, givens=givens) + + if givens: + # We need to build the `in_xs` and `in_ys` lists. To do this, we need + # to be able to tell whether a variable belongs to the computational + # graph of `var1` or `var2`. + # The typical case we want to handle is when `to_replace` belongs to + # one of these graphs, and `replace_by` belongs to the other one. In + # other situations, the current implementation of `equal_computations` + # is probably not appropriate, so we do not call it. + ok = True + in_xs = [] + in_ys = [] + # Compute the sets of all variables found in each computational graph. + inputs_var = list(map(graph_inputs, ([var1], [var2]))) + all_vars = [ + set(vars_between(v_i, v_o)) + for v_i, v_o in ((inputs_var[0], [var1]), (inputs_var[1], [var2])) + ] + + def in_var(x, k): + # Return True iff `x` is in computation graph of variable `vark`. + return x in all_vars[k - 1] + + for to_replace, replace_by in givens.items(): + # Map a substitution variable to the computational graphs it + # belongs to. + inside = { + v: [in_var(v, k) for k in (1, 2)] for v in (to_replace, replace_by) + } + if ( + inside[to_replace][0] + and not inside[to_replace][1] + and inside[replace_by][1] + and not inside[replace_by][0] + ): + # Substitute variable in `var1` by one from `var2`. + in_xs.append(to_replace) + in_ys.append(replace_by) + elif ( + inside[to_replace][1] + and not inside[to_replace][0] + and inside[replace_by][0] + and not inside[replace_by][1] + ): + # Substitute variable in `var2` by one from `var1`. + in_xs.append(replace_by) + in_ys.append(to_replace) + else: + ok = False + break + if not ok: + # We cannot directly use `equal_computations`. + use_equal_computations = False + else: + in_xs = None + in_ys = None + if use_equal_computations: + rval2 = equal_computations(xs=[var1], ys=[var2], in_xs=in_xs, in_ys=in_ys) + assert rval2 == rval1 + return rval1 + + +def get_clients_at_depth( + fgraph: FunctionGraph, node: Apply, depth: int +) -> Generator[Apply, None, None]: + """Yields node clients at given depth.""" + for var in node.outputs: + if depth > 0: + for out_node, _ in fgraph.clients[var]: + if out_node == "output": + continue + yield from get_clients_at_depth( + fgraph, cast(Apply, out_node), depth - 1 + ) + else: + assert var.owner is not None + yield var.owner + + +DEPRECATED_NAMES = [ + ( + "optimize_graph", + "`optimize_graph` is deprecated: use `rewrite_graph` instead.", + rewrite_graph, + ), +] + + +def __getattr__(name): + """Intercept module-level attribute access of deprecated symbols. + + Adapted from https://stackoverflow.com/a/55139609/3006474. + + """ + from warnings import warn + + for old_name, msg, old_object in DEPRECATED_NAMES: + if name == old_name: + warn(msg, DeprecationWarning, stacklevel=2) + return old_object + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/aesara/graph/unify.py b/aesara/graph/unify.py index ffa7502a18..72604dafdf 100644 --- a/aesara/graph/unify.py +++ b/aesara/graph/unify.py @@ -1,293 +1,10 @@ -""" -If you have two expressions containing unification variables, these expressions -can be "unified" if there exists an assignment to all unification variables -such that the two expressions are equal. +import warnings -For instance, [5, A, B] and [A, C, 9] can be unified if A=C=5 and B=9, -yielding [5, 5, 9]. -[5, [A, B]] and [A, [1, 2]] cannot be unified because there is no value for A -that satisfies the constraints. That's useful for pattern matching. -""" - -from collections.abc import Mapping -from numbers import Number -from typing import Dict, Optional, Tuple, Union - -import numpy as np -from cons.core import ConsError, _car, _cdr -from etuples import apply, etuple, etuplize -from etuples.core import ExpressionTuple -from unification.core import _unify, assoc -from unification.utils import transitive_get as walk -from unification.variable import Var, isvar, var - -from aesara.graph.basic import Constant, Variable -from aesara.graph.op import Op -from aesara.graph.type import Type - - -def eval_if_etuple(x): - if isinstance(x, ExpressionTuple): - return x.evaled_obj - return x - - -class ConstrainedVar(Var): - """A logical variable with a constraint. - - These will unify with other `Var`s regardless of the constraints. - """ - - __slots__ = ("constraint",) - - def __new__(cls, constraint, token=None, prefix=""): - if token is None: - token = f"{prefix}_{Var._id}" - Var._id += 1 - - key = (token, constraint) - obj = cls._refs.get(key, None) - - if obj is None: - obj = object.__new__(cls) - obj.token = token - obj.constraint = constraint - cls._refs[key] = obj - - return obj - - def __eq__(self, other): - if type(self) == type(other): - return self.token == other.token and self.constraint == other.constraint - return NotImplemented - - def __hash__(self): - return hash((type(self), self.token, self.constraint)) - - def __str__(self): - return f"~{self.token} [{self.constraint}]" - - def __repr__(self): - return f"ConstrainedVar({repr(self.constraint)}, {self.token})" - - -def car_Variable(x): - if x.owner: - return x.owner.op - else: - raise ConsError("Not a cons pair.") - - -_car.add((Variable,), car_Variable) - - -def cdr_Variable(x): - if x.owner: - x_e = etuple(_car(x), *x.owner.inputs, evaled_obj=x) - else: - raise ConsError("Not a cons pair.") - - return x_e[1:] - - -_cdr.add((Variable,), cdr_Variable) - - -def car_Op(x): - if hasattr(x, "__props__"): - return type(x) - - raise ConsError("Not a cons pair.") - - -_car.add((Op,), car_Op) - - -def cdr_Op(x): - if not hasattr(x, "__props__"): - raise ConsError("Not a cons pair.") - - x_e = etuple( - _car(x), - *[getattr(x, p) for p in getattr(x, "__props__", ())], - evaled_obj=x, - ) - return x_e[1:] - - -_cdr.add((Op,), cdr_Op) - - -def car_Type(x): - return type(x) - - -_car.add((Type,), car_Type) - - -def cdr_Type(x): - x_e = etuple( - _car(x), *[getattr(x, p) for p in getattr(x, "__props__", ())], evaled_obj=x - ) - return x_e[1:] - - -_cdr.add((Type,), cdr_Type) - - -def apply_Op_ExpressionTuple(op, etuple_arg): - res = op.make_node(*etuple_arg) - - try: - return res.default_output() - except ValueError: - return res.outputs - - -apply.add((Op, ExpressionTuple), apply_Op_ExpressionTuple) - - -def _unify_etuplize_first_arg(u, v, s): - try: - u_et = etuplize(u, shallow=True) - yield _unify(u_et, v, s) - except TypeError: - yield False - return - - -_unify.add((Op, ExpressionTuple, Mapping), _unify_etuplize_first_arg) -_unify.add( - (ExpressionTuple, Op, Mapping), lambda u, v, s: _unify_etuplize_first_arg(v, u, s) +warnings.warn( + "The module `aesara.graph.unify` is deprecated; use `aesara.graph.rewriting.unify` instead.", + DeprecationWarning, + stacklevel=2, ) -_unify.add((Type, ExpressionTuple, Mapping), _unify_etuplize_first_arg) -_unify.add( - (ExpressionTuple, Type, Mapping), lambda u, v, s: _unify_etuplize_first_arg(v, u, s) -) - - -def _unify_Variable_Variable(u, v, s): - # Avoid converting to `etuple`s, when possible - if u == v: - yield s - return - - if not u.owner and not v.owner: - yield False - return - - yield _unify( - etuplize(u, shallow=True) if u.owner else u, - etuplize(v, shallow=True) if v.owner else v, - s, - ) - - -_unify.add((Variable, Variable, Mapping), _unify_Variable_Variable) - - -def _unify_Constant_Constant(u, v, s): - # XXX: This ignores shape and type differences. It's only implemented this - # way for backward compatibility - if np.array_equiv(u.data, v.data): - yield s - else: - yield False - - -_unify.add((Constant, Constant, Mapping), _unify_Constant_Constant) - - -def _unify_Variable_ExpressionTuple(u, v, s): - # `Constant`s are "atomic" - if not u.owner: - yield False - return - - yield _unify(etuplize(u, shallow=True), v, s) - - -_unify.add( - (Variable, ExpressionTuple, Mapping), - _unify_Variable_ExpressionTuple, -) -_unify.add( - (ExpressionTuple, Variable, Mapping), - lambda u, v, s: _unify_Variable_ExpressionTuple(v, u, s), -) - - -@_unify.register(ConstrainedVar, (ConstrainedVar, Var, object), Mapping) -def _unify_ConstrainedVar_object(u, v, s): - u_w = walk(u, s) - - if isvar(v): - v_w = walk(v, s) - else: - v_w = v - - if u_w == v_w: - yield s - elif isvar(u_w): - if ( - not isvar(v_w) - and isinstance(u_w, ConstrainedVar) - and not u_w.constraint(eval_if_etuple(v_w)) - ): - yield False - return - yield assoc(s, u_w, v_w) - elif isvar(v_w): - if ( - not isvar(u_w) - and isinstance(v_w, ConstrainedVar) - and not v_w.constraint(eval_if_etuple(u_w)) - ): - yield False - return - yield assoc(s, v_w, u_w) - else: - yield _unify(u_w, v_w, s) - - -_unify.add((object, ConstrainedVar, Mapping), _unify_ConstrainedVar_object) - - -def convert_strs_to_vars( - x: Union[Tuple, str, Dict], var_map: Optional[Dict[str, Var]] = None -) -> Union[ExpressionTuple, Var]: - r"""Convert tuples and strings to `etuple`\s and logic variables, respectively. - - Constrained logic variables are specified via `dict`s with the keys - `"pattern"`, which specifies the logic variable as a string, and - `"constraint"`, which provides the `Callable` constraint. - """ - if var_map is None: - var_map = {} - - def _convert(y): - if isinstance(y, str): - v = var_map.get(y, var(y)) - var_map[y] = v - return v - elif isinstance(y, dict): - pattern = y["pattern"] - if not isinstance(pattern, str): - raise TypeError( - "Constraints can only be assigned to logic variables (i.e. strings)" - ) - constraint = y["constraint"] - v = var_map.get(pattern, ConstrainedVar(constraint, pattern)) - var_map[pattern] = v - return v - elif isinstance(y, tuple): - return etuple(*tuple(_convert(e) for e in y)) - elif isinstance(y, (Number, np.ndarray)): - from aesara.tensor import as_tensor_variable - - return as_tensor_variable(y) - return y - - return _convert(x) +from aesara.graph.rewriting.unify import * # noqa: F401 E402 F403 diff --git a/aesara/graph/utils.py b/aesara/graph/utils.py index 2cdc5a162e..85f528a2fa 100644 --- a/aesara/graph/utils.py +++ b/aesara/graph/utils.py @@ -58,10 +58,9 @@ def simple_extract_stack( if len(trace) == 0: rm = False for p in skips: - # Julian: I added the 'tests' exception together with - # Arnaud. Otherwise, we'd lose the stack trace during - # in our test cases (e.g. in test_opt.py). We're not - # sure this is the right way to do it though. + # The 'tests' exception was added; otherwise, we'd lose the + # stack trace during in our test cases. We're not sure this is + # the right way to do it, though. if p in filename and "tests" not in filename: rm = True break diff --git a/aesara/ifelse.py b/aesara/ifelse.py index cc6f01b14b..29e240674d 100644 --- a/aesara/ifelse.py +++ b/aesara/ifelse.py @@ -11,61 +11,53 @@ is a global operation with a scalar condition. """ -import logging from copy import deepcopy -from typing import List, Sequence, Union +from typing import TYPE_CHECKING, Any, Optional, Sequence, Union import numpy as np import aesara.tensor as at +from aesara import as_symbolic from aesara.compile import optdb from aesara.configdefaults import config from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors from aesara.graph.op import _NoPythonOp -from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer +from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter +from aesara.graph.type import HasDataType, HasShape from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast -__docformat__ = "restructedtext en" -__authors__ = ( - "Razvan Pascanu " - "James Bergstra " - "Dumitru Erhan " - "David Warde-Farley" - "PyMC Developers" - "Aesara Developers" -) -__copyright__ = "(c) 2010, Universite de Montreal" - -_logger = logging.getLogger("aesara.ifelse") +if TYPE_CHECKING: + from aesara.tensor import TensorLike class IfElse(_NoPythonOp): - """ - Op that provides conditional graph evaluation if used with the CVM/VM - linkers. Note that there exist a helpful function `ifelse` that should - be used to instantiate the op! + r"""An `Op` that provides conditional graph evaluation. - According to a scalar condition `condition` the op evaluates and then - returns all the tensors provided on the `then` branch, otherwise it - evaluates and returns the tensors provided on the `else` branch. The op + According to a scalar condition, this `Op` evaluates and then + returns all the tensors provided on the "then"-branch, otherwise it + evaluates and returns the tensors provided on the "else"-branch. The `Op` supports multiple tensors on each branch, with the condition that the same - number of tensors are on the `then` as on the `else` and there is a one - to one correspondence between them (shape and dtype wise). + number of tensors are on the "then"-branch as on the "else"-branch and + there is a one to one correspondence between their dtypes and numbers of + dimensions. - The `then` branch is defined as the first N tensors (after the - condition), while the `else` branch is defined as the last N tensors. + The "then"-branch is defined as the first ``N`` tensors (after the + condition), while the "else"-branch is defined as the last ``N`` tensors. Example usage: - ``rval = ifelse(condition, rval_if_true1, .., rval_if_trueN, - rval_if_false1, rval_if_false2, .., rval_if_falseN)`` + .. code-block:: + + rval = ifelse(condition, + rval_if_true_1, ..., rval_if_true_N, + rval_if_false_1, ..., rval_if_false_N) .. note: - Other Linkers then CVM and VM are INCOMPATIBLE with this Op, and - will ignore its lazy characteristic, computing both the True and - False branch before picking one. + `Linker`\s other than `CVM`, and some other `VM` subclasses, are + incompatible with this `Op`, and will ignore its lazy characteristic, + computing both the true and false branches before returning one. """ @@ -158,86 +150,142 @@ def infer_shape(self, fgraph, node, inputs_shapes): return out_shapes - def make_node(self, c, *args): - if len(args) != 2 * self.n_outs: + def make_node(self, condition: "TensorLike", *true_false_branches: Any): + if len(true_false_branches) != 2 * self.n_outs: raise ValueError( - f"Wrong number of arguments to make_node: expected " - f"{int(2 * self.n_outs)}, got {len(args)}" + f"Wrong number of arguments: expected " + f"{int(2 * self.n_outs)}, got {len(true_false_branches)}" ) - c = at.basic.as_tensor_variable(c) - nw_args = [] - for x in args: - if isinstance(x, Variable): - nw_args.append(x) - else: - nw_args.append(at.as_tensor_variable(x)) - args = nw_args - aes = args[: self.n_outs] - fs = args[self.n_outs :] - for t, f in zip(aes, fs): - # TODO: Attempt to convert types so that they match? - # new_f = t.type.filter_variable(f) + condition = at.basic.as_tensor_variable(condition) + + if condition.type.ndim > 0: + raise TypeError("The condition argument must be a truthy scalar value") + + inputs_true_branch = true_false_branches[: self.n_outs] + inputs_false_branch = true_false_branches[self.n_outs :] + + output_vars = [] + new_inputs_true_branch = [] + new_inputs_false_branch = [] + for input_t, input_f in zip(inputs_true_branch, inputs_false_branch): + + if not isinstance(input_t, Variable): + input_t = as_symbolic(input_t) + if not isinstance(input_f, Variable): + input_f = as_symbolic(input_f) - if not t.type.is_super(f.type): + if type(input_f.type) != type(input_t.type): # noqa: E721 raise TypeError( - "IfElse requires compatible types for true and false return values: " - f"true_branch={t.type}, false_branch={f.type}" + f"Input types {type(input_t.type)} and {type(input_f.type)} do not match." ) - if c.ndim > 0: - raise TypeError( - "Condition given to the op has to be a scalar " - "with 0 standing for False, anything else " - "for True" - ) - return Apply(self, [c] + list(args), [t.type() for t in aes]) + + if isinstance(input_t.type, HasDataType) and isinstance( + input_f.type, HasDataType + ): + # TODO: Be smarter about dtype casting. + # up_dtype = aes.upcast(input_t.type.dtype, input_f.type.dtype) + + if input_t.type.dtype != input_f.type.dtype: + raise TypeError( + "IfElse requires compatible dtypes for both branches: got " + f"true_branch={input_t.type.dtype}, false_branch={input_f.type.dtype}" + ) + + if isinstance(input_t.type, HasShape) and isinstance( + input_f.type, HasShape + ): + + if input_t.type.ndim != input_f.type.ndim: + raise TypeError( + "IfElse requires compatible ndim values for both branches: got " + f"true_branch={input_t.type.ndim}, false_branch={input_f.type.ndim}" + ) + + # We can only use static shape information that corresponds + # in both branches, because the outputs of this `Op` are + # allowed to have distinct shapes from either branch + new_shape = tuple( + s_t if s_t == s_f else None + for s_t, s_f in zip(input_t.type.shape, input_f.type.shape) + ) + # TODO FIXME: The presence of this keyword is a strong + # assumption. Find something that's guaranteed by the/a + # confirmed interface. + output_var_t = input_t.type.clone(shape=new_shape)() + output_var_f = input_f.type.clone(shape=new_shape)() + else: + output_var_t = input_t.type() + output_var_f = input_f.type() + + input_t_ = output_var_f.type.filter_variable(input_t) + input_f_ = output_var_t.type.filter_variable(input_f) + + new_inputs_true_branch.append(input_t_) + new_inputs_false_branch.append(input_f_) + output_vars.append(output_var_t) + + return Apply( + self, + [condition] + new_inputs_true_branch + new_inputs_false_branch, + output_vars, + ) def R_op(self, inputs, eval_points): return self(inputs[0], *eval_points[1:], return_list=True) def grad(self, ins, grads): - aes = ins[1:][: self.n_outs] - fs = ins[1:][self.n_outs :] + + condition = ins[0] + inputs_true_branch = ins[1:][: self.n_outs] + inputs_false_branch = ins[1:][self.n_outs :] + if self.name is not None: nw_name_t = self.name + "_grad_t" nw_name_f = self.name + "_grad_f" else: nw_name_t = None nw_name_f = None - if_true_op = IfElse(n_outs=self.n_outs, as_view=self.as_view, name=nw_name_t) + if_true_op = IfElse(n_outs=self.n_outs, as_view=self.as_view, name=nw_name_t) if_false_op = IfElse(n_outs=self.n_outs, as_view=self.as_view, name=nw_name_f) - # The grads can have a different dtype then the inputs. - # As inputs true/false pair must have the same dtype, - # we must cast the zeros to the corresponding grad dtype - # and not the input dtype. - if_true = ( - [ins[0]] + # The `grads` can have different dtypes than the `inputs`. + # Since input true/false entries must have the same dtypes, we need to + # cast the zeros to the corresponding `grads` dtypes and not the input + # dtypes. + inputs_true_grad = ( + [condition] + grads - + [at.basic.zeros_like(t, dtype=grads[i].dtype) for i, t in enumerate(aes)] + + [ + at.basic.zeros_like(t, dtype=grads[i].dtype) + for i, t in enumerate(inputs_true_branch) + ] ) - if_false = ( - [ins[0]] - + [at.basic.zeros_like(f, dtype=grads[i].dtype) for i, f in enumerate(fs)] + inputs_false_grad = ( + [condition] + + [ + at.basic.zeros_like(f, dtype=grads[i].dtype) + for i, f in enumerate(inputs_false_branch) + ] + grads ) - condition = ins[0] - # condition does affect the elements of the output so it is connected. + # `condition` does affect the elements of the output so it is connected. # For the sake of making the gradient convenient we assume that # condition + epsilon always triggers the same branch as condition condition_grad = condition.zeros_like().astype(config.floatX) + return ( [condition_grad] - + if_true_op(*if_true, return_list=True) - + if_false_op(*if_false, return_list=True) + + if_true_op(*inputs_true_grad, return_list=True) + + if_false_op(*inputs_false_grad, return_list=True) ) def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): cond = node.inputs[0] - aes = node.inputs[1:][: self.n_outs] - fs = node.inputs[1:][self.n_outs :] + input_true_branch = node.inputs[1:][: self.n_outs] + inputs_false_branch = node.inputs[1:][self.n_outs :] outputs = node.outputs def thunk(): @@ -249,12 +297,12 @@ def thunk(): ls = [ idx + 1 for idx in range(self.n_outs) - if not compute_map[aes[idx]][0] + if not compute_map[input_true_branch[idx]][0] ] if len(ls) > 0: return ls else: - for out, t in zip(outputs, aes): + for out, t in zip(outputs, input_true_branch): compute_map[out][0] = 1 val = storage_map[t][0] if self.as_view: @@ -269,12 +317,12 @@ def thunk(): ls = [ 1 + idx + self.n_outs for idx in range(self.n_outs) - if not compute_map[fs[idx]][0] + if not compute_map[inputs_false_branch[idx]][0] ] if len(ls) > 0: return ls else: - for out, f in zip(outputs, fs): + for out, f in zip(outputs, inputs_false_branch): compute_map[out][0] = 1 # can't view both outputs unless destroyhandler # improves @@ -293,46 +341,42 @@ def thunk(): def ifelse( - condition: Variable, - then_branch: Union[Variable, List[Variable]], - else_branch: Union[Variable, List[Variable]], - name: str = None, + condition: "TensorLike", + then_branch: Union[Any, Sequence[Any]], + else_branch: Union[Any, Sequence[Any]], + name: Optional[str] = None, ) -> Union[Variable, Sequence[Variable]]: - """ - This function corresponds to an if statement, returning (and evaluating) - inputs in the ``then_branch`` if ``condition`` evaluates to True or - inputs in the ``else_branch`` if ``condition`` evaluates to False. + """Construct a graph for an ``if`` statement. Parameters - ========== + ---------- condition - ``condition`` should be a tensor scalar representing the condition. - If it evaluates to 0 it corresponds to False, anything else stands - for True. + `condition` should be a tensor scalar representing the condition. + If it evaluates to ``0`` it corresponds to ``False``, anything else + stands for ``True``. then_branch - A single aesara variable or a list of aesara variables that the - function should return as the output if ``condition`` evaluates to + A single variable or a list of variables that the + function should return as the output if `condition` evaluates to true. The number of variables should match those in the - ``else_branch``, and there should be a one to one correspondence - (type wise) with the tensors provided in the else branch + `else_branch`, as well as the dtypes and numbers of dimensions of each + tensor. else_branch - A single aesara variable or a list of aesara variables that the - function should return as the output if ``condition`` evaluates to - false. The number of variables should match those in the then branch, - and there should be a one to one correspondence (type wise) with the - tensors provided in the then branch. + A single variable or a list of variables that the function should + return as the output if `condition` evaluates to false. The number of + variables should match those in `then_branch`, as well as the dtypes + and numbers of dimensions of each tensor. Returns - ======= - A sequence of aesara variables or a single variable (depending on the - nature of the ``then_branch`` and ``else_branch``). More exactly if - ``then_branch`` and ``else_branch`` is a tensor, then - the return variable will be just a single variable, otherwise a - sequence. The value returns correspond either to the values in the - ``then_branch`` or in the ``else_branch`` depending on the value of - ``condition``. + ------- + A sequence of variables or a single variable, depending on the + nature of `then_branch` and `else_branch`. More exactly, if + `then_branch` and `else_branch` is are single variables, then + the return variable will also be a single variable; otherwise, it will + be a sequence. The value returned correspond to either the values in + the `then_branch` or in the `else_branch` depending on the value of + `condition`. """ rval_type = None @@ -344,35 +388,17 @@ def ifelse( if not isinstance(else_branch, (list, tuple)): else_branch = [else_branch] - # Some of the elements might be converted into another type, - # we will store them in these new_... lists. - new_then_branch = [] - new_else_branch = [] - for then_branch_elem, else_branch_elem in zip(then_branch, else_branch): - if not isinstance(then_branch_elem, Variable): - then_branch_elem = at.basic.as_tensor_variable(then_branch_elem) - if not isinstance(else_branch_elem, Variable): - else_branch_elem = at.basic.as_tensor_variable(else_branch_elem) - - # Make sure the types are compatible - else_branch_elem = then_branch_elem.type.filter_variable(else_branch_elem) - then_branch_elem = else_branch_elem.type.filter_variable(then_branch_elem) - - new_then_branch.append(then_branch_elem) - new_else_branch.append(else_branch_elem) - if len(then_branch) != len(else_branch): raise ValueError( - "The number of values on the `then` branch" - " should have the same number of variables as " - "the `else` branch : (variables on `then` " - f"{len(then_branch)}, variables on `else` " - f"{len(else_branch)})" + "The number of values on the `then` branch " + "must match the `else` branch: got " + f"{len(then_branch)} for `then`, and " + f"{len(else_branch)} for `else`." ) new_ifelse = IfElse(n_outs=len(then_branch), as_view=False, name=name) - ins = [condition] + list(new_then_branch) + list(new_else_branch) + ins = [condition] + list(then_branch) + list(else_branch) rval = new_ifelse(*ins, return_list=True) if rval_type is None: @@ -383,7 +409,7 @@ def ifelse( return tuple(rval) -@local_optimizer([IfElse]) +@node_rewriter([IfElse]) def cond_make_inplace(fgraph, node): op = node.op if ( @@ -414,8 +440,8 @@ def cond_make_inplace(fgraph, node): # XXX: Optimizations commented pending further debugging (certain optimizations # make computation less lazy than it should be currently). # -# ifelse_equilibrium = graph.optdb.EquilibriumDB() -# ifelse_seqopt = graph.optdb.SequenceDB() +# ifelse_equilibrium = graph.rewriting.db.EquilibriumDB() +# ifelse_seqopt = graph.rewriting.db.SequenceDB() # ifelse_equilibrium.register('seq_ifelse', ifelse_seqopt, 'fast_run', # 'ifelse') """ Comments: @@ -461,7 +487,7 @@ def cond_make_inplace(fgraph, node): ) -@local_optimizer(acceptable_ops) +@node_rewriter(acceptable_ops) def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node): """This optimization lifts up certain ifelse instances. @@ -508,7 +534,7 @@ def ifelse_lift_single_if_through_acceptable_ops(fgraph, main_node): return nw_outs -@local_optimizer([IfElse]) +@node_rewriter([IfElse]) def cond_merge_ifs_true(fgraph, node): op = node.op if not isinstance(op, IfElse): @@ -535,7 +561,7 @@ def cond_merge_ifs_true(fgraph, node): return op(*old_ins, return_list=True) -@local_optimizer([IfElse]) +@node_rewriter([IfElse]) def cond_merge_ifs_false(fgraph, node): op = node.op if not isinstance(op, IfElse): @@ -562,7 +588,7 @@ def cond_merge_ifs_false(fgraph, node): return op(*old_ins, return_list=True) -class CondMerge(GlobalOptimizer): +class CondMerge(GraphRewriter): """Graph Optimizer that merges different cond ops""" def add_requirements(self, fgraph): @@ -614,7 +640,7 @@ def apply(self, fgraph): fgraph.replace_all_validate(pairs, reason="cond_merge") -@local_optimizer([IfElse]) +@node_rewriter([IfElse]) def cond_remove_identical(fgraph, node): op = node.op @@ -660,7 +686,7 @@ def cond_remove_identical(fgraph, node): return rval -@local_optimizer([IfElse]) +@node_rewriter([IfElse]) def cond_merge_random_op(fgraph, main_node): if isinstance(main_node.op, IfElse): return False @@ -717,7 +743,7 @@ def cond_merge_random_op(fgraph, main_node): # XXX: Optimizations commented pending further debugging (certain optimizations # make computation less lazy than it should be currently). # -# pushout_equilibrium = graph.optdb.EquilibriumDB() +# pushout_equilibrium = graph.rewriting.db.EquilibriumDB() # # XXX: This optimization doesn't seem to exist anymore? # pushout_equilibrium.register("cond_lift_single_if", diff --git a/aesara/link/basic.py b/aesara/link/basic.py index c89c8479d7..48ecb3e69c 100644 --- a/aesara/link/basic.py +++ b/aesara/link/basic.py @@ -641,7 +641,8 @@ def create_jitable_thunk( The JITed function that performs the computations. """ - output_nodes = [o.owner for o in self.fgraph.outputs] + # This is a bit hackish, but we only return one of the output nodes + output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1] converted_fgraph = self.fgraph_convert( self.fgraph, @@ -678,8 +679,7 @@ def thunk( thunks.append(thunk) - # This is a bit hackish, but we only return one of the output nodes - return thunks, output_nodes[:1], fgraph_jit + return thunks, output_nodes, fgraph_jit def make_all(self, input_storage=None, output_storage=None, storage_map=None): fgraph = self.fgraph diff --git a/aesara/link/c/basic.py b/aesara/link/c/basic.py index 411434a47a..8aed25cd13 100644 --- a/aesara/link/c/basic.py +++ b/aesara/link/c/basic.py @@ -3,12 +3,11 @@ """ import logging -import os import sys from collections import defaultdict from copy import copy from io import StringIO -from typing import Dict +from typing import TYPE_CHECKING, Any, Dict, Optional import numpy as np @@ -21,7 +20,6 @@ io_toposort, vars_between, ) -from aesara.graph.callcache import CallCache from aesara.link.basic import Container, Linker, LocalLinker, PerformLinker from aesara.link.c.cmodule import ( METH_VARARGS, @@ -36,13 +34,14 @@ from aesara.utils import difference, uniq -_logger = logging.getLogger("aesara.link.c.basic") - +if TYPE_CHECKING: + from aesara.graph.fg import FunctionGraph + from aesara.link.c.cmodule import ModuleCache -run_cthunk = None # Will be imported only when needed. +_logger = logging.getLogger("aesara.link.c.basic") -def get_module_cache(init_args=None): +def get_module_cache(init_args: Optional[Dict[str, Any]] = None) -> "ModuleCache": """ Parameters @@ -55,18 +54,6 @@ def get_module_cache(init_args=None): return _get_module_cache(config.compiledir, init_args=init_args) -_persistent_module_cache = None - - -def get_persistent_module_cache(): - global _persistent_module_cache - if _persistent_module_cache is None: - _persistent_module_cache = CallCache( - os.path.join(config.compiledir, "persistent_cache") - ) - return _persistent_module_cache - - class CodeBlock: """ Represents a computation unit composed of declare, behavior, and cleanup. @@ -594,7 +581,9 @@ def __init__(self, schedule=None): self.fgraph = None super().__init__(scheduler=schedule) - def accept(self, fgraph, no_recycling=None, profile=None): + def accept( + self, fgraph: "FunctionGraph", no_recycling=None, profile=None + ) -> "CLinker": r"""Associate this `Linker` with `fgraph`. The `no_recycling` argument can contain a list of `Variable`\s that @@ -1094,18 +1083,18 @@ def __compile__( input_storage=None, output_storage=None, storage_map=None, + cache: Optional["ModuleCache"] = None, ): - """ - Compiles this linker's fgraph. + """Compile `self.fgraph`. Parameters ---------- input_storage: list or None List of lists of length 1. In order to use the thunk returned - by __compile__, the inputs must be put in that storage. + by this method, the inputs must be put in that storage. If None, storage will be allocated. output_storage: list of lists of length 1 - The thunk returned by __compile__ will put the variables of the + The thunk returned by this method will put the variables of the computation in these lists. If None, storage will be allocated. Returns @@ -1135,6 +1124,7 @@ def __compile__( input_storage, output_storage, storage_map, + cache, ) return ( thunk, @@ -1166,37 +1156,51 @@ def get_init_tasks(self): id += 2 return init_tasks, tasks - def make_thunk(self, input_storage=None, output_storage=None, storage_map=None): - """ - Compiles this linker's fgraph and returns a function to perform the - computations, as well as lists of storage cells for both the inputs - and outputs. + def make_thunk( + self, + input_storage=None, + output_storage=None, + storage_map=None, + cache: Optional["ModuleCache"] = None, + **kwargs, + ): + """Compile this linker's `self.fgraph` and return a function that performs the computations. + + The return values can be used as follows: + + .. code-block:: + + f, istor, ostor = clinker.make_thunk() + istor[0].data = first_input + istor[1].data = second_input + f() + first_output = ostor[0].data + Parameters ---------- input_storage: list or None List of lists of length 1. In order to use - the thunk returned by __compile__, the inputs must be put in + the thunk returned by `CLinker.__compile__`, the inputs must be put in that storage. If None, storage will be allocated. output_storage: list of lists of length 1. - The thunk returned by __compile__ will put the variables + The thunk returned by `CLinker.__compile__` will put the variables of the computation in these lists. If None, storage will be allocated. storage_map: dict that map variables to storages. This is used when you need to customize the storage of this thunk - Returns: thunk, input_storage, output_storage + cache + A cache in which to store the compilation results. + + Returns + ------- + thunk, input_storage, output_storage - The return values can be used as follows: - f, istor, ostor = clinker.make_thunk() - istor[0].data = first_input - istor[1].data = second_input - f() - first_output = ostor[0].data """ init_tasks, tasks = self.get_init_tasks() cthunk, module, in_storage, out_storage, error_storage = self.__compile__( - input_storage, output_storage, storage_map + input_storage, output_storage, storage_map, cache ) res = _CThunk(cthunk, init_tasks, tasks, error_storage, module) @@ -1456,8 +1460,12 @@ def in_sig(i, topological_pos, i_idx): for node_pos, node in enumerate(order): if hasattr(node.op, "c_code_cache_version_apply"): version.append(node.op.c_code_cache_version_apply(node)) - if hasattr(node.op, "__props__"): - version.append(node.op.__props__) + + props = getattr(node.op, "__props__", None) + + if props: + version.append(props) + for i in node.inputs: if isinstance(i.type, CLinkerObject): version.append(i.type.c_code_cache_version()) @@ -1598,7 +1606,14 @@ def get_dynamic_module(self): self._mod = mod return self._mod - def cthunk_factory(self, error_storage, in_storage, out_storage, storage_map=None): + def cthunk_factory( + self, + error_storage, + in_storage, + out_storage, + storage_map=None, + cache: Optional["ModuleCache"] = None, + ): """ Returns a thunk that points to an instance of a C struct that can carry on the computation of this linker's fgraph @@ -1619,6 +1634,7 @@ def cthunk_factory(self, error_storage, in_storage, out_storage, storage_map=Non key = self.cmodule_key() except KeyError: key = None + if key is None: # If we can't get a key, then forget the cache mechanism. module = self.compile_cmodule() @@ -1626,7 +1642,9 @@ def cthunk_factory(self, error_storage, in_storage, out_storage, storage_map=Non # Set compute_map as None as clinker do not support lazy evaluation for node in self.node_order: node.op.prepare_node(node, storage_map, None, "c") - module = get_module_cache().module_from_key(key=key, lnk=self) + if cache is None: + cache = get_module_cache() + module = cache.module_from_key(key=key, lnk=self) vars = self.inputs + self.outputs + self.orphans # List of indices that should be ignored when passing the arguments @@ -1703,7 +1721,7 @@ class _CThunk: Parameters ---------- cthunk - The CObject pointer used by run_cthunk. + A CObject pointer that is used to run the thunk. init_tasks WRITEME tasks @@ -1717,15 +1735,16 @@ class _CThunk: """ def __init__(self, cthunk, init_tasks, tasks, error_storage, module): - global run_cthunk - if run_cthunk is None: - # Lazy import to avoid compilation when importing aesara. - from aesara.link.c.cutils import run_cthunk # noqa + # Lazy import to avoid compilation when importing aesara. + from aesara.link.c.cutils import run_cthunk # noqa + + self.run_cthunk = run_cthunk self.cthunk = cthunk self.init_tasks = init_tasks self.tasks = tasks self.error_storage = error_storage self.module = module + self.nodes = None def find_task(self, failure_code): """ @@ -1741,7 +1760,7 @@ def find_task(self, failure_code): return self.tasks[failure_code - n] def __call__(self): - failure = run_cthunk(self.cthunk) + failure = self.run_cthunk(self.cthunk) if failure: task, taskname, id = self.find_task(failure) try: diff --git a/aesara/link/c/cmodule.py b/aesara/link/c/cmodule.py index fad7bf9dff..58102b7303 100644 --- a/aesara/link/c/cmodule.py +++ b/aesara/link/c/cmodule.py @@ -19,7 +19,7 @@ import time import warnings from io import BytesIO, StringIO -from typing import Callable, Dict, List, Optional, Set, Tuple, cast +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, cast import numpy as np from setuptools._distutils.sysconfig import ( @@ -46,6 +46,10 @@ ) +if TYPE_CHECKING: + from aesara.link.c.basic import CLinker + + class StdLibDirsAndLibsType(Protocol): data: Optional[Tuple[List[str], ...]] __call__: Callable[[], Optional[Tuple[List[str], ...]]] @@ -555,7 +559,7 @@ def save_pkl(self): with open(self.key_pkl, "wb") as f: pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL) except pickle.PicklingError: - _logger.warning(f"Cache leak due to unpickle-able key data {self.keys}") + warnings.warn(f"Cache leak due to unpickle-able key data: {self.keys}") os.remove(self.key_pkl) raise @@ -616,38 +620,38 @@ class ModuleCache: can import. The cache contains one directory for each module, containing: - - the dynamic library file itself (.so/.pyd), - - an empty __init__.py file, so Python can import it, - - a file containing the source code for the module (mod.cpp/mod.cu), - - a key.pkl file, containing a KeyData object with all the keys + - the dynamic library file itself (e.g. ``.so/.pyd``), + - an empty ``__init__.py`` file, so Python can import it, + - a file containing the source code for the module (e.g. ``mod.cpp/mod.cu``), + - a ``key.pkl`` file, containing a KeyData object with all the keys associated with that module, - - possibly a delete.me file, meaning this directory has been marked + - possibly a ``delete.me`` file, meaning this directory has been marked for deletion. - Keys should be tuples of length 2: (version, rest). The - ``rest`` can be anything hashable and picklable, that uniquely + Keys should be tuples of length two: ``(version, rest)``. The + rest can be anything hashable and picklable, that uniquely identifies the computation in the module. The key is returned by - ``CLinker.cmodule_key_``. + `CLinker.cmodule_key_`. - The ``version`` should be a hierarchy of tuples of integers. - If the ``version`` is either 0 or (), then the key is unversioned, and its - corresponding module will be marked for deletion in an atexit() handler. - If the ``version`` is neither 0 nor (), then the module will be kept in the - cache between processes. + The ``version`` value should be a hierarchy of tuples of integers. If the + ``version`` value is either ``0`` or ``()``, then the key is unversioned, + and its corresponding module will be marked for deletion in an `atexit` + handler. If the ``version`` value is neither ``0`` nor ``()``, then the + module will be kept in the cache between processes. An unversioned module is not always deleted by the process that creates it. Deleting such modules may not work on NFS filesystems because the tmpdir in which the library resides is in use until the end of the process' lifetime. In this case, unversioned modules - are left in their tmpdirs without corresponding .pkl files. These - modules and their directories are erased by subsequent processes' - refresh() functions. + are left in their temporary directories without corresponding ``.pkl`` + files. These modules and their directories are erased by subsequent + processes' `ModuleCache.refresh`functions. Two different keys are mapped to the same module when all conditions below are met: - They have the same version. - They share the same compilation options in their ``rest`` part (see - ``CLinker.cmodule_key_`` for how this part is built). + `CLinker.cmodule_key_` for how this part is built). - They share the same C code. These three elements uniquely identify a module, and are summarized in a single "module hash". @@ -655,12 +659,12 @@ class ModuleCache: Parameters ---------- check_for_broken_eq - A bad __eq__ implementation can break this cache mechanism. + A bad `object.__eq__` implementation can break this cache mechanism. This option turns on a not-too-expensive sanity check every time a new key is added to the cache. do_refresh : bool - If True, then the ``refresh`` method will be called + If ``True``, then the `ModuleCache.refresh` method will be called in the constructor. """ @@ -677,7 +681,7 @@ class ModuleCache: """ entry_from_key: Dict = {} """ - Maps keys to the filename of a .so/.pyd. + Maps keys to the filename of a ``.so/.pyd``. """ similar_keys: Dict = {} @@ -687,18 +691,18 @@ class ModuleCache: """ module_hash_to_key_data: Dict = {} """ - Maps a module hash to its corresponding KeyData object. + Maps a module hash to its corresponding `KeyData` object. """ stats: List = [] """ A list with counters for the number of hits, loads, compiles issued by - module_from_key(). + `ModuleCache.module_from_key`. """ loaded_key_pkl: Set = set() """ - Set of all key.pkl files that have been loaded. + Set of all ``key.pkl`` files that have been loaded. """ @@ -824,8 +828,8 @@ def rmtree_empty(*args, **kwargs): # os. So it is normal that this happens from time # to time. _logger.warning( - "ModuleCache.refresh() Found key " - f"without dll in cache, deleting it. {key_pkl}", + "`ModuleCache.refresh` Found a key " + f"without a cached shared library ({key_pkl}); deleting it." ) rmtree( root, @@ -852,7 +856,7 @@ def unpickle_failure(): rmtree( root, ignore_nocleanup=True, - msg="broken cache directory [EOF]", + msg="Broken cache directory [EOF]", level=logging.WARNING, ) continue @@ -886,9 +890,9 @@ def unpickle_failure(): root, ignore_nocleanup=True, msg=( - "invalid cache entry format -- this " - "should not happen unless your cache " - "was really old" + "Invalid cache entry format. This " + "should not happen unless the cache " + "is outdated." ), level=logging.WARN, ) @@ -921,14 +925,14 @@ def unpickle_failure(): # TODO: check if this can happen at all to_del = [key for key in key_data.keys if not key[0]] if to_del: - _logger.warning( - "ModuleCache.refresh() Found unversioned " - f"key in cache, removing it. {key_pkl}", + warnings.warn( + "`ModuleCache.refresh` found an unversioned " + f"key in the cache ({key_pkl}); removing it." ) # Since the version is in the module hash, all # keys should be unversioned. if len(to_del) != len(key_data.keys): - _logger.warning( + warnings.warn( "Found a mix of unversioned and " "versioned keys for the same " f"module {key_pkl}", @@ -986,12 +990,11 @@ def unpickle_failure(): else: dir1 = os.path.dirname(self.entry_from_key[key]) dir2 = os.path.dirname(entry) - _logger.warning( - "The same cache key is associated to " + warnings.warn( + "The same cache key is associated with " f"different modules ({dir1} and {dir2}). This " - "is not supposed to happen! You may " - "need to manually delete your cache " - "directory to fix this.", + "is not supposed to happen. The cache directory " + "may need to be manually delete in order to fix this." ) # Clean up the name space to prevent bug. if key_data.keys: @@ -1025,10 +1028,10 @@ def unpickle_failure(): # considered a failure of the OTHER process, that deleted # it. if entry in self.module_from_name: - _logger.warning( + warnings.warn( "A module that was loaded by this " "ModuleCache can no longer be read from file " - f"{entry}... this could lead to problems.", + f"{entry}. This could lead to problems.", ) del self.module_from_name[entry] @@ -1045,7 +1048,7 @@ def unpickle_failure(): # Under /tmp, file are removed periodically by the # os. So it is normal that this happen from time to # time. - _logger.warning( + warnings.warn( f"Removing key file {pkl_file_to_remove} because the " "corresponding module is gone from the " "file system." @@ -1158,16 +1161,14 @@ def _add_to_cache(self, module, key, module_hash): elif config.cmodule__warn_no_version: key_flat = flatten(key) ops = [k for k in key_flat if isinstance(k, Op)] - _logger.warning( - "not all the" - " following op(s) implement" - " c_code_cache_version(). This makes them" - " recompiled for each process." + str(ops) + warnings.warn( + f"The following `Op`(s) do not implement `COp.c_code_cache_version`: {ops}. " + "They will be recompiled across processes/Python sessions" ) self._update_mappings(key, key_data, module.__file__, not key_broken) return key_data - def module_from_key(self, key, lnk=None): + def module_from_key(self, key, lnk: "CLinker"): """ Return a module from the cache, compiling it if necessary. @@ -1322,7 +1323,7 @@ def check_key(self, key, key_pkl): age_thresh_del = config.cmodule__age_thresh_use + 60 * 60 * 24 * 7 age_thresh_del_unversioned = 60 * 60 * 24 * 7 # 7 days """ - The default age threshold for `clear_old` (in seconds). + The default age threshold for `ModuleCache.clear_old` (in seconds). """ @@ -1350,7 +1351,7 @@ def clear_old(self, age_thresh_del=None, delete_if_problem=False): # contain all modules older than age_thresh_del. if age_thresh_del < self.age_thresh_use: if age_thresh_del > 0: - _logger.warning( + _logger.info( "Clearing modules that were not deemed " f"too old to use: age_thresh_del={age_thresh_del}, " f"self.age_thresh_use={self.age_thresh_use}" @@ -1434,14 +1435,14 @@ def clear_base_files(self): shutil.rmtree(to_delete) _logger.debug(f"Deleted: {to_delete}") except Exception: - _logger.warning(f"Could not delete {to_delete}") + warnings.warn(f"Could not delete {to_delete}") continue to_rename = os.path.join(self.dirname, base_dir) if os.path.isdir(to_rename): try: shutil.move(to_rename, to_delete) except Exception: - _logger.warning(f"Could not move {to_rename} to {to_delete}") + warnings.warn(f"Could not move {to_rename} to {to_delete}") def clear_unversioned(self, min_age=None): """Delete unversioned dynamic modules. @@ -1607,23 +1608,23 @@ def _rmtree( with open(os.path.join(parent, "delete.me"), "w"): pass except Exception as ee: - _logger.warning( + warnings.warn( f"Failed to remove or mark cache directory {parent} for removal {ee}" ) -_module_cache = None +_module_cache: Optional[ModuleCache] = None -def get_module_cache(dirname, init_args=None): - """ - Create a new module_cache with the (k, v) pairs in this dictionary +def get_module_cache(dirname: str, init_args=None) -> ModuleCache: + """Create a new module_cache. Parameters ---------- + dirname + The name of the directory used by the cache. init_args - If not None, the (k, v) pairs in this dictionary will be forwarded to - the ModuleCache constructor as keyword arguments. + Keyword arguments passed to the `ModuleCache` constructor. """ global _module_cache @@ -1633,14 +1634,14 @@ def get_module_cache(dirname, init_args=None): _module_cache = ModuleCache(dirname, **init_args) atexit.register(_module_cache._on_atexit) elif init_args: - _logger.warning( + warnings.warn( "Ignoring init arguments for module cache because it " "was created prior to this call" ) if _module_cache.dirname != dirname: - _logger.warning( + warnings.warn( "Returning module cache instance with different " - f"dirname ({_module_cache.dirname}) than you requested ({dirname})" + f"dirname ({_module_cache.dirname}) than requested ({dirname})" ) return _module_cache @@ -2073,13 +2074,11 @@ def compile_args(march_flags=True): and "clang-omp++" not in config.cxx and "icpc" not in config.cxx ): - _logger.warning( - "Your Aesara flag `cxx` seems not to be" - " the g++ compiler. So we disable the compiler optimization" - " specific to g++ that tell to compile for a specific CPU." - " At worst, this could cause slow down.\n" - " You can add those parameters to the compiler yourself" - " via the Aesara flag `gcc__cxxflags`." + warnings.warn( + "`aesara.config.cxx` is not an identifiable `g++` compiler. " + "Aesara will disable compiler optimizations specific to `g++`. " + "At worst, this could cause slow downs.\n" + "Those parameters can be added manually via the `cxxflags` setting." ) detect_march = False @@ -2142,7 +2141,7 @@ def get_lines(cmd, parse=True): ) else: reported_lines = native_lines - _logger.warning( + warnings.warn( "Aesara was not able to find the" " g++ parameters that tune the compilation to your " " specific CPU. This can slow down the execution of Aesara" @@ -2154,15 +2153,17 @@ def get_lines(cmd, parse=True): default_lines = get_lines(f"{config.cxx} -E -v -") _logger.info(f"g++ default lines: {default_lines}") if len(default_lines) < 1: - _logger.warning( - "Aesara was not able to find the" - " default g++ parameters. This is needed to tune" - " the compilation to your specific" - " CPU. This can slow down the execution of Aesara" - " functions. Please submit the following lines to" - " Aesara's mailing list so that we can fix this" - " problem:\n %s", - get_lines(f"{config.cxx} -E -v -", parse=False), + reported_lines = get_lines(f"{config.cxx} -E -v -", parse=False) + warnings.warn( + ( + "Aesara was not able to find the " + "default g++ parameters. This is needed to tune " + "the compilation to your specific " + "CPU. This can slow down the execution of Aesara " + "functions. Please submit the following lines to " + "Aesara's mailing list so that we can fix this " + f"problem:\n {reported_lines}" + ) ) else: # Some options are actually given as "-option value", diff --git a/aesara/link/c/interface.py b/aesara/link/c/interface.py index d230451cb5..84570c0996 100644 --- a/aesara/link/c/interface.py +++ b/aesara/link/c/interface.py @@ -1,3 +1,4 @@ +import warnings from abc import abstractmethod from typing import Callable, Dict, List, Text, Tuple, Union @@ -162,6 +163,9 @@ def c_code_cache_version(self) -> Union[Tuple[int, ...], Tuple]: c_code_cache_version_apply() """ + warnings.warn( + f"{type(self)} specifies no C code cache version and will not be cached." + ) return () diff --git a/aesara/link/c/op.py b/aesara/link/c/op.py index b15679324a..9c35968e4d 100644 --- a/aesara/link/c/op.py +++ b/aesara/link/c/op.py @@ -13,7 +13,6 @@ Optional, Pattern, Set, - Text, Tuple, Union, cast, @@ -28,6 +27,7 @@ from aesara.graph.utils import MethodNotDefined from aesara.link.c.interface import CLinkerOp from aesara.link.c.params_type import ParamsType +from aesara.utils import hash_from_code if TYPE_CHECKING: @@ -230,7 +230,7 @@ def prepare_node(self, node, storage_map, compute_map, impl): self.update_self_openmp() -def lquote_macro(txt: Text) -> Text: +def lquote_macro(txt: str) -> str: """Turn the last line of text into a ``\\``-commented line.""" res = [] spl = txt.split("\n") @@ -240,7 +240,7 @@ def lquote_macro(txt: Text) -> Text: return "\n".join(res) -def get_sub_macros(sub: Dict[Text, Text]) -> Union[Tuple[Text], Tuple[Text, Text]]: +def get_sub_macros(sub: Dict[str, str]) -> Union[Tuple[str], Tuple[str, str]]: define_macros = [] undef_macros = [] define_macros.append(f"#define FAIL {lquote_macro(sub['fail'])}") @@ -253,8 +253,8 @@ def get_sub_macros(sub: Dict[Text, Text]) -> Union[Tuple[Text], Tuple[Text, Text def get_io_macros( - inputs: List[Text], outputs: List[Text] -) -> Union[Tuple[List[Text]], Tuple[str, str]]: + inputs: List[str], outputs: List[str] +) -> Union[Tuple[List[str]], Tuple[str, str]]: define_macros = [] undef_macros = [] @@ -285,7 +285,7 @@ class ExternalCOp(COp): r"^AESARA_(APPLY|SUPPORT)_CODE_SECTION$", re.MULTILINE ) # This is the set of allowed markers - SECTIONS: ClassVar[Set[Text]] = { + SECTIONS: ClassVar[Set[str]] = { "init_code", "init_code_apply", "init_code_struct", @@ -296,9 +296,11 @@ class ExternalCOp(COp): "code", "code_cleanup", } + _cop_num_inputs: Optional[int] = None + _cop_num_outputs: Optional[int] = None @classmethod - def get_path(cls, f: Text) -> Text: + def get_path(cls, f: str) -> str: """Convert a path relative to the location of the class file into an absolute path. Paths that are already absolute are passed through unchanged. @@ -311,7 +313,7 @@ def get_path(cls, f: Text) -> Text: return f def __init__( - self, func_files: Union[Text, List[Text]], func_name: Optional[Text] = None + self, func_files: Union[str, List[str]], func_name: Optional[str] = None ): """ Sections are loaded from files in order with sections in later @@ -319,36 +321,37 @@ def __init__( """ if not isinstance(func_files, list): - func_files = [func_files] + self.func_files = [func_files] + else: + self.func_files = func_files - self.func_name = func_name + self.func_codes: List[str] = [] # Keep the original name. If we reload old pickle, we want to # find the new path and new version of the file in Aesara. - self.func_files = func_files - self.load_c_code(func_files) + self.func_name = func_name + self.code_sections: Dict[str, str] = dict() + + self.load_c_code(self.func_files) if len(self.code_sections) == 0: - raise ValueError("No sections where defined in C files") + raise ValueError("No sections where defined in the C files") if self.func_name is not None: if "op_code" in self.code_sections: # maybe a warning instead (and clearing the key) raise ValueError( - 'Cannot have an "op_code" section and ' "specify the func_name" + "Cannot have an `op_code` section and specify `func_name`" ) if "op_code_cleanup" in self.code_sections: # maybe a warning instead (and clearing the key) raise ValueError( - 'Cannot have an "op_code_cleanup" section ' - "and specify the func_name" + "Cannot have an `op_code_cleanup` section and specify `func_name`" ) - def load_c_code(self, func_files: List[Text]) -> None: + def load_c_code(self, func_files: List[str]) -> None: """Loads the C code to perform the `Op`.""" func_files = [self.get_path(f) for f in func_files] - self.func_codes = [] for func_file in func_files: - # U (universal) will convert all new lines format to \n. with open(func_file) as f: self.func_codes.append(f.read()) @@ -370,7 +373,6 @@ def load_c_code(self, func_files: List[Text]) -> None: "be used at the same time." ) - self.code_sections = dict() for i, code in enumerate(self.func_codes): if self.backward_re.search(code): # This is backward compat code that will go away in a while @@ -449,7 +451,7 @@ def __get_op_params(self) -> List[Tuple[str, Any]]: return params def c_code_cache_version(self): - version = (hash(tuple(self.func_codes)),) + version = (hash_from_code("\n".join(self.func_codes)),) if self.params_type is not None: version += (self.params_type.c_code_cache_version(),) return version @@ -502,24 +504,35 @@ def c_cleanup_code_struct(self, node, name): else: return super().c_cleanup_code_struct(node, name) - def format_c_function_args(self, inp: List[Text], out: List[Text]) -> Text: + def format_c_function_args(self, inp: List[str], out: List[str]) -> str: """Generate a string containing the arguments sent to the external C function. The result will have the format: ``"input0, input1, input2, &output0, &output1"``. """ inp = list(inp) - numi = getattr(self, "_cop_num_inputs", len(inp)) + if self._cop_num_inputs is not None: + numi = self._cop_num_inputs + else: + numi = len(inp) + while len(inp) < numi: inp.append("NULL") + out = [f"&{o}" for o in out] - numo = getattr(self, "_cop_num_outputs", len(out)) + + if self._cop_num_outputs is not None: + numo = self._cop_num_outputs + else: + numo = len(out) + while len(out) < numo: out.append("NULL") + return ", ".join(inp + out) def get_c_macros( - self, node: Apply, name: Text, check_input: Optional[bool] = None + self, node: Apply, name: str, check_input: Optional[bool] = None ) -> Union[Tuple[str], Tuple[str, str]]: "Construct a pair of C ``#define`` and ``#undef`` code strings." define_template = "#define %s %s" diff --git a/aesara/link/jax/dispatch.py b/aesara/link/jax/dispatch.py deleted file mode 100644 index bb1e4ec584..0000000000 --- a/aesara/link/jax/dispatch.py +++ /dev/null @@ -1,1156 +0,0 @@ -import warnings -from functools import reduce, singledispatch -from warnings import warn - -import jax -import jax.numpy as jnp -import jax.scipy as jsp -import numpy as np -from numpy.random import Generator, RandomState -from numpy.random.bit_generator import _coerce_to_uint32_array - -from aesara.compile.ops import DeepCopyOp, ViewOp -from aesara.configdefaults import config -from aesara.graph.fg import FunctionGraph -from aesara.ifelse import IfElse -from aesara.link.utils import fgraph_to_python -from aesara.raise_op import CheckAndRaise -from aesara.scalar import Softplus -from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second -from aesara.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi -from aesara.scan.op import Scan -from aesara.scan.utils import ScanArgs -from aesara.tensor.basic import ( - Alloc, - AllocDiag, - AllocEmpty, - ARange, - ExtractDiag, - Eye, - Join, - MakeVector, - ScalarFromTensor, - TensorFromScalar, -) -from aesara.tensor.blas import BatchedDot -from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from aesara.tensor.extra_ops import ( - Bartlett, - BroadcastTo, - CumOp, - FillDiagonal, - FillDiagonalOffset, - RavelMultiIndex, - Repeat, - Unique, - UnravelIndex, -) -from aesara.tensor.math import Dot, MaxAndArgmax -from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull -from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad -from aesara.tensor.random.op import RandomVariable -from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast -from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular -from aesara.tensor.subtensor import ( - AdvancedIncSubtensor, - AdvancedIncSubtensor1, - AdvancedSubtensor, - AdvancedSubtensor1, - IncSubtensor, - Subtensor, - indices_from_subtensor, -) -from aesara.tensor.type_other import MakeSlice - - -# For use with JAX since JAX doesn't support 'str' arguments -numpy_bit_gens = {"MT19937": 0, "PCG64": 1, "Philox": 2, "SFC64": 3} - - -if config.floatX == "float64": - jax.config.update("jax_enable_x64", True) -else: - jax.config.update("jax_enable_x64", False) - -# XXX: Enabling this will break some shape-based functionality, and severely -# limit the types of graphs that can be converted. -# See https://github.com/google/jax/blob/4d556837cc9003492f674c012689efc3d68fdf5f/design_notes/omnistaging.md -# Older versions < 0.2.0 do not have this flag so we don't need to set it. -try: - jax.config.disable_omnistaging() -except AttributeError: - pass -except Exception as e: - # The version might be >= 0.2.12, which means that omnistaging can't be - # disabled - warnings.warn(f"JAX omnistaging couldn't be disabled: {e}") - -subtensor_ops = (Subtensor, AdvancedSubtensor1, AdvancedSubtensor) -incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1) - - -@singledispatch -def jax_typify(data, dtype=None, **kwargs): - r"""Convert instances of Aesara `Type`\s to JAX types.""" - if dtype is None: - return data - else: - return jnp.array(data, dtype=dtype) - - -@jax_typify.register(np.ndarray) -def jax_typify_ndarray(data, dtype=None, **kwargs): - return jnp.array(data, dtype=dtype) - - -@jax_typify.register(RandomState) -def jax_typify_RandomState(state, **kwargs): - state = state.get_state(legacy=False) - state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] - # XXX: Is this a reasonable approach? - state["jax_state"] = state["state"]["key"][0:2] - return state - - -@jax_typify.register(Generator) -def jax_typify_Generator(rng, **kwargs): - state = rng.__getstate__() - state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] - - # XXX: Is this a reasonable approach? - state["jax_state"] = _coerce_to_uint32_array(state["state"]["state"])[0:2] - - # The "state" and "inc" values in a NumPy `Generator` are 128 bits, which - # JAX can't handle, so we split these values into arrays of 32 bit integers - # and then combine the first two into a single 64 bit integers. - # - # XXX: Depending on how we expect these values to be used, is this approach - # reasonable? - # - # TODO: We might as well remove these altogether, since this conversion - # should only occur once (e.g. when the graph is converted/JAX-compiled), - # and, from then on, we use the custom "jax_state" value. - inc_32 = _coerce_to_uint32_array(state["state"]["inc"]) - state_32 = _coerce_to_uint32_array(state["state"]["state"]) - state["state"]["inc"] = inc_32[0] << 32 | inc_32[1] - state["state"]["state"] = state_32[0] << 32 | state_32[1] - return state - - -@singledispatch -def jax_funcify(op, node=None, storage_map=None, **kwargs): - """Create a JAX compatible function from an Aesara `Op`.""" - raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}") - - -@jax_funcify.register(MakeSlice) -def jax_funcify_MakeSlice(op, **kwargs): - def makeslice(*x): - return slice(*x) - - return makeslice - - -@jax_funcify.register(ScalarOp) -def jax_funcify_ScalarOp(op, **kwargs): - func_name = op.nfunc_spec[0] - - if "." in func_name: - jnp_func = reduce(getattr, [jax] + func_name.split(".")) - else: - jnp_func = getattr(jnp, func_name) - - if hasattr(op, "nfunc_variadic"): - # These are special cases that handle invalid arities due to the broken - # Aesara `Op` type contract (e.g. binary `Op`s that also function as - # their own variadic counterparts--even when those counterparts already - # exist as independent `Op`s). - jax_variadic_func = getattr(jnp, op.nfunc_variadic) - - def elemwise(*args): - if len(args) > op.nfunc_spec[1]: - return jax_variadic_func( - jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0 - ) - else: - return jnp_func(*args) - - return elemwise - else: - return jnp_func - - -@jax_funcify.register(Clip) -def jax_funcify_Clip(op, **kwargs): - def clip(x, min, max): - return jnp.where(x < min, min, jnp.where(x > max, max, x)) - - return clip - - -@jax_funcify.register(Identity) -def jax_funcify_Identity(op, **kwargs): - def identity(x): - return x - - return identity - - -@jax_funcify.register(Softmax) -def jax_funcify_Softmax(op, **kwargs): - axis = op.axis - - def softmax(x): - return jax.nn.softmax(x, axis=axis) - - return softmax - - -@jax_funcify.register(SoftmaxGrad) -def jax_funcify_SoftmaxGrad(op, **kwargs): - axis = op.axis - - def softmax_grad(dy, sm): - dy_times_sm = dy * sm - return dy_times_sm - jnp.sum(dy_times_sm, axis=axis, keepdims=True) * sm - - return softmax_grad - - -@jax_funcify.register(LogSoftmax) -def jax_funcify_LogSoftmax(op, **kwargs): - axis = op.axis - - def log_softmax(x): - return jax.nn.log_softmax(x, axis=axis) - - return log_softmax - - -@jax_funcify.register(Softplus) -def jax_funcify_Softplus(op, **kwargs): - def softplus(x): - # This expression is numerically equivalent to the Aesara one - # It just contains one "speed" optimization less than the Aesara counterpart - return jnp.where( - x < -37.0, jnp.exp(x), jnp.where(x > 33.3, x, jnp.log1p(jnp.exp(x))) - ) - - return softplus - - -@jax_funcify.register(Second) -def jax_funcify_Second(op, **kwargs): - def second(x, y): - return jnp.broadcast_to(y, x.shape) - - return second - - -@jax_funcify.register(AllocDiag) -def jax_funcify_AllocDiag(op, **kwargs): - offset = op.offset - - def allocdiag(v, offset=offset): - return jnp.diag(v, k=offset) - - return allocdiag - - -@jax_funcify.register(AllocEmpty) -def jax_funcify_AllocEmpty(op, **kwargs): - def allocempty(*shape): - return jnp.empty(shape, dtype=op.dtype) - - return allocempty - - -@jax_funcify.register(Alloc) -def jax_funcify_Alloc(op, **kwargs): - def alloc(x, *shape): - res = jnp.broadcast_to(x, shape) - return res - - return alloc - - -@jax_funcify.register(Dot) -def jax_funcify_Dot(op, **kwargs): - def dot(x, y): - return jnp.dot(x, y) - - return dot - - -@jax_funcify.register(ARange) -def jax_funcify_ARange(op, **kwargs): - # XXX: This currently requires concrete arguments. - def arange(start, stop, step): - return jnp.arange(start, stop, step, dtype=op.dtype) - - return arange - - -def jnp_safe_copy(x): - try: - res = jnp.copy(x) - except NotImplementedError: - warn("`jnp.copy` is not implemented yet. " "Using the object's `copy` method.") - if hasattr(x, "copy"): - res = jnp.array(x.copy()) - else: - warn(f"Object has no `copy` method: {x}") - res = x - - return res - - -@jax_funcify.register(DeepCopyOp) -def jax_funcify_DeepCopyOp(op, **kwargs): - def deepcopyop(x): - return jnp_safe_copy(x) - - return deepcopyop - - -@jax_funcify.register(Shape) -def jax_funcify_Shape(op, **kwargs): - def shape(x): - return jnp.shape(x) - - return shape - - -@jax_funcify.register(Shape_i) -def jax_funcify_Shape_i(op, **kwargs): - i = op.i - - def shape_i(x): - return jnp.shape(x)[i] - - return shape_i - - -@jax_funcify.register(SpecifyShape) -def jax_funcify_SpecifyShape(op, **kwargs): - def specifyshape(x, *shape): - assert x.ndim == len(shape) - assert jnp.all(x.shape == tuple(shape)), ( - "got shape", - x.shape, - "expected", - shape, - ) - return x - - return specifyshape - - -@jax_funcify.register(Unbroadcast) -def jax_funcify_Unbroadcast(op, **kwargs): - def unbroadcast(x): - return x - - return unbroadcast - - -@jax_funcify.register(ViewOp) -def jax_funcify_ViewOp(op, **kwargs): - def viewop(x): - return x - - return viewop - - -@jax_funcify.register(Cast) -def jax_funcify_Cast(op, **kwargs): - def cast(x): - return jnp.array(x).astype(op.o_type.dtype) - - return cast - - -@jax_funcify.register(TensorFromScalar) -def jax_funcify_TensorFromScalar(op, **kwargs): - def tensor_from_scalar(x): - return jnp.array(x) - - return tensor_from_scalar - - -@jax_funcify.register(ScalarFromTensor) -def jax_funcify_ScalarFromTensor(op, **kwargs): - def scalar_from_tensor(x): - return jnp.array(x).flatten()[0] - - return scalar_from_tensor - - -@jax_funcify.register(Elemwise) -def jax_funcify_Elemwise(op, **kwargs): - scalar_op = op.scalar_op - return jax_funcify(scalar_op, **kwargs) - - -@jax_funcify.register(Composite) -def jax_funcify_Composite(op, vectorize=True, **kwargs): - jax_impl = jax_funcify(op.fgraph) - - def composite(*args): - return jax_impl(*args)[0] - - return jnp.vectorize(composite) - - -@jax_funcify.register(Scan) -def jax_funcify_Scan(op, **kwargs): - inner_fg = FunctionGraph(op.inputs, op.outputs) - jax_at_inner_func = jax_funcify(inner_fg, **kwargs) - - def scan(*outer_inputs): - scan_args = ScanArgs( - list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info - ) - - # `outer_inputs` is a list with the following composite form: - # [n_steps] - # + outer_in_seqs - # + outer_in_mit_mot - # + outer_in_mit_sot - # + outer_in_sit_sot - # + outer_in_shared - # + outer_in_nit_sot - # + outer_in_non_seqs - n_steps = scan_args.n_steps - seqs = scan_args.outer_in_seqs - - # TODO: mit_mots - mit_mot_in_slices = [] - - mit_sot_in_slices = [] - for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot): - neg_taps = [abs(t) for t in tap if t < 0] - pos_taps = [abs(t) for t in tap if t > 0] - max_neg = max(neg_taps) if neg_taps else 0 - max_pos = max(pos_taps) if pos_taps else 0 - init_slice = seq[: max_neg + max_pos] - mit_sot_in_slices.append(init_slice) - - sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot] - - init_carry = ( - mit_mot_in_slices, - mit_sot_in_slices, - sit_sot_in_slices, - scan_args.outer_in_shared, - scan_args.outer_in_non_seqs, - ) - - def jax_args_to_inner_scan(op, carry, x): - # `carry` contains all inner-output taps, non_seqs, and shared - # terms - ( - inner_in_mit_mot, - inner_in_mit_sot, - inner_in_sit_sot, - inner_in_shared, - inner_in_non_seqs, - ) = carry - - # `x` contains the in_seqs - inner_in_seqs = x - - # `inner_scan_inputs` is a list with the following composite form: - # inner_in_seqs - # + sum(inner_in_mit_mot, []) - # + sum(inner_in_mit_sot, []) - # + inner_in_sit_sot - # + inner_in_shared - # + inner_in_non_seqs - inner_in_mit_sot_flatten = [] - for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices): - inner_in_mit_sot_flatten.extend(array[jnp.array(index)]) - - inner_scan_inputs = sum( - [ - inner_in_seqs, - inner_in_mit_mot, - inner_in_mit_sot_flatten, - inner_in_sit_sot, - inner_in_shared, - inner_in_non_seqs, - ], - [], - ) - - return inner_scan_inputs - - def inner_scan_outs_to_jax_outs( - op, - old_carry, - inner_scan_outs, - ): - ( - inner_in_mit_mot, - inner_in_mit_sot, - inner_in_sit_sot, - inner_in_shared, - inner_in_non_seqs, - ) = old_carry - - def update_mit_sot(mit_sot, new_val): - return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0) - - inner_out_mit_sot = [ - update_mit_sot(mit_sot, new_val) - for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs) - ] - - # This should contain all inner-output taps, non_seqs, and shared - # terms - if not inner_in_sit_sot: - inner_out_sit_sot = [] - else: - inner_out_sit_sot = inner_scan_outs - new_carry = ( - inner_in_mit_mot, - inner_out_mit_sot, - inner_out_sit_sot, - inner_in_shared, - inner_in_non_seqs, - ) - - return new_carry - - def jax_inner_func(carry, x): - inner_args = jax_args_to_inner_scan(op, carry, x) - inner_scan_outs = list(jax_at_inner_func(*inner_args)) - new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs) - return new_carry, inner_scan_outs - - _, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps) - - # We need to prepend the initial values so that the JAX output will - # match the raw `Scan` `Op` output and, thus, work with a downstream - # `Subtensor` `Op` introduced by the `scan` helper function. - def append_scan_out(scan_in_part, scan_out_part): - return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0) - - if scan_args.outer_in_mit_sot: - scan_out_final = [ - append_scan_out(init, out) - for init, out in zip(scan_args.outer_in_mit_sot, scan_out) - ] - elif scan_args.outer_in_sit_sot: - scan_out_final = [ - append_scan_out(init, out) - for init, out in zip(scan_args.outer_in_sit_sot, scan_out) - ] - - if len(scan_out_final) == 1: - scan_out_final = scan_out_final[0] - return scan_out_final - - return scan - - -@jax_funcify.register(IfElse) -def jax_funcify_IfElse(op, **kwargs): - n_outs = op.n_outs - - def ifelse(cond, *args, n_outs=n_outs): - res = jax.lax.cond( - cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None - ) - return res if n_outs > 1 else res[0] - - return ifelse - - -@jax_funcify.register(CheckAndRaise) -def jax_funcify_CheckAndRaise(op, **kwargs): - - raise NotImplementedError( - f"""This exception is raised because you tried to convert an aesara graph with a `CheckAndRaise` Op (message: {op.msg}) to JAX. - - JAX uses tracing to jit-compile functions, and assertions typically - don't do well with tracing. The appropriate workaround depends on what - you intended to do with the assertions in the first place. - - Note that all assertions can be removed from the graph by adding - `local_remove_all_assert` to the rewrites.""" - ) - - -@jax_funcify.register(Subtensor) -def jax_funcify_Subtensor(op, **kwargs): - - idx_list = getattr(op, "idx_list", None) - - def subtensor(x, *ilists): - - indices = indices_from_subtensor(ilists, idx_list) - - if len(indices) == 1: - indices = indices[0] - - return x.__getitem__(indices) - - return subtensor - - -_ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops] - - -def jax_funcify_IncSubtensor(op, **kwargs): - - idx_list = getattr(op, "idx_list", None) - - if getattr(op, "set_instead_of_inc", False): - jax_fn = getattr(jax.ops, "index_update", None) - - if jax_fn is None: - - def jax_fn(x, indices, y): - return x.at[indices].set(y) - - else: - jax_fn = getattr(jax.ops, "index_add", None) - - if jax_fn is None: - - def jax_fn(x, indices, y): - return x.at[indices].add(y) - - def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): - indices = indices_from_subtensor(ilist, idx_list) - if len(indices) == 1: - indices = indices[0] - - return jax_fn(x, indices, y) - - return incsubtensor - - -_ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_ops] - - -@jax_funcify.register(AdvancedIncSubtensor) -def jax_funcify_AdvancedIncSubtensor(op, **kwargs): - - if getattr(op, "set_instead_of_inc", False): - jax_fn = getattr(jax.ops, "index_update", None) - - if jax_fn is None: - - def jax_fn(x, indices, y): - return x.at[indices].set(y) - - else: - jax_fn = getattr(jax.ops, "index_add", None) - - if jax_fn is None: - - def jax_fn(x, indices, y): - return x.at[indices].add(y) - - def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn): - return jax_fn(x, ilist, y) - - return advancedincsubtensor - - -@jax_funcify.register(FunctionGraph) -def jax_funcify_FunctionGraph( - fgraph, - node=None, - fgraph_name="jax_funcified_fgraph", - **kwargs, -): - return fgraph_to_python( - fgraph, - jax_funcify, - type_conversion_fn=jax_typify, - fgraph_name=fgraph_name, - **kwargs, - ) - - -@jax_funcify.register(CAReduce) -def jax_funcify_CAReduce(op, **kwargs): - axis = op.axis - op_nfunc_spec = getattr(op, "nfunc_spec", None) - scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None) - scalar_op_name = getattr(op.scalar_op, "name", None) - scalar_op_identity = getattr(op.scalar_op, "identity", None) - acc_dtype = getattr(op, "acc_dtype", None) - - def careduce(x): - nonlocal axis, op_nfunc_spec, scalar_nfunc_spec, scalar_op_name, scalar_op_identity, acc_dtype - - if axis is None: - axis = list(range(x.ndim)) - - if acc_dtype is None: - acc_dtype = x.dtype.type - - if op_nfunc_spec: - jax_op = getattr(jnp, op_nfunc_spec[0]) - return jax_op(x, axis=axis).astype(acc_dtype) - - # The Aesara `Op` didn't tell us which NumPy equivalent to use (or - # there isn't one), so we use this fallback approach - if scalar_nfunc_spec: - scalar_fn_name = scalar_nfunc_spec[0] - elif scalar_op_name: - scalar_fn_name = scalar_op_name - - to_reduce = reversed(sorted(axis)) - - if to_reduce: - # In this case, we need to use the `jax.lax` function (if there - # is one), and not the `jnp` version. - jax_op = getattr(jax.lax, scalar_fn_name) - init_value = jnp.array(scalar_op_identity, dtype=acc_dtype) - return jax.lax.reduce(x, init_value, jax_op, to_reduce).astype(acc_dtype) - else: - return x - - return careduce - - -@jax_funcify.register(MakeVector) -def jax_funcify_MakeVector(op, **kwargs): - def makevector(*x): - return jnp.array(x, dtype=op.dtype) - - return makevector - - -@jax_funcify.register(Reshape) -def jax_funcify_Reshape(op, **kwargs): - def reshape(x, shape): - return jnp.reshape(x, shape) - - return reshape - - -@jax_funcify.register(DimShuffle) -def jax_funcify_DimShuffle(op, **kwargs): - def dimshuffle(x): - - res = jnp.transpose(x, op.transposition) - - shape = list(res.shape[: len(op.shuffle)]) - - for augm in op.augment: - shape.insert(augm, 1) - - res = jnp.reshape(res, shape) - - if not op.inplace: - res = jnp_safe_copy(res) - - return res - - return dimshuffle - - -@jax_funcify.register(Join) -def jax_funcify_Join(op, **kwargs): - def join(axis, *tensors): - # tensors could also be tuples, and in this case they don't have a ndim - tensors = [jnp.asarray(tensor) for tensor in tensors] - view = op.view - if (view != -1) and all( - tensor.shape[axis] == 0 for tensor in tensors[0:view] + tensors[view + 1 :] - ): - return tensors[view] - - else: - return jnp.concatenate(tensors, axis=axis) - - return join - - -@jax_funcify.register(MaxAndArgmax) -def jax_funcify_MaxAndArgmax(op, **kwargs): - axis = op.axis - - def maxandargmax(x, axis=axis): - if axis is None: - axes = tuple(range(x.ndim)) - else: - axes = tuple(int(ax) for ax in axis) - - max_res = jnp.max(x, axis) - - # NumPy does not support multiple axes for argmax; this is a - # work-around - keep_axes = jnp.array( - [i for i in range(x.ndim) if i not in axes], dtype="int64" - ) - # Not-reduced axes in front - transposed_x = jnp.transpose( - x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64"))) - ) - kept_shape = transposed_x.shape[: len(keep_axes)] - reduced_shape = transposed_x.shape[len(keep_axes) :] - - # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64 - # Otherwise reshape would complain citing float arg - new_shape = kept_shape + ( - jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"), - ) - reshaped_x = transposed_x.reshape(new_shape) - - max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") - - return max_res, max_idx_res - - return maxandargmax - - -@jax_funcify.register(ExtractDiag) -def jax_funcify_ExtractDiag(op, **kwargs): - offset = op.offset - axis1 = op.axis1 - axis2 = op.axis2 - - def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): - return jnp.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) - - return extract_diag - - -@jax_funcify.register(Cholesky) -def jax_funcify_Cholesky(op, **kwargs): - lower = op.lower - - def cholesky(a, lower=lower): - return jsp.linalg.cholesky(a, lower=lower).astype(a.dtype) - - return cholesky - - -@jax_funcify.register(Solve) -def jax_funcify_Solve(op, **kwargs): - - if op.assume_a != "gen" and op.lower: - lower = True - else: - lower = False - - def solve(a, b, lower=lower): - return jsp.linalg.solve(a, b, lower=lower) - - return solve - - -@jax_funcify.register(SolveTriangular) -def jax_funcify_SolveTriangular(op, **kwargs): - lower = op.lower - trans = op.trans - unit_diagonal = op.unit_diagonal - check_finite = op.check_finite - - def solve_triangular(A, b): - return jsp.linalg.solve_triangular( - A, - b, - lower=lower, - trans=trans, - unit_diagonal=unit_diagonal, - check_finite=check_finite, - ) - - return solve_triangular - - -@jax_funcify.register(Det) -def jax_funcify_Det(op, **kwargs): - def det(x): - return jnp.linalg.det(x) - - return det - - -@jax_funcify.register(Eig) -def jax_funcify_Eig(op, **kwargs): - def eig(x): - return jnp.linalg.eig(x) - - return eig - - -@jax_funcify.register(Eigh) -def jax_funcify_Eigh(op, **kwargs): - uplo = op.UPLO - - def eigh(x, uplo=uplo): - return jnp.linalg.eigh(x, UPLO=uplo) - - return eigh - - -@jax_funcify.register(MatrixInverse) -def jax_funcify_MatrixInverse(op, **kwargs): - def matrix_inverse(x): - return jnp.linalg.inv(x) - - return matrix_inverse - - -@jax_funcify.register(QRFull) -def jax_funcify_QRFull(op, **kwargs): - mode = op.mode - - def qr_full(x, mode=mode): - return jnp.linalg.qr(x, mode=mode) - - return qr_full - - -@jax_funcify.register(SVD) -def jax_funcify_SVD(op, **kwargs): - full_matrices = op.full_matrices - compute_uv = op.compute_uv - - def svd(x, full_matrices=full_matrices, compute_uv=compute_uv): - return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) - - return svd - - -@jax_funcify.register(CumOp) -def jax_funcify_CumOp(op, **kwargs): - axis = op.axis - mode = op.mode - - def cumop(x, axis=axis, mode=mode): - if mode == "add": - return jnp.cumsum(x, axis=axis) - else: - return jnp.cumprod(x, axis=axis) - - return cumop - - -@jax_funcify.register(Repeat) -def jax_funcify_Repeat(op, **kwargs): - axis = op.axis - - def repeatop(x, repeats, axis=axis): - return jnp.repeat(x, repeats, axis=axis) - - return repeatop - - -@jax_funcify.register(Bartlett) -def jax_funcify_Bartlett(op, **kwargs): - def bartlett(x): - return jnp.bartlett(x) - - return bartlett - - -@jax_funcify.register(FillDiagonal) -def jax_funcify_FillDiagonal(op, **kwargs): - def filldiagonal(value, diagonal): - i, j = jnp.diag_indices(min(value.shape[-2:])) - return value.at[..., i, j].set(diagonal) - - return filldiagonal - - -@jax_funcify.register(FillDiagonalOffset) -def jax_funcify_FillDiagonalOffset(op, **kwargs): - - # def filldiagonaloffset(a, val, offset): - # height, width = a.shape - # - # if offset >= 0: - # start = offset - # num_of_step = min(min(width, height), width - offset) - # else: - # start = -offset * a.shape[1] - # num_of_step = min(min(width, height), height + offset) - # - # step = a.shape[1] + 1 - # end = start + step * num_of_step - # a.flat[start:end:step] = val - # - # return a - # - # return filldiagonaloffset - - raise NotImplementedError("flatiter not implemented in JAX") - - -@jax_funcify.register(Unique) -def jax_funcify_Unique(op, **kwargs): - axis = op.axis - - if axis is not None: - raise NotImplementedError( - "jax.numpy.unique is not implemented for the axis argument" - ) - - return_index = op.return_index - return_inverse = op.return_inverse - return_counts = op.return_counts - - def unique( - x, - return_index=return_index, - return_inverse=return_inverse, - return_counts=return_counts, - axis=axis, - ): - ret = jnp.lax_numpy._unique1d(x, return_index, return_inverse, return_counts) - if len(ret) == 1: - return ret[0] - else: - return ret - - return unique - - -@jax_funcify.register(UnravelIndex) -def jax_funcify_UnravelIndex(op, **kwargs): - order = op.order - - warn("JAX ignores the `order` parameter in `unravel_index`.") - - def unravelindex(indices, dims, order=order): - return jnp.unravel_index(indices, dims) - - return unravelindex - - -@jax_funcify.register(RavelMultiIndex) -def jax_funcify_RavelMultiIndex(op, **kwargs): - mode = op.mode - order = op.order - - def ravelmultiindex(*inp, mode=mode, order=order): - multi_index, dims = inp[:-1], inp[-1] - return jnp.ravel_multi_index(multi_index, dims, mode=mode, order=order) - - return ravelmultiindex - - -@jax_funcify.register(Eye) -def jax_funcify_Eye(op, **kwargs): - dtype = op.dtype - - def eye(N, M, k): - return jnp.eye(N, M, k, dtype=dtype) - - return eye - - -@jax_funcify.register(BatchedDot) -def jax_funcify_BatchedDot(op, **kwargs): - def batched_dot(a, b): - if a.shape[0] != b.shape[0]: - raise TypeError("Shapes must match in the 0-th dimension") - if a.ndim == 2 or b.ndim == 2: - return jnp.einsum("n...j,nj...->n...", a, b) - return jnp.einsum("nij,njk->nik", a, b) - - return batched_dot - - -@jax_funcify.register(RandomVariable) -def jax_funcify_RandomVariable(op, node, **kwargs): - name = op.name - - if not hasattr(jax.random, name): - raise NotImplementedError( - f"No JAX conversion for the given distribution: {name}" - ) - - dtype = node.outputs[1].dtype - - def random_variable(rng, size, dtype_num, *args): - if not op.inplace: - rng = rng.copy() - prng = rng["jax_state"] - data = getattr(jax.random, name)(key=prng, shape=size) - smpl_value = jnp.array(data, dtype=dtype) - rng["jax_state"] = jax.random.split(prng, num=1)[0] - return (rng, smpl_value) - - return random_variable - - -@jax_funcify.register(Erf) -def jax_funcify_Erf(op, node, **kwargs): - def erf(x): - return jax.scipy.special.erf(x) - - return erf - - -@jax_funcify.register(Erfc) -def jax_funcify_Erfc(op, **kwargs): - def erfc(x): - return jax.scipy.special.erfc(x) - - return erfc - - -@jax_funcify.register(Log1mexp) -def jax_funcify_Log1mexp(op, node, **kwargs): - def log1mexp(x): - return jnp.where( - x < jnp.log(0.5), jnp.log1p(-jnp.exp(x)), jnp.log(-jnp.expm1(x)) - ) - - return log1mexp - - -# Commented out because jax.scipy does not have erfcx, -# but leaving the implementation in here just in case we ever see -# a JAX implementation of Erfcx. -# See https://github.com/google/jax/issues/1987 for context. -# @jax_funcify.register(Erfcx) -# def jax_funcify_Erfcx(op, **kwargs): -# def erfcx(x): -# return jax.scipy.special.erfcx(x) -# return erfcx - - -@jax_funcify.register(Erfinv) -def jax_funcify_Erfinv(op, **kwargs): - def erfinv(x): - return jax.scipy.special.erfinv(x) - - return erfinv - - -# Commented out because jax.scipy does not have Erfcinv, -# but leaving the implementation in here just in case we ever see -# a JAX implementation of Erfcinv. -# @jax_funcify.register(Erfcinv) -# def jax_funcify_Erfcinv(op, **kwargs): -# def erfcinv(x): -# return jax.scipy.special.erfcinv(x) -# return erfcinv - - -@jax_funcify.register(Psi) -def jax_funcify_Psi(op, node, **kwargs): - def psi(x): - return jax.scipy.special.digamma(x) - - return psi - - -@jax_funcify.register(BroadcastTo) -def jax_funcify_BroadcastTo(op, **kwargs): - def broadcast_to(x, *shape): - return jnp.broadcast_to(x, shape) - - return broadcast_to diff --git a/aesara/link/jax/dispatch/__init__.py b/aesara/link/jax/dispatch/__init__.py new file mode 100644 index 0000000000..fb4144bc74 --- /dev/null +++ b/aesara/link/jax/dispatch/__init__.py @@ -0,0 +1,16 @@ +# isort: off +from aesara.link.jax.dispatch.basic import jax_funcify, jax_typify + +# Load dispatch specializations +import aesara.link.jax.dispatch.scalar +import aesara.link.jax.dispatch.tensor_basic +import aesara.link.jax.dispatch.subtensor +import aesara.link.jax.dispatch.shape +import aesara.link.jax.dispatch.extra_ops +import aesara.link.jax.dispatch.nlinalg +import aesara.link.jax.dispatch.slinalg +import aesara.link.jax.dispatch.random +import aesara.link.jax.dispatch.elemwise +import aesara.link.jax.dispatch.scan + +# isort: on diff --git a/aesara/link/jax/dispatch/basic.py b/aesara/link/jax/dispatch/basic.py new file mode 100644 index 0000000000..5165ae6050 --- /dev/null +++ b/aesara/link/jax/dispatch/basic.py @@ -0,0 +1,114 @@ +import warnings +from functools import singledispatch + +import jax +import jax.numpy as jnp +import numpy as np + +from aesara.compile.ops import DeepCopyOp, ViewOp +from aesara.configdefaults import config +from aesara.graph.fg import FunctionGraph +from aesara.ifelse import IfElse +from aesara.link.utils import fgraph_to_python +from aesara.raise_op import Assert, CheckAndRaise + + +if config.floatX == "float64": + jax.config.update("jax_enable_x64", True) +else: + jax.config.update("jax_enable_x64", False) + + +@singledispatch +def jax_typify(data, dtype=None, **kwargs): + r"""Convert instances of Aesara `Type`\s to JAX types.""" + if dtype is None: + return data + else: + return jnp.array(data, dtype=dtype) + + +@jax_typify.register(np.ndarray) +def jax_typify_ndarray(data, dtype=None, **kwargs): + return jnp.array(data, dtype=dtype) + + +@singledispatch +def jax_funcify(op, node=None, storage_map=None, **kwargs): + """Create a JAX compatible function from an Aesara `Op`.""" + raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}") + + +@jax_funcify.register(FunctionGraph) +def jax_funcify_FunctionGraph( + fgraph, + node=None, + fgraph_name="jax_funcified_fgraph", + **kwargs, +): + return fgraph_to_python( + fgraph, + jax_funcify, + type_conversion_fn=jax_typify, + fgraph_name=fgraph_name, + **kwargs, + ) + + +@jax_funcify.register(IfElse) +def jax_funcify_IfElse(op, **kwargs): + n_outs = op.n_outs + + def ifelse(cond, *args, n_outs=n_outs): + res = jax.lax.cond( + cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None + ) + return res if n_outs > 1 else res[0] + + return ifelse + + +@jax_funcify.register(Assert) +@jax_funcify.register(CheckAndRaise) +def jax_funcify_CheckAndRaise(op, **kwargs): + warnings.warn( + f"""Skipping `CheckAndRaise` Op (assertion: {op.msg}) as JAX tracing would remove it.""", + stacklevel=2, + ) + + def assert_fn(x, *inputs): + return x + + return assert_fn + + +def jnp_safe_copy(x): + try: + res = jnp.copy(x) + except NotImplementedError: + warnings.warn( + "`jnp.copy` is not implemented yet. " "Using the object's `copy` method." + ) + if hasattr(x, "copy"): + res = jnp.array(x.copy()) + else: + warnings.warn(f"Object has no `copy` method: {x}") + res = x + + return res + + +@jax_funcify.register(DeepCopyOp) +def jax_funcify_DeepCopyOp(op, **kwargs): + def deepcopyop(x): + return jnp_safe_copy(x) + + return deepcopyop + + +@jax_funcify.register(ViewOp) +def jax_funcify_ViewOp(op, **kwargs): + def viewop(x): + return x + + return viewop diff --git a/aesara/link/jax/dispatch/elemwise.py b/aesara/link/jax/dispatch/elemwise.py new file mode 100644 index 0000000000..90f3a40fdb --- /dev/null +++ b/aesara/link/jax/dispatch/elemwise.py @@ -0,0 +1,107 @@ +import jax +import jax.numpy as jnp + +from aesara.link.jax.dispatch.basic import jax_funcify, jnp_safe_copy +from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise +from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad + + +@jax_funcify.register(Elemwise) +def jax_funcify_Elemwise(op, **kwargs): + scalar_op = op.scalar_op + return jax_funcify(scalar_op, **kwargs) + + +@jax_funcify.register(CAReduce) +def jax_funcify_CAReduce(op, **kwargs): + axis = op.axis + op_nfunc_spec = getattr(op, "nfunc_spec", None) + scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None) + scalar_op_name = getattr(op.scalar_op, "name", None) + scalar_op_identity = getattr(op.scalar_op, "identity", None) + acc_dtype = getattr(op, "acc_dtype", None) + + def careduce(x): + nonlocal axis, op_nfunc_spec, scalar_nfunc_spec, scalar_op_name, scalar_op_identity, acc_dtype + + if axis is None: + axis = list(range(x.ndim)) + + if acc_dtype is None: + acc_dtype = x.dtype.type + + if op_nfunc_spec: + jax_op = getattr(jnp, op_nfunc_spec[0]) + return jax_op(x, axis=axis).astype(acc_dtype) + + # The Aesara `Op` didn't tell us which NumPy equivalent to use (or + # there isn't one), so we use this fallback approach + if scalar_nfunc_spec: + scalar_fn_name = scalar_nfunc_spec[0] + elif scalar_op_name: + scalar_fn_name = scalar_op_name + + to_reduce = reversed(sorted(axis)) + + if to_reduce: + # In this case, we need to use the `jax.lax` function (if there + # is one), and not the `jnp` version. + jax_op = getattr(jax.lax, scalar_fn_name) + init_value = jnp.array(scalar_op_identity, dtype=acc_dtype) + return jax.lax.reduce(x, init_value, jax_op, to_reduce).astype(acc_dtype) + else: + return x + + return careduce + + +@jax_funcify.register(DimShuffle) +def jax_funcify_DimShuffle(op, **kwargs): + def dimshuffle(x): + + res = jnp.transpose(x, op.transposition) + + shape = list(res.shape[: len(op.shuffle)]) + + for augm in op.augment: + shape.insert(augm, 1) + + res = jnp.reshape(res, shape) + + if not op.inplace: + res = jnp_safe_copy(res) + + return res + + return dimshuffle + + +@jax_funcify.register(Softmax) +def jax_funcify_Softmax(op, **kwargs): + axis = op.axis + + def softmax(x): + return jax.nn.softmax(x, axis=axis) + + return softmax + + +@jax_funcify.register(SoftmaxGrad) +def jax_funcify_SoftmaxGrad(op, **kwargs): + axis = op.axis + + def softmax_grad(dy, sm): + dy_times_sm = dy * sm + return dy_times_sm - jnp.sum(dy_times_sm, axis=axis, keepdims=True) * sm + + return softmax_grad + + +@jax_funcify.register(LogSoftmax) +def jax_funcify_LogSoftmax(op, **kwargs): + axis = op.axis + + def log_softmax(x): + return jax.nn.log_softmax(x, axis=axis) + + return log_softmax diff --git a/aesara/link/jax/dispatch/extra_ops.py b/aesara/link/jax/dispatch/extra_ops.py new file mode 100644 index 0000000000..2c70499afd --- /dev/null +++ b/aesara/link/jax/dispatch/extra_ops.py @@ -0,0 +1,142 @@ +import warnings + +import jax.numpy as jnp + +from aesara.link.jax.dispatch.basic import jax_funcify +from aesara.tensor.extra_ops import ( + Bartlett, + BroadcastTo, + CumOp, + FillDiagonal, + FillDiagonalOffset, + RavelMultiIndex, + Repeat, + Unique, + UnravelIndex, +) + + +@jax_funcify.register(Bartlett) +def jax_funcify_Bartlett(op, **kwargs): + def bartlett(x): + return jnp.bartlett(x) + + return bartlett + + +@jax_funcify.register(CumOp) +def jax_funcify_CumOp(op, **kwargs): + axis = op.axis + mode = op.mode + + def cumop(x, axis=axis, mode=mode): + if mode == "add": + return jnp.cumsum(x, axis=axis) + else: + return jnp.cumprod(x, axis=axis) + + return cumop + + +@jax_funcify.register(Repeat) +def jax_funcify_Repeat(op, **kwargs): + axis = op.axis + + def repeatop(x, repeats, axis=axis): + return jnp.repeat(x, repeats, axis=axis) + + return repeatop + + +@jax_funcify.register(Unique) +def jax_funcify_Unique(op, **kwargs): + axis = op.axis + + if axis is not None: + raise NotImplementedError( + "jax.numpy.unique is not implemented for the axis argument" + ) + + return_index = op.return_index + return_inverse = op.return_inverse + return_counts = op.return_counts + + def unique( + x, + return_index=return_index, + return_inverse=return_inverse, + return_counts=return_counts, + axis=axis, + ): + ret = jnp.lax_numpy._unique1d(x, return_index, return_inverse, return_counts) + if len(ret) == 1: + return ret[0] + else: + return ret + + return unique + + +@jax_funcify.register(UnravelIndex) +def jax_funcify_UnravelIndex(op, **kwargs): + order = op.order + + warnings.warn("JAX ignores the `order` parameter in `unravel_index`.") + + def unravelindex(indices, dims, order=order): + return jnp.unravel_index(indices, dims) + + return unravelindex + + +@jax_funcify.register(RavelMultiIndex) +def jax_funcify_RavelMultiIndex(op, **kwargs): + mode = op.mode + order = op.order + + def ravelmultiindex(*inp, mode=mode, order=order): + multi_index, dims = inp[:-1], inp[-1] + return jnp.ravel_multi_index(multi_index, dims, mode=mode, order=order) + + return ravelmultiindex + + +@jax_funcify.register(BroadcastTo) +def jax_funcify_BroadcastTo(op, **kwargs): + def broadcast_to(x, *shape): + return jnp.broadcast_to(x, shape) + + return broadcast_to + + +@jax_funcify.register(FillDiagonal) +def jax_funcify_FillDiagonal(op, **kwargs): + def filldiagonal(value, diagonal): + i, j = jnp.diag_indices(min(value.shape[-2:])) + return value.at[..., i, j].set(diagonal) + + return filldiagonal + + +@jax_funcify.register(FillDiagonalOffset) +def jax_funcify_FillDiagonalOffset(op, **kwargs): + + # def filldiagonaloffset(a, val, offset): + # height, width = a.shape + # + # if offset >= 0: + # start = offset + # num_of_step = min(min(width, height), width - offset) + # else: + # start = -offset * a.shape[1] + # num_of_step = min(min(width, height), height + offset) + # + # step = a.shape[1] + 1 + # end = start + step * num_of_step + # a.flat[start:end:step] = val + # + # return a + # + # return filldiagonaloffset + + raise NotImplementedError("flatiter not implemented in JAX") diff --git a/aesara/link/jax/dispatch/nlinalg.py b/aesara/link/jax/dispatch/nlinalg.py new file mode 100644 index 0000000000..bd50cba0cd --- /dev/null +++ b/aesara/link/jax/dispatch/nlinalg.py @@ -0,0 +1,119 @@ +import jax.numpy as jnp + +from aesara.link.jax.dispatch import jax_funcify +from aesara.tensor.blas import BatchedDot +from aesara.tensor.math import Dot, MaxAndArgmax +from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull + + +@jax_funcify.register(SVD) +def jax_funcify_SVD(op, **kwargs): + full_matrices = op.full_matrices + compute_uv = op.compute_uv + + def svd(x, full_matrices=full_matrices, compute_uv=compute_uv): + return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv) + + return svd + + +@jax_funcify.register(Det) +def jax_funcify_Det(op, **kwargs): + def det(x): + return jnp.linalg.det(x) + + return det + + +@jax_funcify.register(Eig) +def jax_funcify_Eig(op, **kwargs): + def eig(x): + return jnp.linalg.eig(x) + + return eig + + +@jax_funcify.register(Eigh) +def jax_funcify_Eigh(op, **kwargs): + uplo = op.UPLO + + def eigh(x, uplo=uplo): + return jnp.linalg.eigh(x, UPLO=uplo) + + return eigh + + +@jax_funcify.register(MatrixInverse) +def jax_funcify_MatrixInverse(op, **kwargs): + def matrix_inverse(x): + return jnp.linalg.inv(x) + + return matrix_inverse + + +@jax_funcify.register(QRFull) +def jax_funcify_QRFull(op, **kwargs): + mode = op.mode + + def qr_full(x, mode=mode): + return jnp.linalg.qr(x, mode=mode) + + return qr_full + + +@jax_funcify.register(Dot) +def jax_funcify_Dot(op, **kwargs): + def dot(x, y): + return jnp.dot(x, y) + + return dot + + +@jax_funcify.register(BatchedDot) +def jax_funcify_BatchedDot(op, **kwargs): + def batched_dot(a, b): + if a.shape[0] != b.shape[0]: + raise TypeError("Shapes must match in the 0-th dimension") + if a.ndim == 2 or b.ndim == 2: + return jnp.einsum("n...j,nj...->n...", a, b) + return jnp.einsum("nij,njk->nik", a, b) + + return batched_dot + + +@jax_funcify.register(MaxAndArgmax) +def jax_funcify_MaxAndArgmax(op, **kwargs): + axis = op.axis + + def maxandargmax(x, axis=axis): + if axis is None: + axes = tuple(range(x.ndim)) + else: + axes = tuple(int(ax) for ax in axis) + + max_res = jnp.max(x, axis) + + # NumPy does not support multiple axes for argmax; this is a + # work-around + keep_axes = jnp.array( + [i for i in range(x.ndim) if i not in axes], dtype="int64" + ) + # Not-reduced axes in front + transposed_x = jnp.transpose( + x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64"))) + ) + kept_shape = transposed_x.shape[: len(keep_axes)] + reduced_shape = transposed_x.shape[len(keep_axes) :] + + # Numpy.prod returns 1.0 when arg is empty, so we cast it to int64 + # Otherwise reshape would complain citing float arg + new_shape = kept_shape + ( + jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"), + ) + reshaped_x = transposed_x.reshape(new_shape) + + max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64") + + return max_res, max_idx_res + + return maxandargmax diff --git a/aesara/link/jax/dispatch/random.py b/aesara/link/jax/dispatch/random.py new file mode 100644 index 0000000000..f72f8e1346 --- /dev/null +++ b/aesara/link/jax/dispatch/random.py @@ -0,0 +1,70 @@ +import jax +import jax.numpy as jnp +from numpy.random import Generator, RandomState +from numpy.random.bit_generator import ( # type: ignore[attr-defined] + _coerce_to_uint32_array, +) + +from aesara.link.jax.dispatch.basic import jax_funcify, jax_typify +from aesara.tensor.random.op import RandomVariable + + +numpy_bit_gens = {"MT19937": 0, "PCG64": 1, "Philox": 2, "SFC64": 3} + + +@jax_typify.register(RandomState) +def jax_typify_RandomState(state, **kwargs): + state = state.get_state(legacy=False) + state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] + # XXX: Is this a reasonable approach? + state["jax_state"] = state["state"]["key"][0:2] + return state + + +@jax_typify.register(Generator) +def jax_typify_Generator(rng, **kwargs): + state = rng.__getstate__() + state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] + + # XXX: Is this a reasonable approach? + state["jax_state"] = _coerce_to_uint32_array(state["state"]["state"])[0:2] + + # The "state" and "inc" values in a NumPy `Generator` are 128 bits, which + # JAX can't handle, so we split these values into arrays of 32 bit integers + # and then combine the first two into a single 64 bit integers. + # + # XXX: Depending on how we expect these values to be used, is this approach + # reasonable? + # + # TODO: We might as well remove these altogether, since this conversion + # should only occur once (e.g. when the graph is converted/JAX-compiled), + # and, from then on, we use the custom "jax_state" value. + inc_32 = _coerce_to_uint32_array(state["state"]["inc"]) + state_32 = _coerce_to_uint32_array(state["state"]["state"]) + state["state"]["inc"] = inc_32[0] << 32 | inc_32[1] + state["state"]["state"] = state_32[0] << 32 | state_32[1] + return state + + +@jax_funcify.register(RandomVariable) +def jax_funcify_RandomVariable(op, node, **kwargs): + name = op.name + + # TODO Make sure there's a 1-to-1 correspondance with names + if not hasattr(jax.random, name): + raise NotImplementedError( + f"No JAX conversion for the given distribution: {name}" + ) + + dtype = node.outputs[1].dtype + + def random_variable(rng, size, dtype_num, *args): + if not op.inplace: + rng = rng.copy() + prng = rng["jax_state"] + data = getattr(jax.random, name)(key=prng, shape=size) + smpl_value = jnp.array(data, dtype=dtype) + rng["jax_state"] = jax.random.split(prng, num=1)[0] + return (rng, smpl_value) + + return random_variable diff --git a/aesara/link/jax/dispatch/scalar.py b/aesara/link/jax/dispatch/scalar.py new file mode 100644 index 0000000000..8f7fdf6951 --- /dev/null +++ b/aesara/link/jax/dispatch/scalar.py @@ -0,0 +1,134 @@ +import functools + +import jax +import jax.numpy as jnp + +from aesara.link.jax.dispatch.basic import jax_funcify +from aesara.scalar import Softplus +from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second +from aesara.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi + + +@jax_funcify.register(ScalarOp) +def jax_funcify_ScalarOp(op, **kwargs): + func_name = op.nfunc_spec[0] + + if "." in func_name: + jnp_func = functools.reduce(getattr, [jax] + func_name.split(".")) + else: + jnp_func = getattr(jnp, func_name) + + if hasattr(op, "nfunc_variadic"): + # These are special cases that handle invalid arities due to the broken + # Aesara `Op` type contract (e.g. binary `Op`s that also function as + # their own variadic counterparts--even when those counterparts already + # exist as independent `Op`s). + jax_variadic_func = getattr(jnp, op.nfunc_variadic) + + def elemwise(*args): + if len(args) > op.nfunc_spec[1]: + return jax_variadic_func( + jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0 + ) + else: + return jnp_func(*args) + + return elemwise + else: + return jnp_func + + +@jax_funcify.register(Cast) +def jax_funcify_Cast(op, **kwargs): + def cast(x): + return jnp.array(x).astype(op.o_type.dtype) + + return cast + + +@jax_funcify.register(Identity) +def jax_funcify_Identity(op, **kwargs): + def identity(x): + return x + + return identity + + +@jax_funcify.register(Clip) +def jax_funcify_Clip(op, **kwargs): + def clip(x, min, max): + return jnp.where(x < min, min, jnp.where(x > max, max, x)) + + return clip + + +@jax_funcify.register(Composite) +def jax_funcify_Composite(op, vectorize=True, **kwargs): + jax_impl = jax_funcify(op.fgraph) + + def composite(*args): + return jax_impl(*args)[0] + + return jnp.vectorize(composite) + + +@jax_funcify.register(Second) +def jax_funcify_Second(op, **kwargs): + def second(x, y): + return jnp.broadcast_to(y, x.shape) + + return second + + +@jax_funcify.register(Erf) +def jax_funcify_Erf(op, node, **kwargs): + def erf(x): + return jax.scipy.special.erf(x) + + return erf + + +@jax_funcify.register(Erfc) +def jax_funcify_Erfc(op, **kwargs): + def erfc(x): + return jax.scipy.special.erfc(x) + + return erfc + + +@jax_funcify.register(Erfinv) +def jax_funcify_Erfinv(op, **kwargs): + def erfinv(x): + return jax.scipy.special.erfinv(x) + + return erfinv + + +@jax_funcify.register(Log1mexp) +def jax_funcify_Log1mexp(op, node, **kwargs): + def log1mexp(x): + return jnp.where( + x < jnp.log(0.5), jnp.log1p(-jnp.exp(x)), jnp.log(-jnp.expm1(x)) + ) + + return log1mexp + + +@jax_funcify.register(Psi) +def jax_funcify_Psi(op, node, **kwargs): + def psi(x): + return jax.scipy.special.digamma(x) + + return psi + + +@jax_funcify.register(Softplus) +def jax_funcify_Softplus(op, **kwargs): + def softplus(x): + # This expression is numerically equivalent to the Aesara one + # It just contains one "speed" optimization less than the Aesara counterpart + return jnp.where( + x < -37.0, jnp.exp(x), jnp.where(x > 33.3, x, jnp.log1p(jnp.exp(x))) + ) + + return softplus diff --git a/aesara/link/jax/dispatch/scan.py b/aesara/link/jax/dispatch/scan.py new file mode 100644 index 0000000000..12c588c0d6 --- /dev/null +++ b/aesara/link/jax/dispatch/scan.py @@ -0,0 +1,159 @@ +import jax +import jax.numpy as jnp + +from aesara.graph.fg import FunctionGraph +from aesara.link.jax.dispatch.basic import jax_funcify +from aesara.scan.op import Scan +from aesara.scan.utils import ScanArgs + + +@jax_funcify.register(Scan) +def jax_funcify_Scan(op, **kwargs): + inner_fg = FunctionGraph(op.inputs, op.outputs) + jax_at_inner_func = jax_funcify(inner_fg, **kwargs) + + def scan(*outer_inputs): + scan_args = ScanArgs( + list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info + ) + + # `outer_inputs` is a list with the following composite form: + # [n_steps] + # + outer_in_seqs + # + outer_in_mit_mot + # + outer_in_mit_sot + # + outer_in_sit_sot + # + outer_in_shared + # + outer_in_nit_sot + # + outer_in_non_seqs + n_steps = scan_args.n_steps + seqs = scan_args.outer_in_seqs + + # TODO: mit_mots + mit_mot_in_slices = [] + + mit_sot_in_slices = [] + for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot): + neg_taps = [abs(t) for t in tap if t < 0] + pos_taps = [abs(t) for t in tap if t > 0] + max_neg = max(neg_taps) if neg_taps else 0 + max_pos = max(pos_taps) if pos_taps else 0 + init_slice = seq[: max_neg + max_pos] + mit_sot_in_slices.append(init_slice) + + sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot] + + init_carry = ( + mit_mot_in_slices, + mit_sot_in_slices, + sit_sot_in_slices, + scan_args.outer_in_shared, + scan_args.outer_in_non_seqs, + ) + + def jax_args_to_inner_scan(op, carry, x): + # `carry` contains all inner-output taps, non_seqs, and shared + # terms + ( + inner_in_mit_mot, + inner_in_mit_sot, + inner_in_sit_sot, + inner_in_shared, + inner_in_non_seqs, + ) = carry + + # `x` contains the in_seqs + inner_in_seqs = x + + # `inner_scan_inputs` is a list with the following composite form: + # inner_in_seqs + # + sum(inner_in_mit_mot, []) + # + sum(inner_in_mit_sot, []) + # + inner_in_sit_sot + # + inner_in_shared + # + inner_in_non_seqs + inner_in_mit_sot_flatten = [] + for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices): + inner_in_mit_sot_flatten.extend(array[jnp.array(index)]) + + inner_scan_inputs = sum( + [ + inner_in_seqs, + inner_in_mit_mot, + inner_in_mit_sot_flatten, + inner_in_sit_sot, + inner_in_shared, + inner_in_non_seqs, + ], + [], + ) + + return inner_scan_inputs + + def inner_scan_outs_to_jax_outs( + op, + old_carry, + inner_scan_outs, + ): + ( + inner_in_mit_mot, + inner_in_mit_sot, + inner_in_sit_sot, + inner_in_shared, + inner_in_non_seqs, + ) = old_carry + + def update_mit_sot(mit_sot, new_val): + return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0) + + inner_out_mit_sot = [ + update_mit_sot(mit_sot, new_val) + for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs) + ] + + # This should contain all inner-output taps, non_seqs, and shared + # terms + if not inner_in_sit_sot: + inner_out_sit_sot = [] + else: + inner_out_sit_sot = inner_scan_outs + new_carry = ( + inner_in_mit_mot, + inner_out_mit_sot, + inner_out_sit_sot, + inner_in_shared, + inner_in_non_seqs, + ) + + return new_carry + + def jax_inner_func(carry, x): + inner_args = jax_args_to_inner_scan(op, carry, x) + inner_scan_outs = list(jax_at_inner_func(*inner_args)) + new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs) + return new_carry, inner_scan_outs + + _, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps) + + # We need to prepend the initial values so that the JAX output will + # match the raw `Scan` `Op` output and, thus, work with a downstream + # `Subtensor` `Op` introduced by the `scan` helper function. + def append_scan_out(scan_in_part, scan_out_part): + return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0) + + if scan_args.outer_in_mit_sot: + scan_out_final = [ + append_scan_out(init, out) + for init, out in zip(scan_args.outer_in_mit_sot, scan_out) + ] + elif scan_args.outer_in_sit_sot: + scan_out_final = [ + append_scan_out(init, out) + for init, out in zip(scan_args.outer_in_sit_sot, scan_out) + ] + + if len(scan_out_final) == 1: + scan_out_final = scan_out_final[0] + return scan_out_final + + return scan diff --git a/aesara/link/jax/dispatch/shape.py b/aesara/link/jax/dispatch/shape.py new file mode 100644 index 0000000000..77da552713 --- /dev/null +++ b/aesara/link/jax/dispatch/shape.py @@ -0,0 +1,65 @@ +import jax.numpy as jnp + +from aesara.graph import Constant +from aesara.link.jax.dispatch.basic import jax_funcify +from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast + + +@jax_funcify.register(Reshape) +def jax_funcify_Reshape(op, node, **kwargs): + + # JAX reshape only works with constant inputs, otherwise JIT fails + shape = node.inputs[1] + if isinstance(shape, Constant): + constant_shape = shape.data + + def reshape(x, shape): + return jnp.reshape(x, constant_shape) + + else: + + def reshape(x, shape): + return jnp.reshape(x, shape) + + return reshape + + +@jax_funcify.register(Shape) +def jax_funcify_Shape(op, **kwargs): + def shape(x): + return jnp.shape(x) + + return shape + + +@jax_funcify.register(Shape_i) +def jax_funcify_Shape_i(op, **kwargs): + i = op.i + + def shape_i(x): + return jnp.shape(x)[i] + + return shape_i + + +@jax_funcify.register(SpecifyShape) +def jax_funcify_SpecifyShape(op, **kwargs): + def specifyshape(x, *shape): + assert x.ndim == len(shape) + assert jnp.all(x.shape == tuple(shape)), ( + "got shape", + x.shape, + "expected", + shape, + ) + return x + + return specifyshape + + +@jax_funcify.register(Unbroadcast) +def jax_funcify_Unbroadcast(op, **kwargs): + def unbroadcast(x): + return x + + return unbroadcast diff --git a/aesara/link/jax/dispatch/slinalg.py b/aesara/link/jax/dispatch/slinalg.py new file mode 100644 index 0000000000..f2cf27e60c --- /dev/null +++ b/aesara/link/jax/dispatch/slinalg.py @@ -0,0 +1,48 @@ +import jax + +from aesara.link.jax.dispatch.basic import jax_funcify +from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular + + +@jax_funcify.register(Cholesky) +def jax_funcify_Cholesky(op, **kwargs): + lower = op.lower + + def cholesky(a, lower=lower): + return jax.scipy.linalg.cholesky(a, lower=lower).astype(a.dtype) + + return cholesky + + +@jax_funcify.register(Solve) +def jax_funcify_Solve(op, **kwargs): + + if op.assume_a != "gen" and op.lower: + lower = True + else: + lower = False + + def solve(a, b, lower=lower): + return jax.scipy.linalg.solve(a, b, lower=lower) + + return solve + + +@jax_funcify.register(SolveTriangular) +def jax_funcify_SolveTriangular(op, **kwargs): + lower = op.lower + trans = op.trans + unit_diagonal = op.unit_diagonal + check_finite = op.check_finite + + def solve_triangular(A, b): + return jax.scipy.linalg.solve_triangular( + A, + b, + lower=lower, + trans=trans, + unit_diagonal=unit_diagonal, + check_finite=check_finite, + ) + + return solve_triangular diff --git a/aesara/link/jax/dispatch/subtensor.py b/aesara/link/jax/dispatch/subtensor.py new file mode 100644 index 0000000000..822d78a6fa --- /dev/null +++ b/aesara/link/jax/dispatch/subtensor.py @@ -0,0 +1,97 @@ +import jax + +from aesara.link.jax.dispatch.basic import jax_funcify +from aesara.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + AdvancedSubtensor, + AdvancedSubtensor1, + IncSubtensor, + Subtensor, + indices_from_subtensor, +) +from aesara.tensor.type_other import MakeSlice + + +@jax_funcify.register(Subtensor) +@jax_funcify.register(AdvancedSubtensor) +@jax_funcify.register(AdvancedSubtensor1) +def jax_funcify_Subtensor(op, **kwargs): + + idx_list = getattr(op, "idx_list", None) + + def subtensor(x, *ilists): + + indices = indices_from_subtensor(ilists, idx_list) + + if len(indices) == 1: + indices = indices[0] + + return x.__getitem__(indices) + + return subtensor + + +@jax_funcify.register(IncSubtensor) +@jax_funcify.register(AdvancedIncSubtensor1) +def jax_funcify_IncSubtensor(op, **kwargs): + + idx_list = getattr(op, "idx_list", None) + + if getattr(op, "set_instead_of_inc", False): + jax_fn = getattr(jax.ops, "index_update", None) + + if jax_fn is None: + + def jax_fn(x, indices, y): + return x.at[indices].set(y) + + else: + jax_fn = getattr(jax.ops, "index_add", None) + + if jax_fn is None: + + def jax_fn(x, indices, y): + return x.at[indices].add(y) + + def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): + indices = indices_from_subtensor(ilist, idx_list) + if len(indices) == 1: + indices = indices[0] + + return jax_fn(x, indices, y) + + return incsubtensor + + +@jax_funcify.register(AdvancedIncSubtensor) +def jax_funcify_AdvancedIncSubtensor(op, **kwargs): + + if getattr(op, "set_instead_of_inc", False): + jax_fn = getattr(jax.ops, "index_update", None) + + if jax_fn is None: + + def jax_fn(x, indices, y): + return x.at[indices].set(y) + + else: + jax_fn = getattr(jax.ops, "index_add", None) + + if jax_fn is None: + + def jax_fn(x, indices, y): + return x.at[indices].add(y) + + def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn): + return jax_fn(x, ilist, y) + + return advancedincsubtensor + + +@jax_funcify.register(MakeSlice) +def jax_funcify_MakeSlice(op, **kwargs): + def makeslice(*x): + return slice(*x) + + return makeslice diff --git a/aesara/link/jax/dispatch/tensor_basic.py b/aesara/link/jax/dispatch/tensor_basic.py new file mode 100644 index 0000000000..c15233175f --- /dev/null +++ b/aesara/link/jax/dispatch/tensor_basic.py @@ -0,0 +1,114 @@ +import jax.numpy as jnp + +from aesara.link.jax.dispatch.basic import jax_funcify +from aesara.tensor.basic import ( + Alloc, + AllocDiag, + AllocEmpty, + ARange, + ExtractDiag, + Eye, + Join, + MakeVector, + ScalarFromTensor, + TensorFromScalar, +) + + +@jax_funcify.register(AllocDiag) +def jax_funcify_AllocDiag(op, **kwargs): + offset = op.offset + + def allocdiag(v, offset=offset): + return jnp.diag(v, k=offset) + + return allocdiag + + +@jax_funcify.register(AllocEmpty) +def jax_funcify_AllocEmpty(op, **kwargs): + def allocempty(*shape): + return jnp.empty(shape, dtype=op.dtype) + + return allocempty + + +@jax_funcify.register(Alloc) +def jax_funcify_Alloc(op, **kwargs): + def alloc(x, *shape): + res = jnp.broadcast_to(x, shape) + return res + + return alloc + + +@jax_funcify.register(ARange) +def jax_funcify_ARange(op, **kwargs): + # XXX: This currently requires concrete arguments. + def arange(start, stop, step): + return jnp.arange(start, stop, step, dtype=op.dtype) + + return arange + + +@jax_funcify.register(Join) +def jax_funcify_Join(op, **kwargs): + def join(axis, *tensors): + # tensors could also be tuples, and in this case they don't have a ndim + tensors = [jnp.asarray(tensor) for tensor in tensors] + view = op.view + if (view != -1) and all( + tensor.shape[axis] == 0 for tensor in tensors[0:view] + tensors[view + 1 :] + ): + return tensors[view] + + else: + return jnp.concatenate(tensors, axis=axis) + + return join + + +@jax_funcify.register(ExtractDiag) +def jax_funcify_ExtractDiag(op, **kwargs): + offset = op.offset + axis1 = op.axis1 + axis2 = op.axis2 + + def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2): + return jnp.diagonal(x, offset=offset, axis1=axis1, axis2=axis2) + + return extract_diag + + +@jax_funcify.register(Eye) +def jax_funcify_Eye(op, **kwargs): + dtype = op.dtype + + def eye(N, M, k): + return jnp.eye(N, M, k, dtype=dtype) + + return eye + + +@jax_funcify.register(MakeVector) +def jax_funcify_MakeVector(op, **kwargs): + def makevector(*x): + return jnp.array(x, dtype=op.dtype) + + return makevector + + +@jax_funcify.register(TensorFromScalar) +def jax_funcify_TensorFromScalar(op, **kwargs): + def tensor_from_scalar(x): + return jnp.array(x) + + return tensor_from_scalar + + +@jax_funcify.register(ScalarFromTensor) +def jax_funcify_ScalarFromTensor(op, **kwargs): + def scalar_from_tensor(x): + return jnp.array(x).flatten()[0] + + return scalar_from_tensor diff --git a/aesara/link/jax/dispatch/test_subtensor.py b/aesara/link/jax/dispatch/test_subtensor.py new file mode 100644 index 0000000000..22cc492402 --- /dev/null +++ b/aesara/link/jax/dispatch/test_subtensor.py @@ -0,0 +1,186 @@ +import jax +import numpy as np +import pytest +from jax._src.errors import NonConcreteBooleanIndexError +from packaging.version import parse as version_parse + +import aesara.tensor as at +from aesara.configdefaults import config +from aesara.graph.fg import FunctionGraph +from aesara.tensor import subtensor as at_subtensor +from tests.link.jax.test_basic import compare_jax_and_py + + +def test_jax_Subtensors(): + # Basic indices + x_at = at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) + out_at = x_at[1, 2, 0] + assert isinstance(out_at.owner.op, at_subtensor.Subtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + out_at = x_at[1:2, 1, :] + assert isinstance(out_at.owner.op, at_subtensor.Subtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + # Advanced indexing + out_at = at_subtensor.advanced_subtensor1(x_at, [1, 2]) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor1) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + out_at = x_at[[1, 2], [2, 3]] + assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + # Advanced and basic indexing + out_at = x_at[[1, 2], :] + assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + out_at = x_at[[1, 2], :, [3, 4]] + assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + +@pytest.mark.xfail( + version_parse(jax.__version__) >= version_parse("0.2.12"), + reason="Omnistaging cannot be disabled", +) +def test_jax_Subtensors_omni(): + x_at = at.arange(3 * 4 * 5).reshape((3, 4, 5)) + + # Boolean indices + out_at = x_at[x_at < 0] + assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + +def test_jax_IncSubtensor(): + rng = np.random.default_rng(213234) + + x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) + x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)) + + # "Set" basic indices + st_at = at.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) + out_at = at_subtensor.set_subtensor(x_at[1, 2, 3], st_at) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) + out_at = at_subtensor.set_subtensor(x_at[:2, 0, 0], st_at) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + out_at = at_subtensor.set_subtensor(x_at[0, 1:3, 0], st_at) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + # "Set" advanced indices + st_at = at.as_tensor_variable( + rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX) + ) + out_at = at_subtensor.set_subtensor(x_at[np.r_[0, 2]], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) + out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, 0], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + # "Set" boolean indices + mask_at = at.constant(x_np > 0) + out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + # "Increment" basic indices + st_at = at.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) + out_at = at_subtensor.inc_subtensor(x_at[1, 2, 3], st_at) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) + out_at = at_subtensor.inc_subtensor(x_at[:2, 0, 0], st_at) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + out_at = at_subtensor.set_subtensor(x_at[0, 1:3, 0], st_at) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + # "Increment" advanced indices + st_at = at.as_tensor_variable( + rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX) + ) + out_at = at_subtensor.inc_subtensor(x_at[np.r_[0, 2]], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) + out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, 0], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + # "Increment" boolean indices + mask_at = at.constant(x_np > 0) + out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_jax_and_py(out_fg, []) + + +def test_jax_IncSubtensors_unsupported(): + rng = np.random.default_rng(213234) + x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) + x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)) + + mask_at = at.as_tensor(x_np) > 0 + out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + with pytest.raises( + NonConcreteBooleanIndexError, match="Array boolean indices must be concrete" + ): + compare_jax_and_py(out_fg, []) + + mask_at = at.as_tensor_variable(x_np) > 0 + out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + with pytest.raises( + NonConcreteBooleanIndexError, match="Array boolean indices must be concrete" + ): + compare_jax_and_py(out_fg, []) + + st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) + out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + with pytest.raises(IndexError, match="Array slice indices must have static"): + compare_jax_and_py(out_fg, []) + + st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) + out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + with pytest.raises(IndexError, match="Array slice indices must have static"): + compare_jax_and_py(out_fg, []) diff --git a/aesara/link/numba/dispatch/basic.py b/aesara/link/numba/dispatch/basic.py index f36b752a5c..3c579637b5 100644 --- a/aesara/link/numba/dispatch/basic.py +++ b/aesara/link/numba/dispatch/basic.py @@ -499,7 +499,6 @@ def numba_funcify_Subtensor(op, node, **kwargs): @numba_funcify.register(IncSubtensor) @numba_funcify.register(AdvancedIncSubtensor) -@numba_funcify.register(AdvancedIncSubtensor1) def numba_funcify_IncSubtensor(op, node, **kwargs): incsubtensor_def_src = create_index_func( @@ -515,6 +514,39 @@ def numba_funcify_IncSubtensor(op, node, **kwargs): return numba_njit(incsubtensor_fn) +@numba_funcify.register(AdvancedIncSubtensor1) +def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): + inplace = op.inplace + set_instead_of_inc = op.set_instead_of_inc + + if set_instead_of_inc: + + @numba_njit + def advancedincsubtensor1_inplace(x, vals, idxs): + for idx, val in zip(idxs, vals): + x[idx] = val + return x + + else: + + @numba_njit + def advancedincsubtensor1_inplace(x, vals, idxs): + for idx, val in zip(idxs, vals): + x[idx] += val + return x + + if inplace: + return advancedincsubtensor1_inplace + else: + + @numba_njit + def advancedincsubtensor1(x, vals, idxs): + x = x.copy() + return advancedincsubtensor1_inplace(x, vals, idxs) + + return advancedincsubtensor1 + + @numba_funcify.register(DeepCopyOp) def numba_funcify_DeepCopyOp(op, node, **kwargs): diff --git a/aesara/link/numba/dispatch/random.py b/aesara/link/numba/dispatch/random.py index 9d2891dca9..b41f83e221 100644 --- a/aesara/link/numba/dispatch/random.py +++ b/aesara/link/numba/dispatch/random.py @@ -312,33 +312,21 @@ def body_fn(a): @numba_funcify.register(aer.CategoricalRV) def numba_funcify_CategoricalRV(op, node, **kwargs): out_dtype = node.outputs[1].type.numpy_dtype - - ind_shape_len = node.inputs[3].type.ndim - 1 - neg_ind_shape_len = -ind_shape_len - size_len = int(get_vector_length(node.inputs[1])) @numba_basic.numba_njit def categorical_rv(rng, size, dtype, p): - - size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) - ind_shape = p.shape[:-1] - - if ind_shape_len > 0: - if size_len > 0 and size_tpl[neg_ind_shape_len:] != ind_shape: - raise ValueError("Parameters shape and size do not match.") - - samples_shape = size_tpl[:neg_ind_shape_len] + ind_shape - p_bcast = np.broadcast_to(p, size_tpl[:neg_ind_shape_len] + p.shape) + if not size_len: + size_tpl = p.shape[:-1] else: - samples_shape = size_tpl - p_bcast = p + size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) + p = np.broadcast_to(p, size_tpl + p.shape[-1:]) - unif_samples = np.random.uniform(0, 1, samples_shape) + unif_samples = np.random.uniform(0, 1, size_tpl) - res = np.empty(samples_shape, dtype=out_dtype) - for idx in np.ndindex(*samples_shape): - res[idx] = np.searchsorted(np.cumsum(p_bcast[idx]), unif_samples[idx]) + res = np.empty(size_tpl, dtype=out_dtype) + for idx in np.ndindex(*size_tpl): + res[idx] = np.searchsorted(np.cumsum(p[idx]), unif_samples[idx]) return (rng, res) diff --git a/aesara/link/numba/dispatch/scalar.py b/aesara/link/numba/dispatch/scalar.py index d561876254..28031ea988 100644 --- a/aesara/link/numba/dispatch/scalar.py +++ b/aesara/link/numba/dispatch/scalar.py @@ -22,8 +22,8 @@ Clip, Composite, Identity, - Inv, Mul, + Reciprocal, ScalarOp, Second, Switch, @@ -236,13 +236,15 @@ def second(x, y): return second -@numba_funcify.register(Inv) -def numba_funcify_Inv(op, node, **kwargs): +@numba_funcify.register(Reciprocal) +def numba_funcify_Reciprocal(op, node, **kwargs): @numba_basic.numba_njit(inline="always") - def inv(x): + def reciprocal(x): + # TODO FIXME: This isn't really the behavior or `numpy.reciprocal` when + # `x` is an `int` return 1 / x - return inv + return reciprocal @numba_funcify.register(Sigmoid) diff --git a/aesara/link/numba/dispatch/scan.py b/aesara/link/numba/dispatch/scan.py index cd8218c166..2a1a6c2735 100644 --- a/aesara/link/numba/dispatch/scan.py +++ b/aesara/link/numba/dispatch/scan.py @@ -54,7 +54,7 @@ def numba_funcify_Scan(op, node, **kwargs): p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot - input_names = [n.auto_name for n in node.inputs[1:]] + input_names = [f"{n.auto_name}_{i}" for i, n in enumerate(node.inputs[1:])] outer_in_seqs_names = input_names[:n_seqs] outer_in_mit_mot_names = input_names[p_in_mit_mot : p_in_mit_mot + n_mit_mot] outer_in_mit_sot_names = input_names[p_in_mit_sot : p_in_mit_sot + n_mit_sot] diff --git a/aesara/link/utils.py b/aesara/link/utils.py index ce68b2dded..ba118ef170 100644 --- a/aesara/link/utils.py +++ b/aesara/link/utils.py @@ -243,7 +243,7 @@ def gc_helper(node_list: List[Apply]): ------- 2-tuple FIRST, the set of Variable instances which are computed by node_list, - and SECOND a dictionary that maps each Variable instance to a the last + and SECOND a dictionary that maps each Variable instance to the last node to use Variable as an input. Extended Summary diff --git a/aesara/raise_op.py b/aesara/raise_op.py index 9853b06ceb..766f0df534 100644 --- a/aesara/raise_op.py +++ b/aesara/raise_op.py @@ -10,6 +10,7 @@ from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.link.c.type import Generic +from aesara.scalar.basic import ScalarType from aesara.tensor.type import DenseTensorType @@ -78,7 +79,10 @@ def make_node(self, value: Variable, *conds: Tuple[Variable]): if not isinstance(value, Variable): value = at.as_tensor_variable(value) - conds = [at.as_tensor_variable(c) for c in conds] + conds = [ + at.as_tensor_variable(c) if not isinstance(c, Variable) else c + for c in conds + ] assert all(c.type.ndim == 0 for c in conds) @@ -102,7 +106,7 @@ def connection_pattern(self, node): return [[1]] + [[0]] * (len(node.inputs) - 1) def c_code(self, node, name, inames, onames, props): - if not isinstance(node.inputs[0].type, DenseTensorType): + if not isinstance(node.inputs[0].type, (DenseTensorType, ScalarType)): raise NotImplementedError( f"CheckAndRaise c_code not implemented for input type {node.inputs[0].type}" ) @@ -112,25 +116,47 @@ def c_code(self, node, name, inames, onames, props): fail_code = props["fail"] param_struct_name = props["params"] msg = self.msg.replace('"', '\\"').replace("\n", "\\n") + for idx, cond_name in enumerate(cond_names): - check.append( - f""" - if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{ - PyObject * exc_type = {param_struct_name}->exc_type; - Py_INCREF(exc_type); - PyErr_SetString(exc_type, "{msg}"); - Py_XDECREF(exc_type); - {indent(fail_code, " " * 4)} - }} - """ - ) + if isinstance(node.inputs[0].type, DenseTensorType): + check.append( + f""" + if(PyObject_IsTrue((PyObject *){cond_name}) == 0) {{ + PyObject * exc_type = {param_struct_name}->exc_type; + Py_INCREF(exc_type); + PyErr_SetString(exc_type, "{msg}"); + Py_XDECREF(exc_type); + {indent(fail_code, " " * 4)} + }} + """ + ) + else: + check.append( + f""" + if({cond_name} == 0) {{ + PyObject * exc_type = {param_struct_name}->exc_type; + Py_INCREF(exc_type); + PyErr_SetString(exc_type, "{msg}"); + Py_XDECREF(exc_type); + {indent(fail_code, " " * 4)} + }} + """ + ) + check = "\n".join(check) - res = f""" - {check} - Py_XDECREF({out_name}); - {out_name} = {value_name}; - Py_INCREF({value_name}); - """ + + if isinstance(node.inputs[0].type, DenseTensorType): + res = f""" + {check} + Py_XDECREF({out_name}); + {out_name} = {value_name}; + Py_INCREF({value_name}); + """ + else: + res = f""" + {check} + {out_name} = {value_name}; + """ return res def c_code_cache_version(self): diff --git a/aesara/sandbox/linalg/ops.py b/aesara/sandbox/linalg/ops.py index b16e7605c2..c92fdc6b8d 100644 --- a/aesara/sandbox/linalg/ops.py +++ b/aesara/sandbox/linalg/ops.py @@ -1,18 +1,18 @@ import logging -from aesara.graph.opt import local_optimizer +from aesara.graph.rewriting.basic import node_rewriter from aesara.tensor import basic as at -from aesara.tensor.basic_opt import ( - register_canonicalize, - register_specialize, - register_stabilize, -) from aesara.tensor.blas import Dot22 from aesara.tensor.elemwise import DimShuffle from aesara.tensor.math import Dot, Prod, dot, log from aesara.tensor.math import pow as at_pow from aesara.tensor.math import prod from aesara.tensor.nlinalg import Det, MatrixInverse, trace +from aesara.tensor.rewriting.basic import ( + register_canonicalize, + register_specialize, + register_stabilize, +) from aesara.tensor.slinalg import Cholesky, Solve, cholesky, solve @@ -20,7 +20,7 @@ @register_canonicalize -@local_optimizer([DimShuffle]) +@node_rewriter([DimShuffle]) def transinv_to_invtrans(fgraph, node): if isinstance(node.op, DimShuffle): if node.op.new_order == (1, 0): @@ -32,7 +32,7 @@ def transinv_to_invtrans(fgraph, node): @register_stabilize -@local_optimizer([Dot, Dot22]) +@node_rewriter([Dot, Dot22]) def inv_as_solve(fgraph, node): """ This utilizes a boolean `symmetric` tag on the matrices. @@ -51,7 +51,7 @@ def inv_as_solve(fgraph, node): @register_stabilize @register_canonicalize -@local_optimizer([Solve]) +@node_rewriter([Solve]) def tag_solve_triangular(fgraph, node): """ If a general solve() is applied to the output of a cholesky op, then @@ -82,7 +82,7 @@ def tag_solve_triangular(fgraph, node): @register_canonicalize @register_stabilize @register_specialize -@local_optimizer([DimShuffle]) +@node_rewriter([DimShuffle]) def no_transpose_symmetric(fgraph, node): if isinstance(node.op, DimShuffle): x = node.inputs[0] @@ -92,7 +92,7 @@ def no_transpose_symmetric(fgraph, node): @register_stabilize -@local_optimizer([Solve]) +@node_rewriter([Solve]) def psd_solve_with_chol(fgraph, node): """ This utilizes a boolean `psd` tag on matrices. @@ -111,7 +111,7 @@ def psd_solve_with_chol(fgraph, node): @register_stabilize @register_specialize -@local_optimizer([Det]) +@node_rewriter([Det]) def local_det_chol(fgraph, node): """ If we have det(X) and there is already an L=cholesky(X) @@ -129,7 +129,7 @@ def local_det_chol(fgraph, node): @register_canonicalize @register_stabilize @register_specialize -@local_optimizer([log]) +@node_rewriter([log]) def local_log_prod_sqr(fgraph, node): """ This utilizes a boolean `positive` tag on matrices. diff --git a/aesara/sandbox/multinomial.py b/aesara/sandbox/multinomial.py index 78d3de6fd8..fc72ca8c6d 100644 --- a/aesara/sandbox/multinomial.py +++ b/aesara/sandbox/multinomial.py @@ -1,5 +1,4 @@ import copy -import warnings from typing import Tuple, Union import numpy as np @@ -435,14 +434,3 @@ def perform(self, node, ins, outs): pvals[n, m] = 0.0 pvals[n] /= pvals[n].sum() break - - -class MultinomialWOReplacementFromUniform(ChoiceFromUniform): - def __init__(self, *args, **kwargs): - warnings.warn( - "MultinomialWOReplacementFromUniform is deprecated, " - "use ChoiceFromUniform instead.", - DeprecationWarning, - stacklevel=2, - ) - super().__init__(*args, **kwargs) diff --git a/aesara/sandbox/rng_mrg.py b/aesara/sandbox/rng_mrg.py index 3e006f01a1..a1afeb3d5d 100644 --- a/aesara/sandbox/rng_mrg.py +++ b/aesara/sandbox/rng_mrg.py @@ -25,7 +25,7 @@ from aesara.configdefaults import config from aesara.gradient import undefined_grad from aesara.graph.basic import Apply, Constant, Variable -from aesara.graph.opt import in2out, local_optimizer +from aesara.graph.rewriting.basic import in2out, node_rewriter from aesara.link.c.op import COp, Op from aesara.link.c.params_type import ParamsType from aesara.sandbox import multinomial @@ -37,6 +37,14 @@ from aesara.tensor.type import TensorType, iscalar, ivector, lmatrix +warnings.warn( + "The module `aesara.sandbox.rng_mrg` is deprecated. " + "Use the module `aesara.tensor.random` for random variables instead.", + DeprecationWarning, + stacklevel=2, +) + + def matVecModM(A, s, m): # TODO : need description for method, parameter and return assert A.dtype == "int64" @@ -1107,10 +1115,10 @@ def multinomial_wo_replacement( **kwargs, ): warnings.warn( - "MRG_RandomStream.multinomial_wo_replacement is " - "deprecated and will be removed in the next release of " - "Aesara. Please use MRG_RandomStream.choice instead.", + "`MRG_RandomStream.multinomial_wo_replacement` is " + "deprecated; use `MRG_RandomStream.choice` instead.", DeprecationWarning, + stacklevel=2, ) assert size is None return self.choice( @@ -1343,7 +1351,7 @@ def _check_size(size): return at.as_tensor_variable(size, ndim=1) -@local_optimizer((mrg_uniform_base,)) +@node_rewriter((mrg_uniform_base,)) def mrg_random_make_inplace(fgraph, node): op = node.op diff --git a/aesara/scalar/basic.py b/aesara/scalar/basic.py index 0215973179..6764f4a16e 100644 --- a/aesara/scalar/basic.py +++ b/aesara/scalar/basic.py @@ -27,7 +27,7 @@ from aesara.gradient import DisconnectedType, grad_undefined from aesara.graph.basic import Apply, Constant, Variable, clone, list_of_nodes from aesara.graph.fg import FunctionGraph -from aesara.graph.opt import MergeOptimizer +from aesara.graph.rewriting.basic import MergeOptimizer from aesara.graph.type import HasDataType, HasShape from aesara.graph.utils import MetaObject, MethodNotDefined from aesara.link.c.op import COp @@ -670,10 +670,6 @@ def get_size(self, shape_info): return shape_info -# Deprecated alias for backward compatibility -Scalar = ScalarType - - def get_scalar_type(dtype, cache: Dict[str, ScalarType] = {}) -> ScalarType: """ Return a ScalarType(dtype) object. @@ -733,6 +729,12 @@ class _scalar_py_operators: dtype = property(lambda self: self.type.dtype) """The dtype of this scalar.""" + @property + def shape(self): + from aesara.tensor.basic import as_tensor_variable + + return as_tensor_variable([], ndim=1, dtype=np.int64) + # UNARY def __abs__(self): return abs(self) @@ -2903,10 +2905,6 @@ def c_code(self, node, name, inputs, outputs, sub): reciprocal = Reciprocal(upgrade_to_float, name="reciprocal") -# These are deprecated and will be removed -Inv = Reciprocal -inv = reciprocal - class Log(UnaryScalarOp): """ @@ -4160,7 +4158,7 @@ def init_fgraph(self): # the fgraph to be set to the variable as we need to pickle # them for the cache of c module to work. fgraph = FunctionGraph(self.inputs, self.outputs) - MergeOptimizer().optimize(fgraph) + MergeOptimizer().rewrite(fgraph) for node in fgraph.apply_nodes: if not isinstance(node.op, ScalarOp): raise ValueError( @@ -4455,3 +4453,26 @@ def handle_composite(node, mapping): Compositef32.special[Composite] = handle_composite + + +DEPRECATED_NAMES = [ + ("Inv", "`Inv` is deprecated; use `Reciprocal` instead.", Reciprocal), + ("inv", "`inv` is deprecated; use `reciprocal` instead.", reciprocal), + ("Scalar", "`Scalar` is deprecated; use `ScalarType` instead.", ScalarType), +] + + +def __getattr__(name): + """Intercept module-level attribute access of deprecated symbols. + + Adapted from https://stackoverflow.com/a/55139609/3006474. + + """ + from warnings import warn + + for old_name, msg, old_object in DEPRECATED_NAMES: + if name == old_name: + warn(msg, DeprecationWarning, stacklevel=2) + return old_object + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/aesara/scalar/math.py b/aesara/scalar/math.py index 989abe89db..43ee662ebf 100644 --- a/aesara/scalar/math.py +++ b/aesara/scalar/math.py @@ -250,6 +250,35 @@ def c_code(self, node, name, inp, out, sub): erfcinv = Erfcinv(upgrade_to_float_no_complex, name="erfcinv") +class Owens_t(BinaryScalarOp): + nfunc_spec = ("scipy.special.owens_t", 2, 1) + + @staticmethod + def st_impl(h, a): + return scipy.special.owens_t(h, a) + + def impl(self, h, a): + return Owens_t.st_impl(h, a) + + def grad(self, inputs, grads): + (h, a) = inputs + (gz,) = grads + return [ + gz + * (-1) + * exp(-(h**2) / 2) + * erf(a * h / np.sqrt(2)) + / (2 * np.sqrt(2 * np.pi)), + gz * exp(-0.5 * (a**2 + 1) * h**2) / (2 * np.pi * (a**2 + 1)), + ] + + def c_code(self, *args, **kwargs): + raise NotImplementedError() + + +owens_t = Owens_t(upgrade_to_float, name="owens_t") + + class Gamma(UnaryScalarOp): nfunc_spec = ("scipy.special.gamma", 1, 1) diff --git a/aesara/scan/__init__.py b/aesara/scan/__init__.py index a108c317d4..54e102cda7 100644 --- a/aesara/scan/__init__.py +++ b/aesara/scan/__init__.py @@ -51,7 +51,7 @@ configdefaults.add_scan_configvars() -from aesara.scan import opt +from aesara.scan import rewriting from aesara.scan.basic import scan from aesara.scan.checkpoints import scan_checkpoints from aesara.scan.utils import until diff --git a/aesara/scan/op.py b/aesara/scan/op.py index 05e6e7dc59..f8a8ad01ef 100644 --- a/aesara/scan/op.py +++ b/aesara/scan/op.py @@ -3440,7 +3440,7 @@ def profile_printer( ) -@op_debug_information.register(Scan) # type: ignore +@op_debug_information.register(Scan) def _op_debug_information_Scan(op, node): from typing import Sequence diff --git a/aesara/scan/opt.py b/aesara/scan/opt.py index 9eea791a54..417c70a785 100644 --- a/aesara/scan/opt.py +++ b/aesara/scan/opt.py @@ -1,2482 +1,10 @@ -"""This module provides optimizations for the `Scan` `Op`.""" +import warnings -import copy -import dataclasses -from itertools import chain -from sys import maxsize -from typing import Dict, List, Optional, Tuple, cast -import numpy as np - -import aesara -from aesara import scalar as aes -from aesara import tensor as at -from aesara.compile import optdb -from aesara.compile.function.types import deep_copy_op -from aesara.configdefaults import config -from aesara.graph.basic import ( - Apply, - Constant, - Variable, - clone_replace, - equal_computations, - graph_inputs, - io_toposort, - is_in_ancestors, -) -from aesara.graph.destroyhandler import DestroyHandler -from aesara.graph.features import ReplaceValidate -from aesara.graph.fg import FunctionGraph -from aesara.graph.op import compute_test_value -from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer -from aesara.graph.optdb import EquilibriumDB, SequenceDB -from aesara.graph.type import HasShape -from aesara.graph.utils import InconsistencyError -from aesara.scan.op import Scan, ScanInfo -from aesara.scan.utils import ( - ScanArgs, - compress_outs, - expand_empty, - reconstruct_graph, - safe_new, - scan_can_remove_outs, -) -from aesara.tensor import basic_opt, math_opt -from aesara.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value -from aesara.tensor.elemwise import DimShuffle, Elemwise -from aesara.tensor.exceptions import NotScalarConstantError -from aesara.tensor.math import Dot, dot, maximum, minimum -from aesara.tensor.shape import shape -from aesara.tensor.subtensor import ( - IncSubtensor, - Subtensor, - get_canonical_form_slice, - get_idx_list, - get_slice_elements, - set_subtensor, +warnings.warn( + "The module `aesara.scan.opt` is deprecated; use `aesara.scan.rewriting` instead.", + DeprecationWarning, + stacklevel=2, ) -from aesara.tensor.var import TensorConstant, get_unique_value - - -list_opt_slice = [ - math_opt.local_abs_merge, - math_opt.local_mul_switch_sink, - basic_opt.local_upcast_elemwise_constant_inputs, - basic_opt.local_useless_switch, - basic_opt.constant_folding, -] - - -@local_optimizer([Scan]) -def remove_constants_and_unused_inputs_scan(fgraph, node): - """Move constants into the inner graph, and remove unused inputs. - - Constants that are in the outer graph are represented by a free symbolic - variable in the inner graph. If we move them into the inner graph, - constant-folding can happen in the inner graph. - This is applied only on sequences and non-sequences, - not on initial states. - - """ - if not isinstance(node.op, Scan): - return False - op = node.op - op_info = op.info - # We only need to take care of sequences and other arguments - st = op_info.n_seqs - st += int( - sum(len(x) for x in chain(op_info.mit_mot_in_slices, op_info.mit_sot_in_slices)) - ) - st += op_info.n_sit_sot - st += op_info.n_shared_outs - - op_ins = op.inner_inputs - op_outs = op.inner_outputs - - # Corresponds to the initial states, which should stay untouched. - # We put those variables aside, and put them back at the end. - out_stuff_inner = op_ins[op_info.n_seqs : st] - - non_seqs = op_ins[st:] - st = ( - op_info.n_seqs - + op_info.n_mit_mot - + op_info.n_mit_sot - + op_info.n_sit_sot - + op_info.n_nit_sot - + op_info.n_shared_outs - + 1 - ) - outer_non_seqs = node.inputs[st:] - out_stuff_outer = node.inputs[1 + op_info.n_seqs : st] - - # To replace constants in the outer graph by clones in the inner graph - givens = {} - # All the inputs of the inner graph of the new scan - nw_inner = [] - # Same for the outer graph, initialized w/ number of steps - nw_outer = [node.inputs[0]] - - all_ins = list(graph_inputs(op_outs)) - for idx in range(op_info.n_seqs): - node_inp = node.inputs[idx + 1] - if ( - isinstance(node_inp, TensorConstant) - and get_unique_value(node_inp) is not None - ): - try: - # This works if input is a constant that has all entries - # equal - givens[op_ins[idx]] = node_inp[0] - except TypeError: - pass - elif op_ins[idx] in all_ins: - # Check for identical other sequence - identical_seqs = [ - x for x in nw_outer if equal_computations([x], [node_inp]) - ] - if identical_seqs: - index = node.inputs.index(identical_seqs[0]) - 1 - givens[op_ins[idx]] = op_ins[index] - else: - nw_inner.append(op_ins[idx]) - nw_outer.append(node_inp) - - nw_n_seqs = len(nw_inner) - # Add outputs stuff - nw_inner += out_stuff_inner - nw_outer += out_stuff_outer - - # Look through non sequences - nw_inner_nonseq = [] - nw_outer_nonseq = [] - for idx, (nw_in, nw_out) in enumerate(zip(non_seqs, outer_non_seqs)): - if isinstance(nw_out, Constant): - givens[nw_in] = nw_out - elif nw_in in all_ins: - # Indices of elements of nw_outer_nonseq that are equivalent - # to nw_out. - identical_nonseq_idx = [ - i - for (i, x) in enumerate(nw_outer_nonseq) - if equal_computations([x], [nw_out]) - ] - if identical_nonseq_idx: - givens[nw_in] = nw_inner_nonseq[identical_nonseq_idx[0]] - else: - nw_inner_nonseq.append(nw_in) - nw_outer_nonseq.append(nw_out) - - nw_inner.extend(nw_inner_nonseq) - nw_outer.extend(nw_outer_nonseq) - - if len(nw_inner) != len(op_ins): - op_outs = clone_replace(op_outs, replace=givens) - nw_info = dataclasses.replace( - op_info, n_seqs=nw_n_seqs, n_non_seqs=len(nw_inner_nonseq) - ) - nwScan = Scan( - nw_inner, - op_outs, - nw_info, - mode=op.mode, - profile=op.profile, - truncate_gradient=op.truncate_gradient, - # TODO: This seems questionable - name=op.name, - allow_gc=op.allow_gc, - ) - nw_outs = nwScan(*nw_outer, return_list=True) - return dict([("remove", [node])] + list(zip(node.outputs, nw_outs))) - else: - return False - - -@local_optimizer([Scan]) -def push_out_non_seq_scan(fgraph, node): - r"""Push out the variables inside the `Scan` that depend only on non-sequences. - - This optimizations pushes, out of `Scan`'s inner function and into the outer - function, computation that depends only on non-sequence inputs. Such - computation ends up being done every iteration on the same values so moving - it to the outer function to be executed only once, before the `Scan` `Op`, - reduces the amount of computation that needs to be performed. - """ - if not isinstance(node.op, Scan): - return False - - node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs - - local_fgraph_topo = io_toposort(node_inputs, node_outputs) - local_fgraph_outs_set = set(node_outputs) - local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)} - - to_remove_set = set() - to_replace_set = set() - to_replace_map = {} - - def add_to_replace(y): - to_replace_set.add(y) - to_replace_map[y] = add_to_replace.n - add_to_replace.n += 1 - - add_to_replace.n = 0 - - # The variables that will replace the variables pushed-out of the - # inner-graph - replace_with_in = [] - # The variables that have been pushed-out of the graph - replace_with_out = [] - - op = node.op - # Construct the list of non_sequences to simplify a few things - inner_non_seqs = op.inner_non_seqs(node_inputs) - inner_non_seqs_set = set(inner_non_seqs) - inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)} - - outer_non_seqs = op.outer_non_seqs(node.inputs) - - inner_seqs = op.inner_seqs(node_inputs) - outer_seqs = op.outer_seqs(node.inputs) - - assert len(inner_non_seqs) == len(outer_non_seqs) - assert len(inner_seqs) == len(outer_seqs) - - for nd in local_fgraph_topo: - if ( # we haven't already looked at this node - nd not in to_remove_set - and all( - ( - (x in inner_non_seqs_set) - or (x.owner in to_remove_set) - or isinstance(x, Constant) - ) - for x in nd.inputs - ) - # We can (supposedly) do this because the assumption is that a - # `ViewOp` or `DeepCopyOp` will be just at the end of the - # function and not somewhere in the middle - and not isinstance(nd.op, aesara.compile.ViewOp) - and not isinstance(nd.op, aesara.compile.DeepCopyOp) - ): - # We have a candidate node to remove from the inner-graph - - # Step 1. Reconstruct the node using the relevant outer-inputs. - # - # More specifically, the node's current inputs are either - # a) inner-graph input place-holders for non-sequences, - # b) the outputs of other nodes being pushed out of the inner-graph, - # c) or constants. - to_remove_set.add(nd) - new_inputs = [] - for old_input in nd.inputs: - if old_input in inner_non_seqs_set: - # This is case a), so we want to use the corresponding - # outer-graph input as the input to our new pushed-out node - _idx = inner_non_seqs_map[old_input] - new_input = outer_non_seqs[_idx] - elif old_input in to_replace_set: - # This is case b), so we want to use the new pushed-out node - # as the input to this new pushed-out node - new_input = replace_with_out[to_replace_map[old_input]] - else: - assert isinstance(old_input, Constant) - new_input = old_input - - new_input = old_input.type.filter_variable(new_input) - new_inputs.append(new_input) - - pushed_out_node = nd.op.make_node(*new_inputs) - - if config.compute_test_value != "off": - compute_test_value(pushed_out_node) - - # Step 2. Create variables to replace the old outputs of the node - # that we're pushing out of the inner-graph - for idx, y in enumerate(nd.outputs): - y_place_holder = y.clone() - # y_place_holder = safe_new(y, "_replace") - add_to_replace(y) - replace_with_in.append(y_place_holder) - assert isinstance(y, type(pushed_out_node.outputs[idx])) - replace_with_out.append(pushed_out_node.outputs[idx]) - - # We need to check all candidate replacements and choose those that - # make sense for us - # Step 1. which elements of `to_replace` are used by remaining - # components of the inner function - clean_to_replace = [] - clean_replace_with_in = [] - clean_replace_with_out = [] - existent_nodes = [nd for nd in local_fgraph_topo if nd not in to_remove_set] - existent_nodes_set = set(existent_nodes) - - to_keep_set = set() - for nd in existent_nodes: - to_keep_set.update(nd.inputs) - - for out, idx in to_replace_map.items(): - if ( # If types are different, conversion Op will be inserted, - # and it may trigger an infinite loop. - out.type.is_super(replace_with_in[idx].type) - and out in to_keep_set - and out.owner not in existent_nodes_set - ): - clean_to_replace.append(out) - clean_replace_with_in.append(replace_with_in[idx]) - clean_replace_with_out.append(replace_with_out[idx]) - - if len(clean_to_replace) > 0: - # We can finally put an end to all this madness - givens = {} - nw_outer = [] - nw_inner = [] - for to_repl, repl_in, repl_out in zip( - clean_to_replace, clean_replace_with_in, clean_replace_with_out - ): - if isinstance(repl_out, Constant): - repl_in = repl_out - else: - nw_inner.append(repl_in) - nw_outer.append(repl_out) - givens[to_repl] = repl_in - - op_outs = clone_replace(node_outputs, replace=givens) - op_ins = node_inputs + nw_inner - - new_info = dataclasses.replace( - op.info, n_non_seqs=op.info.n_non_seqs + len(nw_outer) - ) - - # Reconstruct node - nwScan = Scan( - op_ins, - op_outs, - new_info, - mode=op.mode, - profile=op.profile, - truncate_gradient=op.truncate_gradient, - # TODO: This seems questionable - name=op.name, - allow_gc=op.allow_gc, - ) - - # Do not call make_node for test_value - nw_node = nwScan(*(node.inputs + nw_outer), return_list=True)[0].owner - - replacements = dict(zip(node.outputs, nw_node.outputs)) - replacements["remove"] = [node] - return replacements - elif not to_keep_set: - # Nothing in the inner graph should be kept - replace_with = {} - for out, idx in to_replace_map.items(): - if out in local_fgraph_outs_set: - x = node.outputs[local_fgraph_outs_map[out]] - y = replace_with_out[idx] - y_shape = [shp for shp in y.shape] - replace_with[x] = at.alloc(y, node.inputs[0], *y_shape) - - # We need to add one extra dimension to the outputs - # because the scan op expects for a tensor3, to which an - # subtensor is applied that takes only the last element - if replace_with: - if len(node.outputs) == len(replace_with): - # Every output of the node has a replacement, the Scan - # node can be removed from the graph - replace_with["remove"] = [node] - return replace_with - else: - # The node has some outputs for which no replacement has - # been established. This can occur for outputs that are - # not produced by apply nodes (since the optimizations - # only visits apply nodes) such as constants or inputs - # passed directly as outputs. The replacements can be - # performed but the Scan node can't be removed at this - # point. - return replace_with - - else: - return False - - -@local_optimizer([Scan]) -def push_out_seq_scan(fgraph, node): - r"""Push out the variables inside the `Scan` that depend only on constants and sequences. - - This optimization resembles `push_out_non_seq_scan` but it tries to push--out of - the inner function--the computation that only relies on sequence and - non-sequence inputs. The idea behind this optimization is that, when it is - possible to do so, it is generally more computationally efficient to perform - a single operation on a large tensor rather then perform that same operation - many times on many smaller tensors. In many cases, this optimization can - increase memory usage but, in some specific cases, it can also decrease it. - """ - if not isinstance(node.op, Scan): - return False - - node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs - - local_fgraph_topo = io_toposort(node_inputs, node_outputs) - local_fgraph_outs_set = set(node_outputs) - local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)} - - to_remove_set = set() - to_replace_set = set() - to_replace_map = {} - - def add_to_replace(y): - to_replace_set.add(y) - to_replace_map[y] = add_to_replace.n - add_to_replace.n += 1 - - add_to_replace.n = 0 - - replace_with_in = [] - replace_with_out = [] - - op = node.op - # Construct the list of non_sequences to simplify a few things - inner_non_seqs = op.inner_non_seqs(node_inputs) - inner_non_seqs_set = set(inner_non_seqs) - inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)} - - outer_non_seqs = op.outer_non_seqs(node.inputs) - inner_seqs = op.inner_seqs(node_inputs) - inner_seqs_set = set(inner_seqs) - inner_seqs_map = {v: k for k, v in enumerate(inner_seqs)} - - outer_seqs = op.outer_seqs(node.inputs) - assert len(inner_non_seqs) == len(outer_non_seqs) - assert len(inner_seqs) == len(outer_seqs) - - for nd in local_fgraph_topo: - if ( - nd not in to_remove_set - and all( - (x in inner_non_seqs_set) - or (x.owner in to_remove_set) - or isinstance(x, Constant) - or (x in inner_seqs_set) - for x in nd.inputs - ) - and isinstance(nd.op, Elemwise) - ): - - outside_ins = [] - depends_on_seqs = False - - for x in nd.inputs: - if x in inner_non_seqs_set: - _idx = inner_non_seqs_map[x] - new_input = outer_non_seqs[_idx] - elif x in inner_seqs_set: - new_input = outer_seqs[inner_seqs_map[x]] - depends_on_seqs = True - elif x in to_replace_set: - new_input = replace_with_out[to_replace_map[x]] - depends_on_seqs = True - else: - assert isinstance(x, Constant) - new_input = x - - outside_ins.append(new_input) - - if not depends_on_seqs: - # Removing this node from the inner graph of scan - # should be handled by the PushOutNonSeqScan - # optimization. The current optimization only tries - # to pull sequence-dependant computation out of - # scan. - continue - - to_remove_set.add(nd) - - # Do not call make_node for test_value - nw_outer_node = nd.op.make_node(*outside_ins) - - if config.compute_test_value != "off": - compute_test_value(nw_outer_node) - - # Step 2. Create variables for replacements - for idx, y in enumerate(nd.outputs): - y_place_holder = safe_new(y, "_replace") - add_to_replace(y) - replace_with_in.append(y_place_holder) - replace_with_out.append(nw_outer_node.outputs[idx]) - - elif ( - nd not in to_remove_set - and isinstance(nd.op, DimShuffle) - and (nd.inputs[0] in inner_seqs_set or nd.inputs[0].owner in to_remove_set) - ): - - to_remove_set.add(nd) - x = nd.inputs[0] - if x in inner_seqs_set: - outside_ins = outer_seqs[inner_seqs_map[x]] - elif x in to_replace_set: - outside_ins = replace_with_out[to_replace_map[x]] - new_ord = (0,) - for old_ord in nd.op.new_order: - if old_ord == "x": - new_ord += (old_ord,) - else: - new_ord += (old_ord + 1,) - new_outer = outside_ins.dimshuffle(new_ord) - y = nd.outputs[0] - y_place_holder = safe_new(y, "_replace") - add_to_replace(y) - replace_with_in.append(y_place_holder) - replace_with_out.append(new_outer) - - if hasattr(new_outer.tag, "test_value"): - new_sh = new_outer.tag.test_value.shape - ref_sh = (outside_ins.tag.test_value.shape[0],) - ref_sh += nd.outputs[0].tag.test_value.shape - assert new_sh == ref_sh - - # We need to check all candidate replacements and choose those that - # make sense for us - # Step 1. which elements of `to_replace` are used by remaining - # components of the inner function - clean_to_replace = [] - clean_replace_with_in = [] - clean_replace_with_out = [] - - existent_nodes = [nd for nd in local_fgraph_topo if nd not in to_remove_set] - existent_nodes_set = set(existent_nodes) - - to_keep_set = set() - for nd in existent_nodes: - to_keep_set.update(nd.inputs) - - for out, idx in to_replace_map.items(): - if ( - out in to_keep_set - and out.owner not in existent_nodes_set - and - # If types are different, conversion Op will be inserted, - # and it may trigger an infinite loop. - out.type.is_super(replace_with_in[idx].type) - ): - - clean_to_replace.append(out) - clean_replace_with_in.append(replace_with_in[idx]) - clean_replace_with_out.append(replace_with_out[idx]) - - if len(clean_to_replace) > 0: - # We can finally put an end to all this madness - givens = {} - nw_outer = [] - nw_inner = [] - for to_repl, repl_in, repl_out in zip( - clean_to_replace, clean_replace_with_in, clean_replace_with_out - ): - if isinstance(repl_out, Constant): - repl_in = repl_out - else: - nw_inner.append(repl_in) - nw_outer.append(repl_out) - - givens[to_repl] = repl_in - - op_outs = clone_replace(node_outputs, replace=givens) - op_ins = nw_inner + node_inputs - - # Reconstruct node - nw_info = dataclasses.replace(op.info, n_seqs=op.info.n_seqs + len(nw_inner)) - nwScan = Scan( - op_ins, - op_outs, - nw_info, - mode=op.mode, - profile=op.profile, - truncate_gradient=op.truncate_gradient, - # TODO: This seems questionable - name=op.name, - allow_gc=op.allow_gc, - ) - # Do not call make_node for test_value - nw_node = nwScan( - *(node.inputs[:1] + nw_outer + node.inputs[1:]), - return_list=True, - )[0].owner - - replacements = dict(zip(node.outputs, nw_node.outputs)) - replacements["remove"] = [node] - return replacements - - elif not to_keep_set and not op.info.as_while and not op.outer_mitmot(node.inputs): - # Nothing in the inner graph should be kept - replace_with = {} - for out, idx in to_replace_map.items(): - if out in local_fgraph_outs_set: - x = node.outputs[local_fgraph_outs_map[out]] - _y = replace_with_out[idx] - ls = node_outputs - if out in op.inner_mitsot_outs(ls): - odx = op.inner_mitsot_outs(ls).index(out) - inp = op.outer_mitsot(node.inputs)[odx] - st = abs(np.min(op.info.mit_sot_in_slices)) - y = set_subtensor(inp[st:], _y) - elif out in op.inner_sitsot_outs(ls): - odx = op.inner_sitsot_outs(ls).index(out) - inp = op.outer_sitsot(node.inputs)[odx] - y = set_subtensor(inp[1:], _y) - elif out in op.inner_nitsot_outs(ls): - y = _y - else: - y = _y[-1] - replace_with[x] = y - - # We need to add one extra dimension to the outputs - if replace_with and len(replace_with) == len(node.outputs): - replacements = dict(replace_with.items()) - replacements["remove"] = [node] - return replacements - else: - return False - - -def inner_sitsot_only_last_step_used( - fgraph: FunctionGraph, var: Variable, scan_args: ScanArgs -) -> bool: - """ - Given a inner nit-sot output of `Scan`, return ``True`` iff the outer - nit-sot output has only one client and that client is a `Subtensor` - instance that takes only the last step (last element along the first - axis). - """ - idx = scan_args.inner_out_sit_sot.index(var) - outer_var = scan_args.outer_out_sit_sot[idx] - - if len(fgraph.clients[outer_var]) == 1: - client = fgraph.clients[outer_var][0][0] - if isinstance(client, Apply) and isinstance(client.op, Subtensor): - lst = get_idx_list(client.inputs, client.op.idx_list) - if len(lst) == 1 and at.extract_constant(lst[0]) == -1: - return True - - return False - - -def get_outer_ndim(var: Variable, scan_args: ScanArgs) -> int: - """Determine the number of dimension a variable would have if it was pushed out of a `Scan`.""" - assert isinstance(var.type, HasShape) - - if var in scan_args.inner_in_non_seqs or isinstance(var, Constant): - outer_ndim = var.type.ndim - else: - outer_ndim = var.type.ndim + 1 - - return outer_ndim - - -def push_out_inner_vars( - fgraph: FunctionGraph, - inner_vars: List[Variable], - old_scan_node: Apply, - old_scan_args: ScanArgs, -) -> Tuple[List[Variable], ScanArgs, Dict[Variable, Variable]]: - - tmp_outer_vars: List[Optional[Variable]] = [] - new_scan_node = old_scan_node - new_scan_args = old_scan_args - replacements: Dict[Variable, Variable] = {} - - # For the inner_vars that already exist in the outer graph, - # simply obtain a reference to them - for idx in range(len(inner_vars)): - - var = inner_vars[idx] - - new_outer_var: Optional[Variable] = None - - if var in old_scan_args.inner_in_seqs: - idx_seq = old_scan_args.inner_in_seqs.index(var) - new_outer_var = old_scan_args.outer_in_seqs[idx_seq] - - elif var in old_scan_args.inner_in_non_seqs: - idx_non_seq = old_scan_args.inner_in_non_seqs.index(var) - new_outer_var = old_scan_args.outer_in_non_seqs[idx_non_seq] - - elif isinstance(var, Constant): - new_outer_var = var - - elif var in old_scan_args.inner_out_nit_sot: - idx_nitsot = old_scan_args.inner_out_nit_sot.index(var) - new_outer_var = old_scan_args.outer_out_nit_sot[idx_nitsot] - - tmp_outer_vars.append(new_outer_var) - - # For the inner_vars that don't already exist in the outer graph, add - # them as new nitsot outputs to the scan node. - idx_add_as_nitsots = [i for i, v in enumerate(tmp_outer_vars) if v is None] - add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots] - - new_outs: List[Variable] = [] - - if len(add_as_nitsots) > 0: - - new_scan_node, replacements = add_nitsot_outputs( - fgraph, old_scan_node, old_scan_args, add_as_nitsots - ) - - assert isinstance(new_scan_node.op, Scan) - - new_scan_args = ScanArgs( - new_scan_node.inputs, - new_scan_node.outputs, - new_scan_node.op.inner_inputs, - new_scan_node.op.inner_outputs, - new_scan_node.op.info, - ) - - new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :] - - outer_vars: List[Variable] = [] - - for i, v in enumerate(tmp_outer_vars): - if i in idx_add_as_nitsots: - outer_vars.append(new_outs.pop(0)) - else: - assert v is not None - outer_vars.append(v) - - return outer_vars, new_scan_args, replacements - - -def add_nitsot_outputs( - fgraph: FunctionGraph, - old_scan_node: Apply, - old_scan_args: ScanArgs, - new_outputs_inner, -) -> Tuple[Apply, Dict[Variable, Variable]]: - - assert isinstance(old_scan_node.op, Scan) - - nb_new_outs = len(new_outputs_inner) - - # Create the initial values for the new nitsot outputs - # (the initial value is the nb of steps to store. For a nistot, - # it should be the number of steps performed by scan) - new_nitsots_initial_value = [old_scan_node.inputs[0] for i in range(nb_new_outs)] - - # Create the `ScanArgs` corresponding to the new `Scan` `Op` to create - new_scan_args = copy.copy(old_scan_args) - new_scan_args.inner_out_nit_sot.extend(new_outputs_inner) - new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value) - - assert isinstance(old_scan_node.op, Scan) - - # Create the `Scan` `Op` from the `ScanArgs` - new_scan_op = Scan( - new_scan_args.inner_inputs, - new_scan_args.inner_outputs, - new_scan_args.info, - mode=old_scan_node.op.mode, - profile=old_scan_node.op.profile, - truncate_gradient=old_scan_node.op.truncate_gradient, - # TODO: This seems questionable - name=old_scan_node.op.name, - allow_gc=old_scan_node.op.allow_gc, - ) - - # Create the Apply node for the scan op - new_scan_outs = new_scan_op(*new_scan_args.outer_inputs, return_list=True) - assert isinstance(new_scan_outs, list) - new_scan_node = new_scan_outs[0].owner - assert new_scan_node is not None - - # Modify the outer graph to make sure the outputs of the new scan are - # used instead of the outputs of the old scan - new_node_new_outputs_idx = len(old_scan_args.outer_outputs) - len( - old_scan_args.outer_out_shared - ) - - new_node_old_outputs = ( - new_scan_node.outputs[:new_node_new_outputs_idx] - + new_scan_node.outputs[new_node_new_outputs_idx + nb_new_outs :] - ) - - # TODO FIXME: - # replacements = dict(zip(old_scan_node.outputs, new_node_old_outputs)) - # replacements["remove"] = [old_scan_node] - # return new_scan_node, replacements - fgraph.replace_all_validate_remove( # type: ignore - list(zip(old_scan_node.outputs, new_node_old_outputs)), - remove=[old_scan_node], - reason="scan_pushout_add", - ) - return new_scan_node, {} - - -@local_optimizer([Scan]) -def push_out_add_scan(fgraph, node): - r"""Push `Add` operations performed at the end of the inner graph to the outside. - - Like `push_out_seq_scan`, this optimization aims to replace many operations - on small tensors by few operations on large tensors. It can also lead to - increased memory usage. - """ - # Don't perform the optimization on `as_while` `Scan`s. Because these - # `Scan`s don't run for a predetermined number of steps, handling them is - # more complicated and this optimization doesn't support it at the moment. - if not (isinstance(node.op, Scan) and not node.op.info.as_while): - return False - - op = node.op - - # Use `ScanArgs` to parse the inputs and outputs of scan for ease of - # use - args = ScanArgs( - node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info - ) - - clients = {} - local_fgraph_topo = io_toposort( - args.inner_inputs, args.inner_outputs, clients=clients - ) - - for nd in local_fgraph_topo: - if ( - isinstance(nd.op, Elemwise) - and isinstance(nd.op.scalar_op, aes.Add) - and nd.out in args.inner_out_sit_sot - and inner_sitsot_only_last_step_used(fgraph, nd.out, args) - ): - - # Ensure that one of the input to the add is the output of - # the add from a previous iteration of the inner function - sitsot_idx = args.inner_out_sit_sot.index(nd.out) - if args.inner_in_sit_sot[sitsot_idx] in nd.inputs: - - # Ensure that the other input to the add is a dot product - # between 2 matrices which will become a tensor3 and a - # matrix if pushed outside of the scan. Also make sure - # that the output of the Dot is ONLY used by the 'add' - # otherwise doing a Dot in the outer graph will only - # duplicate computation. - - sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx]) - - # 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0 - dot_in_idx = 1 - sitsot_in_idx - - dot_input = nd.inputs[dot_in_idx] - - if ( - dot_input.owner is not None - and isinstance(dot_input.owner.op, Dot) - and len(clients[dot_input]) == 1 - and dot_input.owner.inputs[0].ndim == 2 - and dot_input.owner.inputs[1].ndim == 2 - and get_outer_ndim(dot_input.owner.inputs[0], args) == 3 - and get_outer_ndim(dot_input.owner.inputs[1], args) == 3 - ): - - # The optimization can be be applied in this case. - - # Move out of scan the two inputs to the Dot and - # perform a dot outside of scan on these two inputs - inner_dot_inputs = nd.inputs[dot_in_idx].owner.inputs - ( - outer_dot_inputs, - new_scan_args, - replacements, - ) = push_out_inner_vars(fgraph, inner_dot_inputs, node, args) - - # Collapse some of the dimensions of the tensors - # so that they become matrices. This is because a - # dot is usually faster on two large matrices than - # a bunch of small ones - outer_dot_inputs[0] = at.flatten( - outer_dot_inputs[0].dimshuffle(1, 0, 2), ndim=2 - ) - - shape_input1 = shape(outer_dot_inputs[1]) - outer_dot_inputs[1] = outer_dot_inputs[1].reshape( - (shape_input1[0] * shape_input1[1], shape_input1[2]) - ) - - # Perform the dot on the newly obtained matrices and - # add the initial value - outer_dot_output = dot(*outer_dot_inputs) - init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0] - replacement = outer_dot_output + init_value - - # Alter the outer graph to use the output of the - # external Dot instead of the output of scan - # Modify the outer graph to add the outer Dot - outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx] - subtensor_node = fgraph.clients[outer_sitsot][0][0] - outer_sitsot_last_step = subtensor_node.outputs[0] - - replacements[outer_sitsot_last_step] = replacement - return replacements - - return False - - -class ScanInplaceOptimizer(GlobalOptimizer): - """Make `Scan`s perform in-place. - - This optimization attempts to make `Scan` compute its recurrent outputs inplace - on the input tensors that contain their initial states. This optimization can - improve runtime performance as well as reduce memory usage. - - """ - - alloc_ops = (Alloc, AllocEmpty) - """ - Classes that represent operation that allocate new memory and that the - optimization should duplicate so it can operate inplace on them. - """ - - def add_requirements(self, fgraph): - fgraph.attach_feature(ReplaceValidate()) - fgraph.attach_feature(DestroyHandler()) - - def attempt_scan_inplace( - self, fgraph: FunctionGraph, node: Apply[Scan], output_indices: List[int] - ) -> Optional[Apply]: - """Attempt to replace a `Scan` node by one which computes the specified outputs inplace. - - Parameters - ---------- - fgraph - Function graph in which to attempt the replacement - node - Scan node to replace by an inplace version - output_indices - Indices of the outputs to attempt to compute inplace - """ - - op = node.op - - # inputs corresponding to sequences and n_steps - ls_begin = node.inputs[: 1 + op.info.n_seqs] - ls = op.outer_mitmot(node.inputs) - ls += op.outer_mitsot(node.inputs) - ls += op.outer_sitsot(node.inputs) - ls_end = op.outer_shared(node.inputs) - ls_end += op.outer_nitsot(node.inputs) - ls_end += op.outer_non_seqs(node.inputs) - - # In `ls`, duplicate any input which has more than one client and is - # the output of an eligible allocation op - for i in range(len(ls)): - inp = ls[i] - if ( - len(fgraph.clients[inp]) > 1 - and inp.owner - and isinstance(inp.owner.op, self.alloc_ops) - ): - new_lsi = inp.owner.op.make_node(*inp.owner.inputs) - - if config.compute_test_value != "off": - compute_test_value(new_lsi) - - new_lsi_out = new_lsi.outputs - - if len(new_lsi_out) == 1: - new_lsi_out = new_lsi_out[0] - - ls[i] = new_lsi_out - - n_outs = len(ls) - for idx in range(n_outs): - if ls[idx] in ls[:idx]: - ls[idx] = deep_copy_op(ls[idx]) - - inputs = ls_begin + ls + ls_end - - new_op = op.clone() - - destroy_map = op.destroy_map.copy() - for out_idx in output_indices: - destroy_map[out_idx] = [out_idx + 1 + op.info.n_seqs] - - new_op.destroy_map = destroy_map - - # Do not call make_node for test_value - new_outs = new_op(*inputs, return_list=True) - - assert isinstance(new_outs, list) - - try: - # TODO FIXME: We need to stop using this approach (i.e. attempt - # in-place replacements and wait for downstream failures to revert - # the changes). It prevents us from making smart, clear - # rewrites and it adds a lot of unnecessary overhead that - # involves dealing with inconsistent graphs. - # This whole rewrite should be a simple local rewrite, but, because - # of this awful approach, it can't be. - fgraph.replace_all_validate_remove( # type: ignore - list(zip(node.outputs, new_outs)), - remove=[node], - reason="scan_make_inplace", - ) - return cast(Apply[Scan], new_outs[0].owner) - except InconsistencyError: - # Failed moving output to be computed inplace - return None - - def apply(self, fgraph): - - for scan_idx, original_node in enumerate(reversed(fgraph.toposort())): - - if not isinstance(original_node.op, Scan): - continue - - # First attempt to make the Scan compute inplace every recurrent - # output that seems like it could be computed inplace. If that - # fails, go through these outputs individually, trying each of - # them. - op = original_node.op - n_outs = op.info.n_mit_mot + op.info.n_mit_sot + op.info.n_sit_sot - - # Generate a list of outputs on which the node could potentially - # operate inplace. - out_indices = [] - for out_idx in range(n_outs): - inp_idx = 1 + op.info.n_seqs + out_idx - inp = original_node.inputs[inp_idx] - - # If the input is from an eligible allocation node, attempt to - # be inplace on it, even if other nodes are modifying it - # inplace. - if inp.owner and isinstance(inp.owner.op, self.alloc_ops): - out_indices.append(out_idx) - continue - - # If the input is not from an eligible allocation node, only - # attempt to be inplace on it if nothing else is currently - # inplace on it. - input_used_inplace = False - for c in fgraph.clients[original_node.inputs[inp_idx]]: - client = c[0] - - # Get the indices of this client's inputs on which it - # operates inplace - if client.op.destroy_map: - # This flattens the content of destroy_map.values() - # which is a list of lists - inplace_inp_indices = sum(client.op.destroy_map.values(), []) - - inplace_inps = [client.inputs[i] for i in inplace_inp_indices] - if original_node.inputs[inp_idx] in inplace_inps: - input_used_inplace = True - break - - if not input_used_inplace: - out_indices.append(out_idx) - - if len(out_indices) == 0: - continue - - new_node = self.attempt_scan_inplace(fgraph, original_node, out_indices) - - if new_node is None: - # Making the scan compute all plausible recurrent outputs - # inplace has failed. Attempt all plausible recurrent outputs - # individually. - new_node = original_node - for pos in out_indices: - new_node = ( - self.attempt_scan_inplace(fgraph, new_node, [pos]) or new_node - ) - - -def select_min(x, y): - if x is None: - return y - if y is None: - return x - return minimum(x, y) - - -def select_max(x, y): - if x is None: - return y - if y is None: - return x - return maximum(x, y) - - -def sanitize(x): - if x is None: - return None - else: - return at.as_tensor_variable(x) - - -@local_optimizer([Scan]) -def save_mem_new_scan(fgraph, node): - r"""Graph optimizer that reduces scan memory consumption. - - This optimizations attempts to determine if a `Scan` node, during its execution, - for any of its outputs, can get away with allocating a memory buffer that is - large enough to contain some of the computed timesteps of that output but not - all of them. - - By default, during the execution of a `Scan` node, memory buffers will be - allocated to store the values computed for every output at every iteration. - However, in some cases, there are outputs for which there is only really a - need to store the most recent ``N`` values, not all of them. - - For instance, if a `Scan` node has a SITSOT output (last computed value is - fed back as an input at the next iteration) and only the last timestep of - that output is ever used in the outer function, the `ScanSaveMem` optimization - could determine that there is no need to store all computed timesteps for - that SITSOT output. Only the most recently computed timestep ever needs to - be kept in memory. - - """ - if not isinstance(node.op, Scan): - return False - - if hasattr(fgraph, "shape_feature"): - shape_of = fgraph.shape_feature.shape_of - else: - # Each access to shape_of is in a try..except block in order to - # use a default version when the variable is not in the shape_of - # dictionary. - shape_of = {} - # 1. Initialization of variables - # Note 1) We do not actually care about outputs representing shared - # variables (those have no intermediate values) so it is safer to - # ignore them and not change them in any way. To simplify the - # optimizations I construct the variable ``c_outs`` ( that counts - # outputs up to those we care) and the list ``init_l`` which for any - # output we care says the length of its initial state. Note that - # defining ``init_l`` for mit_mot sequences is a bit trickier but - # it is safe to set it to 0 - op = node.op - op_info = op.info - c_outs = ( - op_info.n_mit_mot + op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot - ) - - init_l = [0 for x in range(op_info.n_mit_mot)] - init_l += [ - abs(min(v)) for v in chain(op_info.mit_sot_in_slices, op_info.sit_sot_in_slices) - ] - init_l += [0 for x in range(op_info.n_nit_sot)] - # 2. Check the clients of each output and see for how many steps - # does scan need to run - - # This comparison checks if there is any uncounted output, which - # can only be an output corresponding to a shared variable - - # 2.1 Initialize - # global_nsteps is a dictionary having two fields ( 'real' deals - # with int values, 'sym' with symbolic ones) or None - # given that a scan op has k outputs o_1, .. o_k and each - # output has n_j clients c_1^1, c_1^2, .. c_1^{n_1}, c_2^1, .., - # global_nsteps is None if any of the clients is different - # from a subtensor or its real and sym field equal to - # max(c_i_j.idx_list[0].stop), meaning store up to which maximal - # index(step) for any output scan actually needs to compute - # In other words n_steps should be equal to this maximal ! - # Note: if we have a shared variable that gets updated at every step - # of the loop, reducing the number of steps will affect the the - # value of the shared variable after the loop so we need not to - # change the number of steps in that case. To do this we set - # global_nsteps to None which is seen as a flag that nothing needs - # to be done - assert len(node.outputs) >= c_outs - if len(node.outputs) == c_outs: - global_nsteps = {"real": -1, "sym": []} - else: - global_nsteps = None - - # Keeps track of the original slices that each client represent - slices = [None for o in node.outputs] - - # A list for each output indicating how many intermediate values - # should be stored. If negative it means none of the intermediate - # values (i.e. the output can be removed since it is not used - # afterwards in the computations), if 0 it means that all - # intermediate values are required, otherwise is up to that number - # of intermediate values - # Note that for mit_mot outputs and shared outputs we can not change - # the number of intermediate steps stored without affecting the - # result of the op - store_steps = [0 for o in range(op_info.n_mit_mot)] - store_steps += [-1 for o in node.outputs[op_info.n_mit_mot : c_outs]] - # Flag that says if an input has changed and we need to do something - # or not - flag_store = False - - # 2.2 Loop over the clients - for i, out in enumerate(node.outputs[:c_outs]): - # look at all its clients - slices[i] = [] - for cl, _ in fgraph.clients[out]: - - # 2.1 outputs of the function - # => output needs all its intermediate values - if isinstance(cl, str): - # if the node is actually an output, then - # we need to store the entire thing - global_nsteps = None - slices[i] = None - break - # 2.2 non-subtensor nodes - # => output needs all its intermediate values - elif not isinstance(cl.op, Subtensor): - global_nsteps = None - slices[i] = None - break - # 2.3 subtensor nodes - # => output might need to store just a subset of its values - else: - # 2.3.1 extract idx list of subtensor - this_slice = get_idx_list(cl.inputs, cl.op.idx_list) - if this_slice is None: - # if unable to extract idx_list - # => outputs needs all its intermediate values - global_nsteps = None - slices[i] = None - break - - # 2.3.2 extract the begin/end of the first dimension - if i >= op_info.n_mit_mot: - try: - length = shape_of[out][0] - except KeyError: - length = node.inputs[0] + init_l[i] - else: - try: - length = shape_of[out][0] - except KeyError: - length = out.shape[0] - cf_slice = get_canonical_form_slice(this_slice[0], length) - slices[i] += [(cf_slice, this_slice)] - - if isinstance(this_slice[0], slice) and this_slice[0].stop is None: - global_nsteps = None - if isinstance(cf_slice[0], slice): - stop = at.extract_constant(cf_slice[0].stop) - else: - stop = at.extract_constant(cf_slice[0]) + 1 - if stop == maxsize or stop == length: - stop = None - else: - # there is a **gotcha** here ! Namely, scan returns an - # array that contains the initial state of the output - # as well. Which means that if have a initial state of - # length 3, and you look for 5 steps you get an output - # y of length 8. If you only use y[:5], this does not - # mean that you only need to loop for 5 steps but - # actually only for 2 steps ( the first 3 are the - # initial state) - stop = stop - init_l[i] - - # 2.3.3 we might get away with less number of steps - if stop is not None and global_nsteps is not None: - # yes if it is a tensor - if isinstance(stop, Variable): - global_nsteps["sym"] += [stop] - # not if it is maxsize - elif isinstance(stop, int) and stop == maxsize: - global_nsteps = None - # yes if it is a int k, 0 < k < maxsize - elif isinstance(stop, int) and global_nsteps["real"] < stop: - global_nsteps["real"] = stop - # yes if it is a int k, 0 < k < maxsize - elif isinstance(stop, int) and stop > 0: - pass - # not otherwise - else: - global_nsteps = None - - # 2.3. Analyze global_nsteps to figure out for how many steps scan - # needs to iterate - if global_nsteps is not None: - nw_steps = node.inputs[0] - - # there are some symbolic tensors that limit the number of - # steps - if len(global_nsteps["sym"]) == 0: - sym_steps = None - else: - sym_steps = global_nsteps["sym"][0] - for c in global_nsteps["sym"][1:]: - sym_steps = maximum(sym_steps, c) - - if global_nsteps["real"] >= 0: - real_steps = global_nsteps["real"] - else: - real_steps = None - nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0]) - - # Make sure the ScanSaveMem optimization never makes the new - # number of steps to be 0 (this could happen, for instance, if - # the optimization detects that the outputs of the Scan go through - # subtensor nodes that end up taking no elements) because Scan with - # 0 iterations are not supported. Make sure the new number of steps - # is at least 1. - nw_steps = select_max(nw_steps, 1) - else: - nw_steps = node.inputs[0] - global_nsteps = None - - # 2.4 Loop over the clients again now looking just to see how many - # intermediate steps to store - for i, out in enumerate(node.outputs[:c_outs]): - # look at all its clients - for cl, _ in fgraph.clients[out]: - if isinstance(cl, str): - store_steps[i] = 0 - break - elif not isinstance(cl.op, Subtensor): - store_steps[i] = 0 - break - else: - this_slice = get_idx_list(cl.inputs, cl.op.idx_list) - if this_slice is None: - store_steps[i] = 0 - break - - if isinstance(this_slice[0], slice) and this_slice[0].start is None: - store_steps[i] = 0 - break - - if i > op_info.n_mit_mot: - length = node.inputs[0] + init_l[i] - else: - try: - length = shape_of[out][0] - except KeyError: - length = out.shape[0] - cf_slice = get_canonical_form_slice(this_slice[0], length) - - if isinstance(cf_slice[0], slice): - start = at.extract_constant(cf_slice[0].start) - else: - start = at.extract_constant(cf_slice[0]) - if start == 0 or store_steps[i] == 0: - store_steps[i] = 0 - else: - # The "+ 1" is because of the memory pre-allocation - # mechanism used to in the Scan op to reduce overhead. - # To prevent aliasing between the inputs and outputs - # of recurrent states, it requires that the buffer be - # large enough to that, the new state and the oldest - # tap needed don't occupy the sample place in the - # circular buffer. For now, this only needs to be done - # for mitsots and sitsots (because mitmots are not - # currently supported by the mechanism) and only if - # the pre-allocation mechanism is activated. - prealloc_outs = config.scan__allow_output_prealloc - - first_mitsot_idx = op_info.n_mit_mot - last_sitsot_idx = ( - op_info.n_mit_mot + op_info.n_mit_sot + op_info.n_sit_sot - 1 - ) - preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx - - if prealloc_outs and preallocable_output: - pval = select_max(nw_steps - start + init_l[i], init_l[i] + 1) - else: - pval = select_max(nw_steps - start + init_l[i], init_l[i]) - - if store_steps[i] != -1: - pval = select_max(pval, store_steps[i]) - - store_steps[i] = pval - flag_store = True - - orphane_outs = [ - i for i, x in enumerate(store_steps) if isinstance(x, int) and (x < 0) - ] - flag_store = flag_store or (len(orphane_outs) > 0) - # 3. is there anything to change ? - if flag_store or global_nsteps is not None: - # 3.1 initialize inputs for the new scan - old_outputs = [] - nw_inputs = list(node.inputs) - nw_inputs[0] = nw_steps - - # 3.2 check orphane outputs to see if we can eliminate any - required, not_required = scan_can_remove_outs(node.op, orphane_outs) - # 3.3. compose replace pairs for those nodes that need not - # to store everything in memory ( or ar orphane and required - # by the inner function .. ) - replaced_outs = [] - offset = 1 + op_info.n_seqs + op_info.n_mit_mot - for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]): - i = idx + op_info.n_mit_mot - if not (isinstance(_val, int) and _val <= 0 and i not in required): - - if idx + op_info.n_mit_mot in required: - val = 1 - else: - val = _val - # If the memory for this output has been pre-allocated - # before going into the scan op (by an alloc node) - if idx < op_info.n_mit_sot + op_info.n_sit_sot: - # In case the input is still an alloc node, we - # actually have two options: - # a) the input is a set_subtensor, in that case we - # can replace the initial tensor by a slice, - # b) it is not, and we simply take a slice of it. - # TODO: commit change below with Razvan - if ( - nw_inputs[offset + idx].owner - and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor) - and isinstance( - nw_inputs[offset + idx].owner.op.idx_list[0], slice - ) - ): - - assert isinstance( - nw_inputs[offset + idx].owner.op, IncSubtensor - ) - _nw_input = nw_inputs[offset + idx].owner.inputs[1] - cval = at.as_tensor_variable(val) - initl = at.as_tensor_variable(init_l[i]) - tmp_idx = at.switch(cval < initl, cval + initl, cval - initl) - nw_input = expand_empty(_nw_input, tmp_idx) - else: - tmp = at.as_tensor_variable(val) - initl = at.as_tensor_variable(init_l[i]) - tmp = maximum(tmp, initl) - nw_input = nw_inputs[offset + idx][:tmp] - - nw_inputs[offset + idx] = nw_input - replaced_outs.append(op_info.n_mit_mot + idx) - odx = op_info.n_mit_mot + idx - old_outputs += [ - ( - odx, - [ - x[0].outputs[0] - for x in fgraph.clients[node.outputs[odx]] - ], - ) - ] - # If there is no memory pre-allocated for this output - elif idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot: - - pos = ( - op_info.n_mit_mot - + idx - + op_info.n_seqs - + 1 - + op_info.n_shared_outs - ) - if nw_inputs[pos] == node.inputs[0]: - nw_inputs[pos] = val - odx = op_info.n_mit_mot + idx - replaced_outs.append(odx) - old_outputs += [ - ( - odx, - [ - x[0].outputs[0] - for x in fgraph.clients[node.outputs[odx]] - ], - ) - ] - # 3.4. Recompute inputs for everything else based on the new - # number of steps - if global_nsteps is not None: - for idx, val in enumerate(store_steps[op_info.n_mit_mot :]): - if val == 0: - # val == 0 means that we want to keep all intermediate - # results for that state, including the initial values. - if idx < op_info.n_mit_sot + op_info.n_sit_sot: - in_idx = offset + idx - # Number of steps in the initial state - initl = init_l[op_info.n_mit_mot + idx] - - # If the initial buffer has the form - # inc_subtensor(zeros(...)[...], _nw_input) - # we want to make the zeros tensor as small as - # possible (nw_steps + initl), and call - # inc_subtensor on that instead. - # Otherwise, simply take 0:(nw_steps+initl). - if ( - nw_inputs[in_idx].owner - and isinstance(nw_inputs[in_idx].owner.op, IncSubtensor) - and isinstance( - nw_inputs[in_idx].owner.op.idx_list[0], slice - ) - ): - _nw_input = nw_inputs[in_idx].owner.inputs[1] - nw_input = expand_empty(_nw_input, nw_steps) - nw_inputs[in_idx] = nw_input - else: - nw_input = nw_inputs[in_idx][: (initl + nw_steps)] - - elif ( - idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot - ): - in_idx = offset + idx + op_info.n_shared_outs - if nw_inputs[in_idx] == node.inputs[0]: - nw_inputs[in_idx] = nw_steps - - # 3.5 Remove unwanted orphane outputs - (inps, outs, info, node_ins, compress_map) = compress_outs( - op, not_required, nw_inputs - ) - inv_compress_map = {} - for k, v in compress_map.items(): - inv_compress_map[v] = k - - # 3.6 Compose the new scan - # TODO: currently we don't support scan with 0 step. So - # don't create one. - if at.extract_constant(node_ins[0]) == 0: - return False - - # Do not call make_node for test_value - new_op = Scan( - inps, - outs, - info, - mode=op.mode, - profile=op.profile, - truncate_gradient=op.truncate_gradient, - # TODO: This seems questionable - name=op.name, - allow_gc=op.allow_gc, - ) - new_outs = new_op(*node_ins, return_list=True) - - old_new = [] - # 3.7 Get replace pairs for those outputs that do not change - # the number of intermediate steps stored - for idx, sl in enumerate(slices): - if global_nsteps and sl is not None and store_steps[idx] == 0: - for hdx, cl in enumerate(fgraph.clients[node.outputs[idx]]): - cnf_slice, old_slices = sl[hdx] - # Sanitize the nw_slice by converting ints back into - # constants :) I only need to do this for the first - # slice since that is the only slice - - if isinstance(cnf_slice[0], slice): - fslice = slice( - sanitize(cnf_slice[0].start), - sanitize(cnf_slice[0].stop), - sanitize(cnf_slice[0].step), - ) - else: - fslice = sanitize(cnf_slice[0]) - - nw_slice = (fslice,) + tuple(old_slices[1:]) - nw_pos = inv_compress_map[idx] - - subtens = Subtensor(nw_slice) - # slice inputs - sl_ins = get_slice_elements( - nw_slice, lambda entry: isinstance(entry, Variable) - ) - new_o = subtens(new_outs[nw_pos], *sl_ins) - if new_o.ndim > 0: - new_o = new_o[:: cnf_slice[1]] - replaced_outs.append(idx) - old_new += [(cl[0].outputs[0], new_o)] - # 3.8. Get replace pairs for those outputs that change - # the number of stored intermediate steps - for pos, old_outs in old_outputs: - if len(old_outs) > 0: - nw_pos = compress_map[pos] - for k, old in enumerate(old_outs): - # Get the correct slice - cnf_slice, old_slices = slices[pos][k] - if isinstance(cnf_slice[0], slice): - start = ( - cnf_slice[0].start - - nw_steps - - init_l[pos] - + store_steps[pos] - ) - if ( - cnf_slice[0].stop is not None - and cnf_slice[0].stop != maxsize - ): - stop = ( - cnf_slice[0].stop - - nw_steps - - init_l[pos] - + store_steps[pos] - ) - else: - stop = None - nw_slice = ( - slice( - sanitize(start), - sanitize(stop), - sanitize(cnf_slice[0].step), - ), - ) + tuple(old_slices[1:]) - - else: - position = ( - cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos] - ) - - nw_slice = (sanitize(position),) + tuple(old_slices[1:]) - subtens = Subtensor(nw_slice) - sl_ins = get_slice_elements( - nw_slice, lambda entry: isinstance(entry, Variable) - ) - new_o = subtens(new_outs[nw_pos], *sl_ins) - if new_o.ndim > 0: - new_o = new_o[:: cnf_slice[1]] - old_new += [(old, new_o)] - - # 3.9. Get replace pairs for all other nodes - if flag_store or global_nsteps is not None: - for idx, o in enumerate(node.outputs): - if not (idx in replaced_outs) and idx not in not_required: - nw_pos = compress_map[idx] - old_new += [(o, new_outs[nw_pos])] - # Check if the new outputs depend on the old scan node - old_scan_is_used = [ - is_in_ancestors(new.owner, node) for old, new in old_new - ] - if any(old_scan_is_used): - return False - - replacements = dict(old_new) - - # remove = [old.owner for (old, new) in old_new] - # As Fred suggested assert that also the old node is not in - # the Graph as that will make things suboptimal - # remove.append(node) - replacements["remove"] = [node] - - return replacements - - return False - - -class ScanMerge(GlobalOptimizer): - r"""Graph optimizer that merges different scan ops. - - This optimization attempts to fuse distinct `Scan` `Op`s into a single `Scan` `Op` - that performs all the computation. The main advantage of merging `Scan` `Op`\s - together comes from the possibility of both original `Op`\s having some - computation in common. In such a setting, this computation ends up being done - twice. The fused `Scan` `Op`, however, would only need to do it once and could - therefore be more computationally efficient. Also, since every `Scan` node - involves a certain overhead, at runtime, reducing the number of `Scan` nodes in - the graph can improve performance. - - """ - - def add_requirements(self, fgraph): - fgraph.attach_feature(ReplaceValidate()) - - def merge(self, nodes): - - if nodes[0].op.info.as_while: - as_while = True - condition = nodes[0].op.inner_outputs[-1] - else: - as_while = False - - # We keep the inner_ins and inner_outs of each original node separated. - # To be able to recombine them in the right order after the clone, - # we also need to split them by types (seq, mitmot, ...). - # On the other hand, outer_ins, outer_outs and info are held together. - inner_ins = [[] for nd in nodes] - outer_ins = [] - inner_outs = [[] for nd in nodes] - outer_outs = [] - - def rename(ls, suffix): - for k in ls: - if k.name: - k.name += str(suffix) - return ls - - for idx, nd in enumerate(nodes): - inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inner_inputs), idx)) - outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx) - - mit_mot_out_slices = () - - mit_mot_in_slices = () - for idx, nd in enumerate(nodes): - inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inner_inputs), idx)) - inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.inner_outputs)) - mit_mot_in_slices += nd.op.info.mit_mot_in_slices - mit_mot_out_slices += nd.op.info.mit_mot_out_slices[: nd.op.info.n_mit_mot] - outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx) - outer_outs += nd.op.outer_mitmot_outs(nd.outputs) - - mit_sot_in_slices = () - for idx, nd in enumerate(nodes): - inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inner_inputs), idx)) - inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.inner_outputs)) - mit_sot_in_slices += nd.op.info.mit_sot_in_slices - outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx) - outer_outs += nd.op.outer_mitsot_outs(nd.outputs) - - sit_sot_in_slices = () - for idx, nd in enumerate(nodes): - inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inner_inputs), idx)) - sit_sot_in_slices += tuple((-1,) for x in range(nd.op.info.n_sit_sot)) - inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.inner_outputs)) - outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx) - outer_outs += nd.op.outer_sitsot_outs(nd.outputs) - - for idx, nd in enumerate(nodes): - # Shared - inner_ins[idx].append(rename(nd.op.inner_shared(nd.op.inner_inputs), idx)) - outer_ins += rename(nd.op.outer_shared(nd.inputs), idx) - - for idx, nd in enumerate(nodes): - # NitSot - inner_outs[idx].append(nd.op.inner_nitsot_outs(nd.op.inner_outputs)) - outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx) - outer_outs += nd.op.outer_nitsot_outs(nd.outputs) - - for idx, nd in enumerate(nodes): - # Shared - outer_outs += nd.op.outer_shared_outs(nd.outputs) - inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.inner_outputs)) - - n_non_seqs = 0 - for idx, nd in enumerate(nodes): - # Non Seqs - node_inner_non_seqs = nd.op.inner_non_seqs(nd.op.inner_inputs) - n_non_seqs += len(node_inner_non_seqs) - inner_ins[idx].append(rename(node_inner_non_seqs, idx)) - outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx) - - # Add back the number of steps - outer_ins = [nodes[0].inputs[0]] + outer_ins - - if as_while: - # add the condition, which was the one of nodes[0] - inner_outs[0].append([condition]) - - # Clone the inner graph of each node independently - for idx, nd in enumerate(nodes): - # concatenate all inner_ins and inner_outs of nd - flat_inner_ins = sum(inner_ins[idx], []) - flat_inner_outs = sum(inner_outs[idx], []) - # clone - flat_inner_ins, flat_inner_outs = reconstruct_graph( - flat_inner_ins, flat_inner_outs - ) - # split the new inner variables again in seq, mitmot, etc. - new_inner_ins = [] - count = 0 - for nl in inner_ins[idx]: - seq_len = len(nl) - new_inner_ins.append(flat_inner_ins[count : (count + seq_len)]) - count += seq_len - - new_inner_outs = [] - count = 0 - for nl in inner_outs[idx]: - seq_len = len(nl) - new_inner_outs.append(flat_inner_outs[count : (count + seq_len)]) - count += seq_len - - inner_ins[idx] = new_inner_ins - inner_outs[idx] = new_inner_outs - - # Flatten inner_ins and inner_outs so that all seqs are first, - # then mitmot, etc. - new_inner_ins = [] - new_inner_outs = [] - nb_ins_groups = len(inner_ins[0]) - nb_outs_groups = len(inner_outs[0]) - for idx, nd in enumerate(nodes): - # All inner_ins should have the same length - assert len(inner_ins[idx]) == nb_ins_groups - - # All inner_outs should have the same length, except if as_while, - # in which case the first one should have one more element - if as_while and idx > 0: - assert len(inner_outs[idx]) == nb_outs_groups - 1 - else: - assert len(inner_outs[idx]) == nb_outs_groups - - for gr_idx in range(nb_ins_groups): - for idx, nd in enumerate(nodes): - new_inner_ins += inner_ins[idx][gr_idx] - - for gr_idx in range(nb_outs_groups): - for idx, nd in enumerate(nodes): - if as_while and idx > 0 and gr_idx == (nb_outs_groups - 1): - # There is no condition on that node, skip it - pass - else: - new_inner_outs += inner_outs[idx][gr_idx] - - info = ScanInfo( - n_seqs=sum(nd.op.info.n_seqs for nd in nodes), - mit_mot_in_slices=mit_mot_in_slices, - mit_mot_out_slices=mit_mot_out_slices, - mit_sot_in_slices=mit_sot_in_slices, - sit_sot_in_slices=sit_sot_in_slices, - n_nit_sot=sum(nd.op.info.n_nit_sot for nd in nodes), - n_shared_outs=sum(nd.op.info.n_shared_outs for nd in nodes), - n_non_seqs=n_non_seqs, - as_while=as_while, - ) - - old_op = nodes[0].op - new_op = Scan( - new_inner_ins, - new_inner_outs, - info, - mode=old_op.mode, - profile=old_op.profile, - truncate_gradient=old_op.truncate_gradient, - allow_gc=old_op.allow_gc, - name="&".join([nd.op.name for nd in nodes]), - ) - new_outs = new_op(*outer_ins) - - if not isinstance(new_outs, (list, tuple)): - new_outs = [new_outs] - - return list(zip(outer_outs, new_outs)) - - def belongs_to_set(self, node, set_nodes): - """ - This function checks if node `node` belongs to `set_nodes`, in the - sense that it can be merged together with every other node in - `set_nodes`. In order for two nodes to be mergeable, they have to go - over the same number of steps, have the same condition (if any), - have the same value for truncate_gradient, and have the same mode. - Questionable, we should also consider profile ? - - """ - rep = set_nodes[0] - if ( - rep.op.info.as_while != node.op.info.as_while - or node.op.truncate_gradient != rep.op.truncate_gradient - or node.op.mode != rep.op.mode - ): - return False - - nsteps = node.inputs[0] - try: - nsteps = int(get_scalar_constant_value(nsteps)) - except NotScalarConstantError: - pass - - rep_nsteps = rep.inputs[0] - try: - rep_nsteps = int(get_scalar_constant_value(rep_nsteps)) - except NotScalarConstantError: - pass - - if nsteps != rep_nsteps: - return False - - # Check to see if it is an input of a different node - for nd in set_nodes: - if is_in_ancestors(node, nd) or is_in_ancestors(nd, node): - return False - - if not node.op.info.as_while: - return True - cond = node.op.inner_outputs[-1] - rep_cond = rep.op.inner_outputs[-1] - return equal_computations( - [cond], [rep_cond], node.op.inner_inputs, rep.op.inner_inputs - ) - - def apply(self, fgraph): - # Collect all scan nodes ordered according to toposort - scan_nodes = [nd for nd in fgraph.toposort() if isinstance(nd.op, Scan)] - - # All sets of possibly mergeable nodes - all_sets = [] - - for nd in scan_nodes: - belongs_to_set_idx = -1 - for pos, subset in enumerate(all_sets): - if self.belongs_to_set(nd, subset): - belongs_to_set_idx = pos - # It is possible that nd belongs to more than one subset. - # For instance, if we have 3 Scan nodes X, Y and Z, if Z - # depends on the output of X, then X and Z are incompatible - # and would create different subsets, but Y could be - # compatible with both X and Z. We choose the first one. - break - - if belongs_to_set_idx == -1: - all_sets.append([nd]) - else: - all_sets[belongs_to_set_idx].append(nd) - - for subset in all_sets: - if len(subset) > 1: - proposal = self.merge(subset) - fgraph.replace_all_validate_remove( - proposal, remove=subset, reason="scan_merge" - ) - - -def has_duplicates(l): - """ - Returns true if l has any duplicates (according to __eq__). - - """ - return len(set(l)) < len(l) - - -def make_equiv(lo, li): - """ - Builds a dictionary of equivalences between inner inputs based on - the equivalence of their corresponding outer inputs. - - """ - seeno = {} - left = [] - right = [] - for o, i in zip(lo, li): - if o in seeno: - left += [i] - right += [o] - else: - seeno[o] = i - return left, right - - -@local_optimizer([Scan]) -def scan_merge_inouts(fgraph, node): - """ - This optimization attempts to merge a `Scan` `Op`'s identical outer inputs as well - as merge its identical outer outputs (outputs that perform the same - computation on the same inputs). This can reduce the amount of computation as - well as result in a simpler graph for both the inner function and the outer - function. - """ - if not isinstance(node.op, Scan): - return False - - # Do a first pass to merge identical external inputs. - # Equivalent inputs will be stored in inp_equiv, then a new - # scan node created without duplicates. - a = ScanArgs( - node.inputs, - node.outputs, - node.op.inner_inputs, - node.op.inner_outputs, - node.op.info, - ) - - inp_equiv = {} - - if has_duplicates(a.outer_in_seqs): - new_outer_seqs = [] - new_inner_seqs = [] - for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs): - if out_seq in new_outer_seqs: - i = new_outer_seqs.index(out_seq) - inp_equiv[in_seq] = new_inner_seqs[i] - else: - new_outer_seqs.append(out_seq) - new_inner_seqs.append(in_seq) - a.outer_in_seqs = new_outer_seqs - a.inner_in_seqs = new_inner_seqs - - if has_duplicates(a.outer_in_non_seqs): - new_outer_nseqs = [] - new_inner_nseqs = [] - for out_nseq, in_nseq in zip(a.outer_in_non_seqs, a.inner_in_non_seqs): - if out_nseq in new_outer_nseqs: - i = new_outer_nseqs.index(out_nseq) - inp_equiv[in_nseq] = new_inner_nseqs[i] - else: - new_outer_nseqs.append(out_nseq) - new_inner_nseqs.append(in_nseq) - a.outer_in_non_seqs = new_outer_nseqs - a.inner_in_non_seqs = new_inner_nseqs - - if len(inp_equiv) > 0: - # do the replacement now. The rest will be left to ScanSaveMem - inner_inputs = a.inner_inputs - outer_inputs = a.outer_inputs - info = a.info - a_inner_outs = a.inner_outputs - inner_outputs = clone_replace(a_inner_outs, replace=inp_equiv) - - new_op = Scan( - inner_inputs, - inner_outputs, - info, - mode=node.op.mode, - profile=node.op.profile, - truncate_gradient=node.op.truncate_gradient, - # TODO: This seems questionable - name=node.op.name, - allow_gc=node.op.allow_gc, - ) - outputs = new_op(*outer_inputs) - - if not isinstance(outputs, (list, tuple)): - outputs = [outputs] - - na = ScanArgs( - outer_inputs, - outputs, - new_op.inner_inputs, - new_op.inner_outputs, - new_op.info, - ) - remove = [node] - else: - na = a - remove = [] - - # Now that the identical external inputs have been merged, we do a new - # loop in order to merge external outputs that compute the same things - # from the same inputs. - left = [] - right = [] - - if has_duplicates(na.outer_in_shared): - _left, _right = make_equiv(na.outer_in_shared, na.inner_in_shared) - left += _left - right += _right - if has_duplicates(na.outer_in_sit_sot): - _left, _right = make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot) - left += _left - right += _right - if has_duplicates(na.outer_in_mit_mot): - seen = {} - for omm, imm, _sl in zip( - na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices - ): - sl = tuple(_sl) - if (omm, sl) in seen: - simm = seen[(omm, sl)] - left += imm - right += simm - else: - seen[(omm, sl)] = imm - - if has_duplicates(na.outer_in_mit_sot): - seen = {} - for oms, ims, _sl in zip( - na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices - ): - sl = tuple(_sl) - if (oms, sl) in seen: - sims = seen[(oms, sl)] - left += ims - right += sims - else: - seen[(oms, sl)] = ims - - def map_out(outer_i, inner_o, outer_o, seen): - # Return the outer input corresponding to an - # (outer input, inner output) pair. If we see that pair for the first - # time, return the provided outer output. If an equivalent pair had - # already been seen, return that one instead. - # Note that we need to check that the outer input match as well, - # because they could have different sizes, and the corresponding - # outer outputs cannot be merged in that case. - for s_outer_i, s_inner_o, s_outer_o in seen: - if ( - equal_computations([inner_o], [s_inner_o], left, right) - and outer_i == s_outer_i - ): - return s_outer_o - seen.append((outer_i, inner_o, outer_o)) - return outer_o - - seen = [] - - assert len(na.outer_in_nit_sot) == len(na.inner_out_nit_sot) - assert len(na.inner_out_nit_sot) == len(na.outer_out_nit_sot) - na.outer_out_nit_sot = [ - map_out(outer_i, inner_o, outer_o, seen) - for outer_i, inner_o, outer_o in zip( - na.outer_in_nit_sot, na.inner_out_nit_sot, na.outer_out_nit_sot - ) - ] - - seen = [] - assert len(na.outer_in_sit_sot) == len(na.inner_out_sit_sot) - assert len(na.inner_out_sit_sot) == len(na.outer_out_sit_sot) - na.outer_out_sit_sot = [ - map_out(outer_i, inner_o, outer_o, seen) - for outer_i, inner_o, outer_o in zip( - na.outer_in_sit_sot, na.inner_out_sit_sot, na.outer_out_sit_sot - ) - ] - - seen = [] - assert len(na.outer_in_mit_sot) == len(na.inner_out_mit_sot) - assert len(na.inner_out_mit_sot) == len(na.outer_out_mit_sot) - na.outer_out_mit_sot = [ - map_out(outer_i, inner_o, outer_o, seen) - for outer_i, inner_o, outer_o in zip( - na.outer_in_mit_sot, na.inner_out_mit_sot, na.outer_out_mit_sot - ) - ] - - seen = [] - new_outer_out_mit_mot = [] - assert len(na.outer_in_mit_mot) == len(na.inner_out_mit_mot) - assert len(na.inner_out_mit_mot) == len(na.outer_out_mit_mot) - assert len(na.outer_out_mit_mot) == len(na.mit_mot_out_slices) - for outer_imm, inner_omm, outer_omm, osl in zip( - na.outer_in_mit_mot, - na.inner_out_mit_mot, - na.outer_out_mit_mot, - na.mit_mot_out_slices, - ): - for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen: - if ( - osl == sosl - and equal_computations(inner_omm, s_inner_omm, left, right) - and outer_imm == s_outer_imm - ): - - new_outer_out_mit_mot.append(s_outer_omm) - break - else: - seen.append((outer_imm, inner_omm, outer_omm, osl)) - new_outer_out_mit_mot.append(outer_omm) - na.outer_out_mit_mot = new_outer_out_mit_mot - if remove: - return dict([("remove", remove)] + list(zip(node.outputs, na.outer_outputs))) - return na.outer_outputs - - -@local_optimizer([Scan]) -def push_out_dot1_scan(fgraph, node): - r""" - This is another optimization that attempts to detect certain patterns of - computation in a `Scan` `Op`'s inner function and move this computation to the - outer graph. - """ - if not isinstance(node.op, Scan): - return False - - # Replace pattern of the form - # x[t] = x[t-1] + dot(seq[t], value) - # with Sequence.reshape((-1, seq.shape[2])) \dot Value - # When seq[t] is a vector/matrix and `value` is a matrix - # Note that this works when only you need X[-1] in the end - # and assumes dimshuffle are applied to vectors before calling dot - op = node.op - sitsot_ins = op.inner_sitsot(op.inner_inputs) - sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) - outer_sitsot = op.outer_sitsot_outs(node.outputs) - seqs = op.inner_seqs(op.inner_inputs) - for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot): - - if ( - out.owner - and isinstance(out.owner.op, Elemwise) - and isinstance(out.owner.op.scalar_op, aes.Add) - and inp in out.owner.inputs - and len(fgraph.clients[outer_out]) == 1 - and not isinstance(fgraph.clients[outer_out][0][0], str) - and isinstance(fgraph.clients[outer_out][0][0].op, Subtensor) - and fgraph.clients[outer_out][0][0].op.idx_list == (-1,) - ): - - x = out.owner.inputs[0] - if x == inp: - x = out.owner.inputs[1] - # We need to check if x is the result of an outer product - if ( - x.owner - and isinstance(x.owner.op, Dot) - and x.owner.inputs[0].ndim == 2 - and x.owner.inputs[1].ndim == 2 - ): - - # We need to check if any of the inputs are a sequence - inp1 = x.owner.inputs[0] - inp2 = x.owner.inputs[1] - - if inp1 in seqs or inp2 in seqs: - new_scan_out = inp1 - - if inp1 in seqs: - new_scan_out = inp2 - idx = sitsot_outs.index(out) - # We've found our pattern and need to construct a new - # scan node to replace this one. For this we need to - # replace the sit_sot output with a nit_sot output - - # First let us split all arguments according to their - # corresponding categories - - inner_seqs = op.inner_seqs(op.inner_inputs) - outer_seqs = op.outer_seqs(node.inputs) - inner_mitmot = op.inner_mitmot(op.inner_inputs) - outer_mitmot = op.outer_mitmot(node.inputs) - inner_mitmot_outs = op.inner_mitmot_outs(op.inner_outputs) - inner_mitsot = op.inner_mitsot(op.inner_inputs) - outer_mitsot = op.outer_mitsot(node.inputs) - inner_mitsot_outs = op.inner_mitsot_outs(op.inner_outputs) - inner_sitsot = op.inner_sitsot(op.inner_inputs) - outer_sitsot = op.outer_sitsot(node.inputs) - inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) - outer_nitsot = op.outer_nitsot(node.inputs) - inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs) - inner_shared = op.inner_shared(op.inner_inputs) - outer_shared = op.outer_shared(node.inputs) - inner_shared_outs = op.inner_shared_outs(op.inner_outputs) - inner_non_seqs = op.inner_non_seqs(op.inner_inputs) - outer_non_seqs = op.outer_non_seqs(node.inputs) - - new_info = dataclasses.replace( - op.info, - sit_sot_in_slices=op.info.sit_sot_in_slices[:idx] - + op.info.sit_sot_in_slices[idx + 1 :], - n_nit_sot=op.info.n_nit_sot + 1, - ) - inner_sitsot = inner_sitsot[:idx] + inner_sitsot[idx + 1 :] - outer_sitsot = outer_sitsot[:idx] + outer_sitsot[idx + 1 :] - inner_sitsot_outs = ( - inner_sitsot_outs[:idx] + inner_sitsot_outs[idx + 1 :] - ) - # add n_steps as the length - inner_nitsot_outs.append(new_scan_out) - - _new_inner_inps = ( - inner_seqs - + inner_mitmot - + inner_mitsot - + inner_sitsot - + inner_shared - + inner_non_seqs - ) - _new_inner_outs = ( - inner_mitmot_outs - + inner_mitsot_outs - + inner_sitsot_outs - + inner_nitsot_outs - + inner_shared_outs - ) - new_inner_inps, new_inner_outs = reconstruct_graph( - _new_inner_inps, _new_inner_outs - ) - new_op = Scan( - new_inner_inps, - new_inner_outs, - new_info, - mode=op.mode, - profile=op.profile, - truncate_gradient=op.truncate_gradient, - # TODO: This seems questionable - name=op.name, - allow_gc=op.allow_gc, - ) - _scan_inputs = ( - [node.inputs[0]] - + outer_seqs - + outer_mitmot - + outer_mitsot - + outer_sitsot - + outer_shared - + outer_nitsot - + [node.inputs[0]] - + outer_non_seqs - ) - - new_outs = new_op(*_scan_inputs) - if not isinstance(new_outs, (list, tuple)): - new_outs = [new_outs] - - # We need now to pair correctly the new outputs - # with the old ones - - outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs) - - _val = outer_nitsot_outs[-1] - outer_nitsot_outs = outer_nitsot_outs[:-1] - if inp1 in seqs: - _out_seq = op.outer_seqs(node.inputs)[seqs.index(inp1)] - # We need to clip the seq to the number of steps - _out_seq = _out_seq[: node.inputs[0]] - sh0 = _out_seq.shape[0] - sh1 = _out_seq.shape[1] - sh2 = _out_seq.shape[2] - out_seq = _out_seq.dimshuffle(1, 0, 2) - out_seq = out_seq.reshape((sh1, sh0 * sh2)) - sh0 = _val.shape[0] - sh1 = _val.shape[1] - sh2 = _val.shape[2] - - val = _val.reshape((sh0 * sh1, sh2)) - new_out = dot(out_seq, val) - else: - _out_seq = op.outer_seqs(node.inputs)[seqs.index(inp2)] - out_seq = _out_seq.reshape( - ( - _out_seq.shape[0] * _out_seq.shape[1], - _out_seq.shape[2], - ) - ) - - val = _val.dimshuffle(1, 0, 2).reshape( - (_val.shape[1], _val.shape[0] * _val.shape[2]) - ) - new_out = dot(val, out_seq) - - pos = node.outputs.index(outer_out) - old_new = list(zip(node.outputs[:pos], new_outs[:pos])) - old = fgraph.clients[node.outputs[pos]][0][0].outputs[0] - old_new.append((old, new_out)) - old_new += list(zip(node.outputs[pos + 1 :], new_outs[pos:])) - replacements = dict(old_new) - replacements["remove"] = [node] - return replacements - - return False - - -# I've added an equilibrium because later scan optimization in the sequence -# can make it such that earlier optimizations should apply. However, in -# general I do not expect the sequence to run more then once -scan_eqopt1 = EquilibriumDB() -scan_seqopt1 = SequenceDB() -scan_eqopt2 = EquilibriumDB() - -# scan_eqopt1 before ShapeOpt at 0.1 -# This is needed to don't have ShapeFeature trac old Scan that we -# don't want to reintroduce. -optdb.register("scan_eqopt1", scan_eqopt1, "fast_run", "scan", position=0.05) -# We run before blas opt at 1.7 and specialize 2.0 -# but after stabilize at 1.5. Should we put it before stabilize? -optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6) -# ScanSaveMem should execute only once per node. -optdb.register( - "scan_save_mem", - in2out(save_mem_new_scan, ignore_newtrees=True), - "fast_run", - "scan", - position=1.61, -) -optdb.register( - "scan_make_inplace", - ScanInplaceOptimizer(), - "fast_run", - "inplace", - "scan", - position=75, -) - -scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan", position=1) - - -scan_seqopt1.register( - "scan_remove_constants_and_unused_inputs0", - in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), - "remove_constants_and_unused_inputs_scan", - "fast_run", - "scan", - position=1, -) - - -scan_seqopt1.register( - "scan_pushout_nonseqs_ops", - in2out(push_out_non_seq_scan, ignore_newtrees=True), - "fast_run", - "scan", - "scan_pushout", - position=2, -) - - -scan_seqopt1.register( - "scan_pushout_seqs_ops", - in2out(push_out_seq_scan, ignore_newtrees=True), - "fast_run", - "scan", - "scan_pushout", - position=3, -) - - -scan_seqopt1.register( - "scan_pushout_dot1", - in2out(push_out_dot1_scan, ignore_newtrees=True), - "fast_run", - "more_mem", - "scan", - "scan_pushout", - position=4, -) - - -scan_seqopt1.register( - "scan_pushout_add", - # TODO: Perhaps this should be an `EquilibriumOptimizer`? - in2out(push_out_add_scan, ignore_newtrees=False), - "fast_run", - "more_mem", - "scan", - "scan_pushout", - position=5, -) - - -scan_eqopt2.register( - "constant_folding_for_scan2", - in2out(basic_opt.constant_folding, ignore_newtrees=True), - "fast_run", - "scan", - position=1, -) - - -scan_eqopt2.register( - "scan_remove_constants_and_unused_inputs1", - in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), - "remove_constants_and_unused_inputs_scan", - "fast_run", - "scan", - position=2, -) - - -# after const merge but before stabilize so that we can have identity -# for equivalent nodes but we still have the chance to hoist stuff out -# of the scan later. -scan_eqopt2.register("scan_merge", ScanMerge(), "fast_run", "scan", position=4) - -# After Merge optimization -scan_eqopt2.register( - "scan_remove_constants_and_unused_inputs2", - in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), - "remove_constants_and_unused_inputs_scan", - "fast_run", - "scan", - position=5, -) - -scan_eqopt2.register( - "scan_merge_inouts", - in2out(scan_merge_inouts, ignore_newtrees=True), - "fast_run", - "scan", - position=6, -) - -# After everything else -scan_eqopt2.register( - "scan_remove_constants_and_unused_inputs3", - in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), - "remove_constants_and_unused_inputs_scan", - "fast_run", - "scan", - position=8, -) +from aesara.scan.rewriting import * # noqa: F401 E402 F403 diff --git a/aesara/scan/rewriting.py b/aesara/scan/rewriting.py new file mode 100644 index 0000000000..f63db7b74c --- /dev/null +++ b/aesara/scan/rewriting.py @@ -0,0 +1,2479 @@ +"""This module provides optimizations for the `Scan` `Op`.""" + +import copy +import dataclasses +from itertools import chain +from sys import maxsize +from typing import Dict, List, Optional, Tuple, cast + +import numpy as np + +import aesara +from aesara import scalar as aes +from aesara import tensor as at +from aesara.compile import optdb +from aesara.compile.function.types import deep_copy_op +from aesara.configdefaults import config +from aesara.graph.basic import ( + Apply, + Constant, + Variable, + clone_replace, + equal_computations, + graph_inputs, + io_toposort, + is_in_ancestors, +) +from aesara.graph.destroyhandler import DestroyHandler +from aesara.graph.features import ReplaceValidate +from aesara.graph.fg import FunctionGraph +from aesara.graph.op import compute_test_value +from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter +from aesara.graph.rewriting.db import EquilibriumDB, SequenceDB +from aesara.graph.type import HasShape +from aesara.graph.utils import InconsistencyError +from aesara.scan.op import Scan, ScanInfo +from aesara.scan.utils import ( + ScanArgs, + compress_outs, + expand_empty, + reconstruct_graph, + safe_new, + scan_can_remove_outs, +) +from aesara.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value +from aesara.tensor.elemwise import DimShuffle, Elemwise +from aesara.tensor.exceptions import NotScalarConstantError +from aesara.tensor.math import Dot, dot, maximum, minimum +from aesara.tensor.rewriting.basic import constant_folding, local_useless_switch +from aesara.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs +from aesara.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink +from aesara.tensor.shape import shape +from aesara.tensor.subtensor import ( + IncSubtensor, + Subtensor, + get_canonical_form_slice, + get_idx_list, + get_slice_elements, + set_subtensor, +) +from aesara.tensor.var import TensorConstant, get_unique_value + + +list_opt_slice = [ + local_abs_merge, + local_mul_switch_sink, + local_upcast_elemwise_constant_inputs, + local_useless_switch, + constant_folding, +] + + +@node_rewriter([Scan]) +def remove_constants_and_unused_inputs_scan(fgraph, node): + """Move constants into the inner graph, and remove unused inputs. + + Constants that are in the outer graph are represented by a free symbolic + variable in the inner graph. If we move them into the inner graph, + constant-folding can happen in the inner graph. + This is applied only on sequences and non-sequences, + not on initial states. + + """ + if not isinstance(node.op, Scan): + return False + op = node.op + op_info = op.info + # We only need to take care of sequences and other arguments + st = op_info.n_seqs + st += int( + sum(len(x) for x in chain(op_info.mit_mot_in_slices, op_info.mit_sot_in_slices)) + ) + st += op_info.n_sit_sot + st += op_info.n_shared_outs + + op_ins = op.inner_inputs + op_outs = op.inner_outputs + + # Corresponds to the initial states, which should stay untouched. + # We put those variables aside, and put them back at the end. + out_stuff_inner = op_ins[op_info.n_seqs : st] + + non_seqs = op_ins[st:] + st = ( + op_info.n_seqs + + op_info.n_mit_mot + + op_info.n_mit_sot + + op_info.n_sit_sot + + op_info.n_nit_sot + + op_info.n_shared_outs + + 1 + ) + outer_non_seqs = node.inputs[st:] + out_stuff_outer = node.inputs[1 + op_info.n_seqs : st] + + # To replace constants in the outer graph by clones in the inner graph + givens = {} + # All the inputs of the inner graph of the new scan + nw_inner = [] + # Same for the outer graph, initialized w/ number of steps + nw_outer = [node.inputs[0]] + + all_ins = list(graph_inputs(op_outs)) + for idx in range(op_info.n_seqs): + node_inp = node.inputs[idx + 1] + if ( + isinstance(node_inp, TensorConstant) + and get_unique_value(node_inp) is not None + ): + try: + # This works if input is a constant that has all entries + # equal + givens[op_ins[idx]] = node_inp[0] + except TypeError: + pass + elif op_ins[idx] in all_ins: + # Check for identical other sequence + identical_seqs = [ + x for x in nw_outer if equal_computations([x], [node_inp]) + ] + if identical_seqs: + index = node.inputs.index(identical_seqs[0]) - 1 + givens[op_ins[idx]] = op_ins[index] + else: + nw_inner.append(op_ins[idx]) + nw_outer.append(node_inp) + + nw_n_seqs = len(nw_inner) + # Add outputs stuff + nw_inner += out_stuff_inner + nw_outer += out_stuff_outer + + # Look through non sequences + nw_inner_nonseq = [] + nw_outer_nonseq = [] + for idx, (nw_in, nw_out) in enumerate(zip(non_seqs, outer_non_seqs)): + if isinstance(nw_out, Constant): + givens[nw_in] = nw_out + elif nw_in in all_ins: + # Indices of elements of nw_outer_nonseq that are equivalent + # to nw_out. + identical_nonseq_idx = [ + i + for (i, x) in enumerate(nw_outer_nonseq) + if equal_computations([x], [nw_out]) + ] + if identical_nonseq_idx: + givens[nw_in] = nw_inner_nonseq[identical_nonseq_idx[0]] + else: + nw_inner_nonseq.append(nw_in) + nw_outer_nonseq.append(nw_out) + + nw_inner.extend(nw_inner_nonseq) + nw_outer.extend(nw_outer_nonseq) + + if len(nw_inner) != len(op_ins): + op_outs = clone_replace(op_outs, replace=givens) + nw_info = dataclasses.replace( + op_info, n_seqs=nw_n_seqs, n_non_seqs=len(nw_inner_nonseq) + ) + nwScan = Scan( + nw_inner, + op_outs, + nw_info, + mode=op.mode, + profile=op.profile, + truncate_gradient=op.truncate_gradient, + # TODO: This seems questionable + name=op.name, + allow_gc=op.allow_gc, + ) + nw_outs = nwScan(*nw_outer, return_list=True) + return dict([("remove", [node])] + list(zip(node.outputs, nw_outs))) + else: + return False + + +@node_rewriter([Scan]) +def push_out_non_seq_scan(fgraph, node): + r"""Push out the variables inside the `Scan` that depend only on non-sequences. + + This optimizations pushes, out of `Scan`'s inner function and into the outer + function, computation that depends only on non-sequence inputs. Such + computation ends up being done every iteration on the same values so moving + it to the outer function to be executed only once, before the `Scan` `Op`, + reduces the amount of computation that needs to be performed. + """ + if not isinstance(node.op, Scan): + return False + + node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs + + local_fgraph_topo = io_toposort(node_inputs, node_outputs) + local_fgraph_outs_set = set(node_outputs) + local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)} + + to_remove_set = set() + to_replace_set = set() + to_replace_map = {} + + def add_to_replace(y): + to_replace_set.add(y) + to_replace_map[y] = add_to_replace.n + add_to_replace.n += 1 + + add_to_replace.n = 0 + + # The variables that will replace the variables pushed-out of the + # inner-graph + replace_with_in = [] + # The variables that have been pushed-out of the graph + replace_with_out = [] + + op = node.op + # Construct the list of non_sequences to simplify a few things + inner_non_seqs = op.inner_non_seqs(node_inputs) + inner_non_seqs_set = set(inner_non_seqs) + inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)} + + outer_non_seqs = op.outer_non_seqs(node.inputs) + + inner_seqs = op.inner_seqs(node_inputs) + outer_seqs = op.outer_seqs(node.inputs) + + assert len(inner_non_seqs) == len(outer_non_seqs) + assert len(inner_seqs) == len(outer_seqs) + + for nd in local_fgraph_topo: + if ( # we haven't already looked at this node + nd not in to_remove_set + and all( + ( + (x in inner_non_seqs_set) + or (x.owner in to_remove_set) + or isinstance(x, Constant) + ) + for x in nd.inputs + ) + # We can (supposedly) do this because the assumption is that a + # `ViewOp` or `DeepCopyOp` will be just at the end of the + # function and not somewhere in the middle + and not isinstance(nd.op, aesara.compile.ViewOp) + and not isinstance(nd.op, aesara.compile.DeepCopyOp) + ): + # We have a candidate node to remove from the inner-graph + + # Step 1. Reconstruct the node using the relevant outer-inputs. + # + # More specifically, the node's current inputs are either + # a) inner-graph input place-holders for non-sequences, + # b) the outputs of other nodes being pushed out of the inner-graph, + # c) or constants. + to_remove_set.add(nd) + new_inputs = [] + for old_input in nd.inputs: + if old_input in inner_non_seqs_set: + # This is case a), so we want to use the corresponding + # outer-graph input as the input to our new pushed-out node + _idx = inner_non_seqs_map[old_input] + new_input = outer_non_seqs[_idx] + elif old_input in to_replace_set: + # This is case b), so we want to use the new pushed-out node + # as the input to this new pushed-out node + new_input = replace_with_out[to_replace_map[old_input]] + else: + assert isinstance(old_input, Constant) + new_input = old_input + + new_input = old_input.type.filter_variable(new_input) + new_inputs.append(new_input) + + pushed_out_node = nd.op.make_node(*new_inputs) + + if config.compute_test_value != "off": + compute_test_value(pushed_out_node) + + # Step 2. Create variables to replace the old outputs of the node + # that we're pushing out of the inner-graph + for idx, y in enumerate(nd.outputs): + y_place_holder = y.clone() + # y_place_holder = safe_new(y, "_replace") + add_to_replace(y) + replace_with_in.append(y_place_holder) + assert isinstance(y, type(pushed_out_node.outputs[idx])) + replace_with_out.append(pushed_out_node.outputs[idx]) + + # We need to check all candidate replacements and choose those that + # make sense for us + # Step 1. which elements of `to_replace` are used by remaining + # components of the inner function + clean_to_replace = [] + clean_replace_with_in = [] + clean_replace_with_out = [] + existent_nodes = [nd for nd in local_fgraph_topo if nd not in to_remove_set] + existent_nodes_set = set(existent_nodes) + + to_keep_set = set() + for nd in existent_nodes: + to_keep_set.update(nd.inputs) + + for out, idx in to_replace_map.items(): + if ( # If types are different, conversion Op will be inserted, + # and it may trigger an infinite loop. + out.type.is_super(replace_with_in[idx].type) + and out in to_keep_set + and out.owner not in existent_nodes_set + ): + clean_to_replace.append(out) + clean_replace_with_in.append(replace_with_in[idx]) + clean_replace_with_out.append(replace_with_out[idx]) + + if len(clean_to_replace) > 0: + # We can finally put an end to all this madness + givens = {} + nw_outer = [] + nw_inner = [] + for to_repl, repl_in, repl_out in zip( + clean_to_replace, clean_replace_with_in, clean_replace_with_out + ): + if isinstance(repl_out, Constant): + repl_in = repl_out + else: + nw_inner.append(repl_in) + nw_outer.append(repl_out) + givens[to_repl] = repl_in + + op_outs = clone_replace(node_outputs, replace=givens) + op_ins = node_inputs + nw_inner + + new_info = dataclasses.replace( + op.info, n_non_seqs=op.info.n_non_seqs + len(nw_outer) + ) + + # Reconstruct node + nwScan = Scan( + op_ins, + op_outs, + new_info, + mode=op.mode, + profile=op.profile, + truncate_gradient=op.truncate_gradient, + # TODO: This seems questionable + name=op.name, + allow_gc=op.allow_gc, + ) + + # Do not call make_node for test_value + nw_node = nwScan(*(node.inputs + nw_outer), return_list=True)[0].owner + + replacements = dict(zip(node.outputs, nw_node.outputs)) + replacements["remove"] = [node] + return replacements + elif not to_keep_set: + # Nothing in the inner graph should be kept + replace_with = {} + for out, idx in to_replace_map.items(): + if out in local_fgraph_outs_set: + x = node.outputs[local_fgraph_outs_map[out]] + y = replace_with_out[idx] + y_shape = [shp for shp in y.shape] + replace_with[x] = at.alloc(y, node.inputs[0], *y_shape) + + # We need to add one extra dimension to the outputs + # because the scan op expects for a tensor3, to which an + # subtensor is applied that takes only the last element + if replace_with: + if len(node.outputs) == len(replace_with): + # Every output of the node has a replacement, the Scan + # node can be removed from the graph + replace_with["remove"] = [node] + return replace_with + else: + # The node has some outputs for which no replacement has + # been established. This can occur for outputs that are + # not produced by apply nodes (since the optimizations + # only visits apply nodes) such as constants or inputs + # passed directly as outputs. The replacements can be + # performed but the Scan node can't be removed at this + # point. + return replace_with + + else: + return False + + +@node_rewriter([Scan]) +def push_out_seq_scan(fgraph, node): + r"""Push out the variables inside the `Scan` that depend only on constants and sequences. + + This optimization resembles `push_out_non_seq_scan` but it tries to push--out of + the inner function--the computation that only relies on sequence and + non-sequence inputs. The idea behind this optimization is that, when it is + possible to do so, it is generally more computationally efficient to perform + a single operation on a large tensor rather then perform that same operation + many times on many smaller tensors. In many cases, this optimization can + increase memory usage but, in some specific cases, it can also decrease it. + """ + if not isinstance(node.op, Scan): + return False + + node_inputs, node_outputs = node.op.inner_inputs, node.op.inner_outputs + + local_fgraph_topo = io_toposort(node_inputs, node_outputs) + local_fgraph_outs_set = set(node_outputs) + local_fgraph_outs_map = {v: k for k, v in enumerate(node_outputs)} + + to_remove_set = set() + to_replace_set = set() + to_replace_map = {} + + def add_to_replace(y): + to_replace_set.add(y) + to_replace_map[y] = add_to_replace.n + add_to_replace.n += 1 + + add_to_replace.n = 0 + + replace_with_in = [] + replace_with_out = [] + + op = node.op + # Construct the list of non_sequences to simplify a few things + inner_non_seqs = op.inner_non_seqs(node_inputs) + inner_non_seqs_set = set(inner_non_seqs) + inner_non_seqs_map = {v: k for k, v in enumerate(inner_non_seqs)} + + outer_non_seqs = op.outer_non_seqs(node.inputs) + inner_seqs = op.inner_seqs(node_inputs) + inner_seqs_set = set(inner_seqs) + inner_seqs_map = {v: k for k, v in enumerate(inner_seqs)} + + outer_seqs = op.outer_seqs(node.inputs) + assert len(inner_non_seqs) == len(outer_non_seqs) + assert len(inner_seqs) == len(outer_seqs) + + for nd in local_fgraph_topo: + if ( + nd not in to_remove_set + and all( + (x in inner_non_seqs_set) + or (x.owner in to_remove_set) + or isinstance(x, Constant) + or (x in inner_seqs_set) + for x in nd.inputs + ) + and isinstance(nd.op, Elemwise) + ): + + outside_ins = [] + depends_on_seqs = False + + for x in nd.inputs: + if x in inner_non_seqs_set: + _idx = inner_non_seqs_map[x] + new_input = outer_non_seqs[_idx] + elif x in inner_seqs_set: + new_input = outer_seqs[inner_seqs_map[x]] + depends_on_seqs = True + elif x in to_replace_set: + new_input = replace_with_out[to_replace_map[x]] + depends_on_seqs = True + else: + assert isinstance(x, Constant) + new_input = x + + outside_ins.append(new_input) + + if not depends_on_seqs: + # Removing this node from the inner graph of scan + # should be handled by the PushOutNonSeqScan + # optimization. The current optimization only tries + # to pull sequence-dependant computation out of + # scan. + continue + + to_remove_set.add(nd) + + # Do not call make_node for test_value + nw_outer_node = nd.op.make_node(*outside_ins) + + if config.compute_test_value != "off": + compute_test_value(nw_outer_node) + + # Step 2. Create variables for replacements + for idx, y in enumerate(nd.outputs): + y_place_holder = safe_new(y, "_replace") + add_to_replace(y) + replace_with_in.append(y_place_holder) + replace_with_out.append(nw_outer_node.outputs[idx]) + + elif ( + nd not in to_remove_set + and isinstance(nd.op, DimShuffle) + and (nd.inputs[0] in inner_seqs_set or nd.inputs[0].owner in to_remove_set) + ): + + to_remove_set.add(nd) + x = nd.inputs[0] + if x in inner_seqs_set: + outside_ins = outer_seqs[inner_seqs_map[x]] + elif x in to_replace_set: + outside_ins = replace_with_out[to_replace_map[x]] + new_ord = (0,) + for old_ord in nd.op.new_order: + if old_ord == "x": + new_ord += (old_ord,) + else: + new_ord += (old_ord + 1,) + new_outer = outside_ins.dimshuffle(new_ord) + y = nd.outputs[0] + y_place_holder = safe_new(y, "_replace") + add_to_replace(y) + replace_with_in.append(y_place_holder) + replace_with_out.append(new_outer) + + if hasattr(new_outer.tag, "test_value"): + new_sh = new_outer.tag.test_value.shape + ref_sh = (outside_ins.tag.test_value.shape[0],) + ref_sh += nd.outputs[0].tag.test_value.shape + assert new_sh == ref_sh + + # We need to check all candidate replacements and choose those that + # make sense for us + # Step 1. which elements of `to_replace` are used by remaining + # components of the inner function + clean_to_replace = [] + clean_replace_with_in = [] + clean_replace_with_out = [] + + existent_nodes = [nd for nd in local_fgraph_topo if nd not in to_remove_set] + existent_nodes_set = set(existent_nodes) + + to_keep_set = set() + for nd in existent_nodes: + to_keep_set.update(nd.inputs) + + for out, idx in to_replace_map.items(): + if ( + out in to_keep_set + and out.owner not in existent_nodes_set + and + # If types are different, conversion Op will be inserted, + # and it may trigger an infinite loop. + out.type.is_super(replace_with_in[idx].type) + ): + + clean_to_replace.append(out) + clean_replace_with_in.append(replace_with_in[idx]) + clean_replace_with_out.append(replace_with_out[idx]) + + if len(clean_to_replace) > 0: + # We can finally put an end to all this madness + givens = {} + nw_outer = [] + nw_inner = [] + for to_repl, repl_in, repl_out in zip( + clean_to_replace, clean_replace_with_in, clean_replace_with_out + ): + if isinstance(repl_out, Constant): + repl_in = repl_out + else: + nw_inner.append(repl_in) + nw_outer.append(repl_out) + + givens[to_repl] = repl_in + + op_outs = clone_replace(node_outputs, replace=givens) + op_ins = nw_inner + node_inputs + + # Reconstruct node + nw_info = dataclasses.replace(op.info, n_seqs=op.info.n_seqs + len(nw_inner)) + nwScan = Scan( + op_ins, + op_outs, + nw_info, + mode=op.mode, + profile=op.profile, + truncate_gradient=op.truncate_gradient, + # TODO: This seems questionable + name=op.name, + allow_gc=op.allow_gc, + ) + # Do not call make_node for test_value + nw_node = nwScan( + *(node.inputs[:1] + nw_outer + node.inputs[1:]), + return_list=True, + )[0].owner + + replacements = dict(zip(node.outputs, nw_node.outputs)) + replacements["remove"] = [node] + return replacements + + elif not to_keep_set and not op.info.as_while and not op.outer_mitmot(node.inputs): + # Nothing in the inner graph should be kept + replace_with = {} + for out, idx in to_replace_map.items(): + if out in local_fgraph_outs_set: + x = node.outputs[local_fgraph_outs_map[out]] + _y = replace_with_out[idx] + ls = node_outputs + if out in op.inner_mitsot_outs(ls): + odx = op.inner_mitsot_outs(ls).index(out) + inp = op.outer_mitsot(node.inputs)[odx] + st = abs(np.min(op.info.mit_sot_in_slices)) + y = set_subtensor(inp[st:], _y) + elif out in op.inner_sitsot_outs(ls): + odx = op.inner_sitsot_outs(ls).index(out) + inp = op.outer_sitsot(node.inputs)[odx] + y = set_subtensor(inp[1:], _y) + elif out in op.inner_nitsot_outs(ls): + y = _y + else: + y = _y[-1] + replace_with[x] = y + + # We need to add one extra dimension to the outputs + if replace_with and len(replace_with) == len(node.outputs): + replacements = dict(replace_with.items()) + replacements["remove"] = [node] + return replacements + else: + return False + + +def inner_sitsot_only_last_step_used( + fgraph: FunctionGraph, var: Variable, scan_args: ScanArgs +) -> bool: + """ + Given a inner nit-sot output of `Scan`, return ``True`` iff the outer + nit-sot output has only one client and that client is a `Subtensor` + instance that takes only the last step (last element along the first + axis). + """ + idx = scan_args.inner_out_sit_sot.index(var) + outer_var = scan_args.outer_out_sit_sot[idx] + + if len(fgraph.clients[outer_var]) == 1: + client = fgraph.clients[outer_var][0][0] + if isinstance(client, Apply) and isinstance(client.op, Subtensor): + lst = get_idx_list(client.inputs, client.op.idx_list) + if len(lst) == 1 and at.extract_constant(lst[0]) == -1: + return True + + return False + + +def get_outer_ndim(var: Variable, scan_args: ScanArgs) -> int: + """Determine the number of dimension a variable would have if it was pushed out of a `Scan`.""" + assert isinstance(var.type, HasShape) + + if var in scan_args.inner_in_non_seqs or isinstance(var, Constant): + outer_ndim = var.type.ndim + else: + outer_ndim = var.type.ndim + 1 + + return outer_ndim + + +def push_out_inner_vars( + fgraph: FunctionGraph, + inner_vars: List[Variable], + old_scan_node: Apply, + old_scan_args: ScanArgs, +) -> Tuple[List[Variable], ScanArgs, Dict[Variable, Variable]]: + + tmp_outer_vars: List[Optional[Variable]] = [] + new_scan_node = old_scan_node + new_scan_args = old_scan_args + replacements: Dict[Variable, Variable] = {} + + # For the inner_vars that already exist in the outer graph, + # simply obtain a reference to them + for idx in range(len(inner_vars)): + + var = inner_vars[idx] + + new_outer_var: Optional[Variable] = None + + if var in old_scan_args.inner_in_seqs: + idx_seq = old_scan_args.inner_in_seqs.index(var) + new_outer_var = old_scan_args.outer_in_seqs[idx_seq] + + elif var in old_scan_args.inner_in_non_seqs: + idx_non_seq = old_scan_args.inner_in_non_seqs.index(var) + new_outer_var = old_scan_args.outer_in_non_seqs[idx_non_seq] + + elif isinstance(var, Constant): + new_outer_var = var + + elif var in old_scan_args.inner_out_nit_sot: + idx_nitsot = old_scan_args.inner_out_nit_sot.index(var) + new_outer_var = old_scan_args.outer_out_nit_sot[idx_nitsot] + + tmp_outer_vars.append(new_outer_var) + + # For the inner_vars that don't already exist in the outer graph, add + # them as new nitsot outputs to the scan node. + idx_add_as_nitsots = [i for i, v in enumerate(tmp_outer_vars) if v is None] + add_as_nitsots = [inner_vars[idx] for idx in idx_add_as_nitsots] + + new_outs: List[Variable] = [] + + if len(add_as_nitsots) > 0: + + new_scan_node, replacements = add_nitsot_outputs( + fgraph, old_scan_node, old_scan_args, add_as_nitsots + ) + + assert isinstance(new_scan_node.op, Scan) + + new_scan_args = ScanArgs( + new_scan_node.inputs, + new_scan_node.outputs, + new_scan_node.op.inner_inputs, + new_scan_node.op.inner_outputs, + new_scan_node.op.info, + ) + + new_outs = new_scan_args.outer_out_nit_sot[-len(add_as_nitsots) :] + + outer_vars: List[Variable] = [] + + for i, v in enumerate(tmp_outer_vars): + if i in idx_add_as_nitsots: + outer_vars.append(new_outs.pop(0)) + else: + assert v is not None + outer_vars.append(v) + + return outer_vars, new_scan_args, replacements + + +def add_nitsot_outputs( + fgraph: FunctionGraph, + old_scan_node: Apply, + old_scan_args: ScanArgs, + new_outputs_inner, +) -> Tuple[Apply, Dict[Variable, Variable]]: + + assert isinstance(old_scan_node.op, Scan) + + nb_new_outs = len(new_outputs_inner) + + # Create the initial values for the new nitsot outputs + # (the initial value is the nb of steps to store. For a nistot, + # it should be the number of steps performed by scan) + new_nitsots_initial_value = [old_scan_node.inputs[0] for i in range(nb_new_outs)] + + # Create the `ScanArgs` corresponding to the new `Scan` `Op` to create + new_scan_args = copy.copy(old_scan_args) + new_scan_args.inner_out_nit_sot.extend(new_outputs_inner) + new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value) + + assert isinstance(old_scan_node.op, Scan) + + # Create the `Scan` `Op` from the `ScanArgs` + new_scan_op = Scan( + new_scan_args.inner_inputs, + new_scan_args.inner_outputs, + new_scan_args.info, + mode=old_scan_node.op.mode, + profile=old_scan_node.op.profile, + truncate_gradient=old_scan_node.op.truncate_gradient, + # TODO: This seems questionable + name=old_scan_node.op.name, + allow_gc=old_scan_node.op.allow_gc, + ) + + # Create the Apply node for the scan op + new_scan_outs = new_scan_op(*new_scan_args.outer_inputs, return_list=True) + assert isinstance(new_scan_outs, list) + new_scan_node = new_scan_outs[0].owner + assert new_scan_node is not None + + # Modify the outer graph to make sure the outputs of the new scan are + # used instead of the outputs of the old scan + new_node_new_outputs_idx = len(old_scan_args.outer_outputs) - len( + old_scan_args.outer_out_shared + ) + + new_node_old_outputs = ( + new_scan_node.outputs[:new_node_new_outputs_idx] + + new_scan_node.outputs[new_node_new_outputs_idx + nb_new_outs :] + ) + + # TODO FIXME: + # replacements = dict(zip(old_scan_node.outputs, new_node_old_outputs)) + # replacements["remove"] = [old_scan_node] + # return new_scan_node, replacements + fgraph.replace_all_validate_remove( # type: ignore + list(zip(old_scan_node.outputs, new_node_old_outputs)), + remove=[old_scan_node], + reason="scan_pushout_add", + ) + return new_scan_node, {} + + +@node_rewriter([Scan]) +def push_out_add_scan(fgraph, node): + r"""Push `Add` operations performed at the end of the inner graph to the outside. + + Like `push_out_seq_scan`, this optimization aims to replace many operations + on small tensors by few operations on large tensors. It can also lead to + increased memory usage. + """ + # Don't perform the optimization on `as_while` `Scan`s. Because these + # `Scan`s don't run for a predetermined number of steps, handling them is + # more complicated and this optimization doesn't support it at the moment. + if not (isinstance(node.op, Scan) and not node.op.info.as_while): + return False + + op = node.op + + # Use `ScanArgs` to parse the inputs and outputs of scan for ease of + # use + args = ScanArgs( + node.inputs, node.outputs, op.inner_inputs, op.inner_outputs, op.info + ) + + clients = {} + local_fgraph_topo = io_toposort( + args.inner_inputs, args.inner_outputs, clients=clients + ) + + for nd in local_fgraph_topo: + if ( + isinstance(nd.op, Elemwise) + and isinstance(nd.op.scalar_op, aes.Add) + and nd.out in args.inner_out_sit_sot + and inner_sitsot_only_last_step_used(fgraph, nd.out, args) + ): + + # Ensure that one of the input to the add is the output of + # the add from a previous iteration of the inner function + sitsot_idx = args.inner_out_sit_sot.index(nd.out) + if args.inner_in_sit_sot[sitsot_idx] in nd.inputs: + + # Ensure that the other input to the add is a dot product + # between 2 matrices which will become a tensor3 and a + # matrix if pushed outside of the scan. Also make sure + # that the output of the Dot is ONLY used by the 'add' + # otherwise doing a Dot in the outer graph will only + # duplicate computation. + + sitsot_in_idx = nd.inputs.index(args.inner_in_sit_sot[sitsot_idx]) + + # 0 if sitsot_in_idx==1, 1 if sitsot_in_idx==0 + dot_in_idx = 1 - sitsot_in_idx + + dot_input = nd.inputs[dot_in_idx] + + if ( + dot_input.owner is not None + and isinstance(dot_input.owner.op, Dot) + and len(clients[dot_input]) == 1 + and dot_input.owner.inputs[0].ndim == 2 + and dot_input.owner.inputs[1].ndim == 2 + and get_outer_ndim(dot_input.owner.inputs[0], args) == 3 + and get_outer_ndim(dot_input.owner.inputs[1], args) == 3 + ): + + # The optimization can be be applied in this case. + + # Move out of scan the two inputs to the Dot and + # perform a dot outside of scan on these two inputs + inner_dot_inputs = nd.inputs[dot_in_idx].owner.inputs + ( + outer_dot_inputs, + new_scan_args, + replacements, + ) = push_out_inner_vars(fgraph, inner_dot_inputs, node, args) + + # Collapse some of the dimensions of the tensors + # so that they become matrices. This is because a + # dot is usually faster on two large matrices than + # a bunch of small ones + outer_dot_inputs[0] = at.flatten( + outer_dot_inputs[0].dimshuffle(1, 0, 2), ndim=2 + ) + + shape_input1 = shape(outer_dot_inputs[1]) + outer_dot_inputs[1] = outer_dot_inputs[1].reshape( + (shape_input1[0] * shape_input1[1], shape_input1[2]) + ) + + # Perform the dot on the newly obtained matrices and + # add the initial value + outer_dot_output = dot(*outer_dot_inputs) + init_value = new_scan_args.outer_in_sit_sot[sitsot_idx][0] + replacement = outer_dot_output + init_value + + # Alter the outer graph to use the output of the + # external Dot instead of the output of scan + # Modify the outer graph to add the outer Dot + outer_sitsot = new_scan_args.outer_out_sit_sot[sitsot_idx] + subtensor_node = fgraph.clients[outer_sitsot][0][0] + outer_sitsot_last_step = subtensor_node.outputs[0] + + replacements[outer_sitsot_last_step] = replacement + return replacements + + return False + + +class ScanInplaceOptimizer(GraphRewriter): + """Make `Scan`s perform in-place. + + This optimization attempts to make `Scan` compute its recurrent outputs inplace + on the input tensors that contain their initial states. This optimization can + improve runtime performance as well as reduce memory usage. + + """ + + alloc_ops = (Alloc, AllocEmpty) + """ + Classes that represent operation that allocate new memory and that the + optimization should duplicate so it can operate inplace on them. + """ + + def add_requirements(self, fgraph): + fgraph.attach_feature(ReplaceValidate()) + fgraph.attach_feature(DestroyHandler()) + + def attempt_scan_inplace( + self, fgraph: FunctionGraph, node: Apply[Scan], output_indices: List[int] + ) -> Optional[Apply]: + """Attempt to replace a `Scan` node by one which computes the specified outputs inplace. + + Parameters + ---------- + fgraph + Function graph in which to attempt the replacement + node + Scan node to replace by an inplace version + output_indices + Indices of the outputs to attempt to compute inplace + """ + + op = node.op + + # inputs corresponding to sequences and n_steps + ls_begin = node.inputs[: 1 + op.info.n_seqs] + ls = op.outer_mitmot(node.inputs) + ls += op.outer_mitsot(node.inputs) + ls += op.outer_sitsot(node.inputs) + ls_end = op.outer_shared(node.inputs) + ls_end += op.outer_nitsot(node.inputs) + ls_end += op.outer_non_seqs(node.inputs) + + # In `ls`, duplicate any input which has more than one client and is + # the output of an eligible allocation op + for i in range(len(ls)): + inp = ls[i] + if ( + len(fgraph.clients[inp]) > 1 + and inp.owner + and isinstance(inp.owner.op, self.alloc_ops) + ): + new_lsi = inp.owner.op.make_node(*inp.owner.inputs) + + if config.compute_test_value != "off": + compute_test_value(new_lsi) + + new_lsi_out = new_lsi.outputs + + if len(new_lsi_out) == 1: + new_lsi_out = new_lsi_out[0] + + ls[i] = new_lsi_out + + n_outs = len(ls) + for idx in range(n_outs): + if ls[idx] in ls[:idx]: + ls[idx] = deep_copy_op(ls[idx]) + + inputs = ls_begin + ls + ls_end + + new_op = op.clone() + + destroy_map = op.destroy_map.copy() + for out_idx in output_indices: + destroy_map[out_idx] = [out_idx + 1 + op.info.n_seqs] + + new_op.destroy_map = destroy_map + + # Do not call make_node for test_value + new_outs = new_op(*inputs, return_list=True) + + assert isinstance(new_outs, list) + + try: + # TODO FIXME: We need to stop using this approach (i.e. attempt + # in-place replacements and wait for downstream failures to revert + # the changes). It prevents us from making smart, clear + # rewrites and it adds a lot of unnecessary overhead that + # involves dealing with inconsistent graphs. + # This whole rewrite should be a simple local rewrite, but, because + # of this awful approach, it can't be. + fgraph.replace_all_validate_remove( # type: ignore + list(zip(node.outputs, new_outs)), + remove=[node], + reason="scan_make_inplace", + ) + return cast(Apply[Scan], new_outs[0].owner) + except InconsistencyError: + # Failed moving output to be computed inplace + return None + + def apply(self, fgraph): + + for scan_idx, original_node in enumerate(reversed(fgraph.toposort())): + + if not isinstance(original_node.op, Scan): + continue + + # First attempt to make the Scan compute inplace every recurrent + # output that seems like it could be computed inplace. If that + # fails, go through these outputs individually, trying each of + # them. + op = original_node.op + n_outs = op.info.n_mit_mot + op.info.n_mit_sot + op.info.n_sit_sot + + # Generate a list of outputs on which the node could potentially + # operate inplace. + out_indices = [] + for out_idx in range(n_outs): + inp_idx = 1 + op.info.n_seqs + out_idx + inp = original_node.inputs[inp_idx] + + # If the input is from an eligible allocation node, attempt to + # be inplace on it, even if other nodes are modifying it + # inplace. + if inp.owner and isinstance(inp.owner.op, self.alloc_ops): + out_indices.append(out_idx) + continue + + # If the input is not from an eligible allocation node, only + # attempt to be inplace on it if nothing else is currently + # inplace on it. + input_used_inplace = False + for c in fgraph.clients[original_node.inputs[inp_idx]]: + client = c[0] + + # Get the indices of this client's inputs on which it + # operates inplace + if client.op.destroy_map: + # This flattens the content of destroy_map.values() + # which is a list of lists + inplace_inp_indices = sum(client.op.destroy_map.values(), []) + + inplace_inps = [client.inputs[i] for i in inplace_inp_indices] + if original_node.inputs[inp_idx] in inplace_inps: + input_used_inplace = True + break + + if not input_used_inplace: + out_indices.append(out_idx) + + if len(out_indices) == 0: + continue + + new_node = self.attempt_scan_inplace(fgraph, original_node, out_indices) + + if new_node is None: + # Making the scan compute all plausible recurrent outputs + # inplace has failed. Attempt all plausible recurrent outputs + # individually. + + new_node = original_node + for pos in out_indices: + new_node = ( + self.attempt_scan_inplace(fgraph, new_node, [pos]) or new_node + ) + + +def select_min(x, y): + if x is None: + return y + if y is None: + return x + return minimum(x, y) + + +def select_max(x, y): + if x is None: + return y + if y is None: + return x + return maximum(x, y) + + +def sanitize(x): + if x is None: + return None + else: + return at.as_tensor_variable(x) + + +@node_rewriter([Scan]) +def save_mem_new_scan(fgraph, node): + r"""Graph optimizer that reduces scan memory consumption. + + This optimizations attempts to determine if a `Scan` node, during its execution, + for any of its outputs, can get away with allocating a memory buffer that is + large enough to contain some of the computed timesteps of that output but not + all of them. + + By default, during the execution of a `Scan` node, memory buffers will be + allocated to store the values computed for every output at every iteration. + However, in some cases, there are outputs for which there is only really a + need to store the most recent ``N`` values, not all of them. + + For instance, if a `Scan` node has a SITSOT output (last computed value is + fed back as an input at the next iteration) and only the last timestep of + that output is ever used in the outer function, the `ScanSaveMem` optimization + could determine that there is no need to store all computed timesteps for + that SITSOT output. Only the most recently computed timestep ever needs to + be kept in memory. + + """ + if not isinstance(node.op, Scan): + return False + + if hasattr(fgraph, "shape_feature"): + shape_of = fgraph.shape_feature.shape_of + else: + # Each access to shape_of is in a try..except block in order to + # use a default version when the variable is not in the shape_of + # dictionary. + shape_of = {} + # 1. Initialization of variables + # Note 1) We do not actually care about outputs representing shared + # variables (those have no intermediate values) so it is safer to + # ignore them and not change them in any way. To simplify the + # optimizations I construct the variable ``c_outs`` ( that counts + # outputs up to those we care) and the list ``init_l`` which for any + # output we care says the length of its initial state. Note that + # defining ``init_l`` for mit_mot sequences is a bit trickier but + # it is safe to set it to 0 + op = node.op + op_info = op.info + c_outs = ( + op_info.n_mit_mot + op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot + ) + + init_l = [0 for x in range(op_info.n_mit_mot)] + init_l += [ + abs(min(v)) for v in chain(op_info.mit_sot_in_slices, op_info.sit_sot_in_slices) + ] + init_l += [0 for x in range(op_info.n_nit_sot)] + # 2. Check the clients of each output and see for how many steps + # does scan need to run + + # This comparison checks if there is any uncounted output, which + # can only be an output corresponding to a shared variable + + # 2.1 Initialize + # global_nsteps is a dictionary having two fields ( 'real' deals + # with int values, 'sym' with symbolic ones) or None + # given that a scan op has k outputs o_1, .. o_k and each + # output has n_j clients c_1^1, c_1^2, .. c_1^{n_1}, c_2^1, .., + # global_nsteps is None if any of the clients is different + # from a subtensor or its real and sym field equal to + # max(c_i_j.idx_list[0].stop), meaning store up to which maximal + # index(step) for any output scan actually needs to compute + # In other words n_steps should be equal to this maximal ! + # Note: if we have a shared variable that gets updated at every step + # of the loop, reducing the number of steps will affect the the + # value of the shared variable after the loop so we need not to + # change the number of steps in that case. To do this we set + # global_nsteps to None which is seen as a flag that nothing needs + # to be done + assert len(node.outputs) >= c_outs + if len(node.outputs) == c_outs: + global_nsteps = {"real": -1, "sym": []} + else: + global_nsteps = None + + # Keeps track of the original slices that each client represent + slices = [None for o in node.outputs] + + # A list for each output indicating how many intermediate values + # should be stored. If negative it means none of the intermediate + # values (i.e. the output can be removed since it is not used + # afterwards in the computations), if 0 it means that all + # intermediate values are required, otherwise is up to that number + # of intermediate values + # Note that for mit_mot outputs and shared outputs we can not change + # the number of intermediate steps stored without affecting the + # result of the op + store_steps = [0 for o in range(op_info.n_mit_mot)] + store_steps += [-1 for o in node.outputs[op_info.n_mit_mot : c_outs]] + # Flag that says if an input has changed and we need to do something + # or not + flag_store = False + + # 2.2 Loop over the clients + for i, out in enumerate(node.outputs[:c_outs]): + # look at all its clients + slices[i] = [] + for cl, _ in fgraph.clients[out]: + + # 2.1 outputs of the function + # => output needs all its intermediate values + if isinstance(cl, str): + # if the node is actually an output, then + # we need to store the entire thing + global_nsteps = None + slices[i] = None + break + # 2.2 non-subtensor nodes + # => output needs all its intermediate values + elif not isinstance(cl.op, Subtensor): + global_nsteps = None + slices[i] = None + break + # 2.3 subtensor nodes + # => output might need to store just a subset of its values + else: + # 2.3.1 extract idx list of subtensor + this_slice = get_idx_list(cl.inputs, cl.op.idx_list) + if this_slice is None: + # if unable to extract idx_list + # => outputs needs all its intermediate values + global_nsteps = None + slices[i] = None + break + + # 2.3.2 extract the begin/end of the first dimension + if i >= op_info.n_mit_mot: + try: + length = shape_of[out][0] + except KeyError: + length = node.inputs[0] + init_l[i] + else: + try: + length = shape_of[out][0] + except KeyError: + length = out.shape[0] + cf_slice = get_canonical_form_slice(this_slice[0], length) + slices[i] += [(cf_slice, this_slice)] + + if isinstance(this_slice[0], slice) and this_slice[0].stop is None: + global_nsteps = None + if isinstance(cf_slice[0], slice): + stop = at.extract_constant(cf_slice[0].stop) + else: + stop = at.extract_constant(cf_slice[0]) + 1 + if stop == maxsize or stop == length: + stop = None + else: + # there is a **gotcha** here ! Namely, scan returns an + # array that contains the initial state of the output + # as well. Which means that if have a initial state of + # length 3, and you look for 5 steps you get an output + # y of length 8. If you only use y[:5], this does not + # mean that you only need to loop for 5 steps but + # actually only for 2 steps ( the first 3 are the + # initial state) + stop = stop - init_l[i] + + # 2.3.3 we might get away with less number of steps + if stop is not None and global_nsteps is not None: + # yes if it is a tensor + if isinstance(stop, Variable): + global_nsteps["sym"] += [stop] + # not if it is maxsize + elif isinstance(stop, int) and stop == maxsize: + global_nsteps = None + # yes if it is a int k, 0 < k < maxsize + elif isinstance(stop, int) and global_nsteps["real"] < stop: + global_nsteps["real"] = stop + # yes if it is a int k, 0 < k < maxsize + elif isinstance(stop, int) and stop > 0: + pass + # not otherwise + else: + global_nsteps = None + + # 2.3. Analyze global_nsteps to figure out for how many steps scan + # needs to iterate + if global_nsteps is not None: + nw_steps = node.inputs[0] + + # there are some symbolic tensors that limit the number of + # steps + if len(global_nsteps["sym"]) == 0: + sym_steps = None + else: + sym_steps = global_nsteps["sym"][0] + for c in global_nsteps["sym"][1:]: + sym_steps = maximum(sym_steps, c) + + if global_nsteps["real"] >= 0: + real_steps = global_nsteps["real"] + else: + real_steps = None + nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0]) + + # Make sure the ScanSaveMem optimization never makes the new + # number of steps to be 0 (this could happen, for instance, if + # the optimization detects that the outputs of the Scan go through + # subtensor nodes that end up taking no elements) because Scan with + # 0 iterations are not supported. Make sure the new number of steps + # is at least 1. + nw_steps = select_max(nw_steps, 1) + else: + nw_steps = node.inputs[0] + global_nsteps = None + + # 2.4 Loop over the clients again now looking just to see how many + # intermediate steps to store + for i, out in enumerate(node.outputs[:c_outs]): + # look at all its clients + for cl, _ in fgraph.clients[out]: + if isinstance(cl, str): + store_steps[i] = 0 + break + elif not isinstance(cl.op, Subtensor): + store_steps[i] = 0 + break + else: + this_slice = get_idx_list(cl.inputs, cl.op.idx_list) + if this_slice is None: + store_steps[i] = 0 + break + + if isinstance(this_slice[0], slice) and this_slice[0].start is None: + store_steps[i] = 0 + break + + if i > op_info.n_mit_mot: + length = node.inputs[0] + init_l[i] + else: + try: + length = shape_of[out][0] + except KeyError: + length = out.shape[0] + cf_slice = get_canonical_form_slice(this_slice[0], length) + + if isinstance(cf_slice[0], slice): + start = at.extract_constant(cf_slice[0].start) + else: + start = at.extract_constant(cf_slice[0]) + if start == 0 or store_steps[i] == 0: + store_steps[i] = 0 + else: + # The "+ 1" is because of the memory pre-allocation + # mechanism used to in the Scan op to reduce overhead. + # To prevent aliasing between the inputs and outputs + # of recurrent states, it requires that the buffer be + # large enough to that, the new state and the oldest + # tap needed don't occupy the sample place in the + # circular buffer. For now, this only needs to be done + # for mitsots and sitsots (because mitmots are not + # currently supported by the mechanism) and only if + # the pre-allocation mechanism is activated. + prealloc_outs = config.scan__allow_output_prealloc + + first_mitsot_idx = op_info.n_mit_mot + last_sitsot_idx = ( + op_info.n_mit_mot + op_info.n_mit_sot + op_info.n_sit_sot - 1 + ) + preallocable_output = first_mitsot_idx <= i <= last_sitsot_idx + + if prealloc_outs and preallocable_output: + pval = select_max(nw_steps - start + init_l[i], init_l[i] + 1) + else: + pval = select_max(nw_steps - start + init_l[i], init_l[i]) + + if store_steps[i] != -1: + pval = select_max(pval, store_steps[i]) + + store_steps[i] = pval + flag_store = True + + orphane_outs = [ + i for i, x in enumerate(store_steps) if isinstance(x, int) and (x < 0) + ] + flag_store = flag_store or (len(orphane_outs) > 0) + # 3. is there anything to change ? + if flag_store or global_nsteps is not None: + # 3.1 initialize inputs for the new scan + old_outputs = [] + nw_inputs = list(node.inputs) + nw_inputs[0] = nw_steps + + # 3.2 check orphane outputs to see if we can eliminate any + required, not_required = scan_can_remove_outs(node.op, orphane_outs) + # 3.3. compose replace pairs for those nodes that need not + # to store everything in memory ( or ar orphane and required + # by the inner function .. ) + replaced_outs = [] + offset = 1 + op_info.n_seqs + op_info.n_mit_mot + for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]): + i = idx + op_info.n_mit_mot + if not (isinstance(_val, int) and _val <= 0 and i not in required): + + if idx + op_info.n_mit_mot in required: + val = 1 + else: + val = _val + # If the memory for this output has been pre-allocated + # before going into the scan op (by an alloc node) + if idx < op_info.n_mit_sot + op_info.n_sit_sot: + # In case the input is still an alloc node, we + # actually have two options: + # a) the input is a set_subtensor, in that case we + # can replace the initial tensor by a slice, + # b) it is not, and we simply take a slice of it. + # TODO: commit change below with Razvan + if ( + nw_inputs[offset + idx].owner + and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor) + and isinstance( + nw_inputs[offset + idx].owner.op.idx_list[0], slice + ) + ): + + assert isinstance( + nw_inputs[offset + idx].owner.op, IncSubtensor + ) + _nw_input = nw_inputs[offset + idx].owner.inputs[1] + cval = at.as_tensor_variable(val) + initl = at.as_tensor_variable(init_l[i]) + tmp_idx = at.switch(cval < initl, cval + initl, cval - initl) + nw_input = expand_empty(_nw_input, tmp_idx) + else: + tmp = at.as_tensor_variable(val) + initl = at.as_tensor_variable(init_l[i]) + tmp = maximum(tmp, initl) + nw_input = nw_inputs[offset + idx][:tmp] + + nw_inputs[offset + idx] = nw_input + replaced_outs.append(op_info.n_mit_mot + idx) + odx = op_info.n_mit_mot + idx + old_outputs += [ + ( + odx, + [ + x[0].outputs[0] + for x in fgraph.clients[node.outputs[odx]] + ], + ) + ] + # If there is no memory pre-allocated for this output + elif idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot: + + pos = ( + op_info.n_mit_mot + + idx + + op_info.n_seqs + + 1 + + op_info.n_shared_outs + ) + if nw_inputs[pos] == node.inputs[0]: + nw_inputs[pos] = val + odx = op_info.n_mit_mot + idx + replaced_outs.append(odx) + old_outputs += [ + ( + odx, + [ + x[0].outputs[0] + for x in fgraph.clients[node.outputs[odx]] + ], + ) + ] + # 3.4. Recompute inputs for everything else based on the new + # number of steps + if global_nsteps is not None: + for idx, val in enumerate(store_steps[op_info.n_mit_mot :]): + if val == 0: + # val == 0 means that we want to keep all intermediate + # results for that state, including the initial values. + if idx < op_info.n_mit_sot + op_info.n_sit_sot: + in_idx = offset + idx + # Number of steps in the initial state + initl = init_l[op_info.n_mit_mot + idx] + + # If the initial buffer has the form + # inc_subtensor(zeros(...)[...], _nw_input) + # we want to make the zeros tensor as small as + # possible (nw_steps + initl), and call + # inc_subtensor on that instead. + # Otherwise, simply take 0:(nw_steps+initl). + if ( + nw_inputs[in_idx].owner + and isinstance(nw_inputs[in_idx].owner.op, IncSubtensor) + and isinstance( + nw_inputs[in_idx].owner.op.idx_list[0], slice + ) + ): + _nw_input = nw_inputs[in_idx].owner.inputs[1] + nw_input = expand_empty(_nw_input, nw_steps) + nw_inputs[in_idx] = nw_input + else: + nw_input = nw_inputs[in_idx][: (initl + nw_steps)] + + elif ( + idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot + ): + in_idx = offset + idx + op_info.n_shared_outs + if nw_inputs[in_idx] == node.inputs[0]: + nw_inputs[in_idx] = nw_steps + + # 3.5 Remove unwanted orphane outputs + (inps, outs, info, node_ins, compress_map) = compress_outs( + op, not_required, nw_inputs + ) + inv_compress_map = {} + for k, v in compress_map.items(): + inv_compress_map[v] = k + + # 3.6 Compose the new scan + # TODO: currently we don't support scan with 0 step. So + # don't create one. + if at.extract_constant(node_ins[0]) == 0: + return False + + # Do not call make_node for test_value + new_op = Scan( + inps, + outs, + info, + mode=op.mode, + profile=op.profile, + truncate_gradient=op.truncate_gradient, + # TODO: This seems questionable + name=op.name, + allow_gc=op.allow_gc, + ) + new_outs = new_op(*node_ins, return_list=True) + + old_new = [] + # 3.7 Get replace pairs for those outputs that do not change + # the number of intermediate steps stored + for idx, sl in enumerate(slices): + if global_nsteps and sl is not None and store_steps[idx] == 0: + for hdx, cl in enumerate(fgraph.clients[node.outputs[idx]]): + cnf_slice, old_slices = sl[hdx] + # Sanitize the nw_slice by converting ints back into + # constants :) I only need to do this for the first + # slice since that is the only slice + + if isinstance(cnf_slice[0], slice): + fslice = slice( + sanitize(cnf_slice[0].start), + sanitize(cnf_slice[0].stop), + sanitize(cnf_slice[0].step), + ) + else: + fslice = sanitize(cnf_slice[0]) + + nw_slice = (fslice,) + tuple(old_slices[1:]) + nw_pos = inv_compress_map[idx] + + subtens = Subtensor(nw_slice) + # slice inputs + sl_ins = get_slice_elements( + nw_slice, lambda entry: isinstance(entry, Variable) + ) + new_o = subtens(new_outs[nw_pos], *sl_ins) + if new_o.ndim > 0: + new_o = new_o[:: cnf_slice[1]] + replaced_outs.append(idx) + old_new += [(cl[0].outputs[0], new_o)] + # 3.8. Get replace pairs for those outputs that change + # the number of stored intermediate steps + for pos, old_outs in old_outputs: + if len(old_outs) > 0: + nw_pos = compress_map[pos] + for k, old in enumerate(old_outs): + # Get the correct slice + cnf_slice, old_slices = slices[pos][k] + if isinstance(cnf_slice[0], slice): + start = ( + cnf_slice[0].start + - nw_steps + - init_l[pos] + + store_steps[pos] + ) + if ( + cnf_slice[0].stop is not None + and cnf_slice[0].stop != maxsize + ): + stop = ( + cnf_slice[0].stop + - nw_steps + - init_l[pos] + + store_steps[pos] + ) + else: + stop = None + nw_slice = ( + slice( + sanitize(start), + sanitize(stop), + sanitize(cnf_slice[0].step), + ), + ) + tuple(old_slices[1:]) + + else: + position = ( + cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos] + ) + + nw_slice = (sanitize(position),) + tuple(old_slices[1:]) + subtens = Subtensor(nw_slice) + sl_ins = get_slice_elements( + nw_slice, lambda entry: isinstance(entry, Variable) + ) + new_o = subtens(new_outs[nw_pos], *sl_ins) + if new_o.ndim > 0: + new_o = new_o[:: cnf_slice[1]] + old_new += [(old, new_o)] + + # 3.9. Get replace pairs for all other nodes + if flag_store or global_nsteps is not None: + for idx, o in enumerate(node.outputs): + if not (idx in replaced_outs) and idx not in not_required: + nw_pos = compress_map[idx] + old_new += [(o, new_outs[nw_pos])] + # Check if the new outputs depend on the old scan node + old_scan_is_used = [ + is_in_ancestors(new.owner, node) for old, new in old_new + ] + if any(old_scan_is_used): + return False + + replacements = dict(old_new) + + # remove = [old.owner for (old, new) in old_new] + # As Fred suggested assert that also the old node is not in + # the Graph as that will make things suboptimal + # remove.append(node) + replacements["remove"] = [node] + + return replacements + + return False + + +class ScanMerge(GraphRewriter): + r"""Graph optimizer that merges different scan ops. + + This optimization attempts to fuse distinct `Scan` `Op`s into a single `Scan` `Op` + that performs all the computation. The main advantage of merging `Scan` `Op`\s + together comes from the possibility of both original `Op`\s having some + computation in common. In such a setting, this computation ends up being done + twice. The fused `Scan` `Op`, however, would only need to do it once and could + therefore be more computationally efficient. Also, since every `Scan` node + involves a certain overhead, at runtime, reducing the number of `Scan` nodes in + the graph can improve performance. + + """ + + def add_requirements(self, fgraph): + fgraph.attach_feature(ReplaceValidate()) + + def merge(self, nodes): + + if nodes[0].op.info.as_while: + as_while = True + condition = nodes[0].op.inner_outputs[-1] + else: + as_while = False + + # We keep the inner_ins and inner_outs of each original node separated. + # To be able to recombine them in the right order after the clone, + # we also need to split them by types (seq, mitmot, ...). + # On the other hand, outer_ins, outer_outs and info are held together. + inner_ins = [[] for nd in nodes] + outer_ins = [] + inner_outs = [[] for nd in nodes] + outer_outs = [] + + def rename(ls, suffix): + for k in ls: + if k.name: + k.name += str(suffix) + return ls + + for idx, nd in enumerate(nodes): + inner_ins[idx].append(rename(nd.op.inner_seqs(nd.op.inner_inputs), idx)) + outer_ins += rename(nd.op.outer_seqs(nd.inputs), idx) + + mit_mot_out_slices = () + + mit_mot_in_slices = () + for idx, nd in enumerate(nodes): + inner_ins[idx].append(rename(nd.op.inner_mitmot(nd.op.inner_inputs), idx)) + inner_outs[idx].append(nd.op.inner_mitmot_outs(nd.op.inner_outputs)) + mit_mot_in_slices += nd.op.info.mit_mot_in_slices + mit_mot_out_slices += nd.op.info.mit_mot_out_slices[: nd.op.info.n_mit_mot] + outer_ins += rename(nd.op.outer_mitmot(nd.inputs), idx) + outer_outs += nd.op.outer_mitmot_outs(nd.outputs) + + mit_sot_in_slices = () + for idx, nd in enumerate(nodes): + inner_ins[idx].append(rename(nd.op.inner_mitsot(nd.op.inner_inputs), idx)) + inner_outs[idx].append(nd.op.inner_mitsot_outs(nd.op.inner_outputs)) + mit_sot_in_slices += nd.op.info.mit_sot_in_slices + outer_ins += rename(nd.op.outer_mitsot(nd.inputs), idx) + outer_outs += nd.op.outer_mitsot_outs(nd.outputs) + + sit_sot_in_slices = () + for idx, nd in enumerate(nodes): + inner_ins[idx].append(rename(nd.op.inner_sitsot(nd.op.inner_inputs), idx)) + sit_sot_in_slices += tuple((-1,) for x in range(nd.op.info.n_sit_sot)) + inner_outs[idx].append(nd.op.inner_sitsot_outs(nd.op.inner_outputs)) + outer_ins += rename(nd.op.outer_sitsot(nd.inputs), idx) + outer_outs += nd.op.outer_sitsot_outs(nd.outputs) + + for idx, nd in enumerate(nodes): + # Shared + inner_ins[idx].append(rename(nd.op.inner_shared(nd.op.inner_inputs), idx)) + outer_ins += rename(nd.op.outer_shared(nd.inputs), idx) + + for idx, nd in enumerate(nodes): + # NitSot + inner_outs[idx].append(nd.op.inner_nitsot_outs(nd.op.inner_outputs)) + outer_ins += rename(nd.op.outer_nitsot(nd.inputs), idx) + outer_outs += nd.op.outer_nitsot_outs(nd.outputs) + + for idx, nd in enumerate(nodes): + # Shared + outer_outs += nd.op.outer_shared_outs(nd.outputs) + inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.inner_outputs)) + + n_non_seqs = 0 + for idx, nd in enumerate(nodes): + # Non Seqs + node_inner_non_seqs = nd.op.inner_non_seqs(nd.op.inner_inputs) + n_non_seqs += len(node_inner_non_seqs) + inner_ins[idx].append(rename(node_inner_non_seqs, idx)) + outer_ins += rename(nd.op.outer_non_seqs(nd.inputs), idx) + + # Add back the number of steps + outer_ins = [nodes[0].inputs[0]] + outer_ins + + if as_while: + # add the condition, which was the one of nodes[0] + inner_outs[0].append([condition]) + + # Clone the inner graph of each node independently + for idx, nd in enumerate(nodes): + # concatenate all inner_ins and inner_outs of nd + flat_inner_ins = sum(inner_ins[idx], []) + flat_inner_outs = sum(inner_outs[idx], []) + # clone + flat_inner_ins, flat_inner_outs = reconstruct_graph( + flat_inner_ins, flat_inner_outs + ) + # split the new inner variables again in seq, mitmot, etc. + new_inner_ins = [] + count = 0 + for nl in inner_ins[idx]: + seq_len = len(nl) + new_inner_ins.append(flat_inner_ins[count : (count + seq_len)]) + count += seq_len + + new_inner_outs = [] + count = 0 + for nl in inner_outs[idx]: + seq_len = len(nl) + new_inner_outs.append(flat_inner_outs[count : (count + seq_len)]) + count += seq_len + + inner_ins[idx] = new_inner_ins + inner_outs[idx] = new_inner_outs + + # Flatten inner_ins and inner_outs so that all seqs are first, + # then mitmot, etc. + new_inner_ins = [] + new_inner_outs = [] + nb_ins_groups = len(inner_ins[0]) + nb_outs_groups = len(inner_outs[0]) + for idx, nd in enumerate(nodes): + # All inner_ins should have the same length + assert len(inner_ins[idx]) == nb_ins_groups + + # All inner_outs should have the same length, except if as_while, + # in which case the first one should have one more element + if as_while and idx > 0: + assert len(inner_outs[idx]) == nb_outs_groups - 1 + else: + assert len(inner_outs[idx]) == nb_outs_groups + + for gr_idx in range(nb_ins_groups): + for idx, nd in enumerate(nodes): + new_inner_ins += inner_ins[idx][gr_idx] + + for gr_idx in range(nb_outs_groups): + for idx, nd in enumerate(nodes): + if as_while and idx > 0 and gr_idx == (nb_outs_groups - 1): + # There is no condition on that node, skip it + pass + else: + new_inner_outs += inner_outs[idx][gr_idx] + + info = ScanInfo( + n_seqs=sum(nd.op.info.n_seqs for nd in nodes), + mit_mot_in_slices=mit_mot_in_slices, + mit_mot_out_slices=mit_mot_out_slices, + mit_sot_in_slices=mit_sot_in_slices, + sit_sot_in_slices=sit_sot_in_slices, + n_nit_sot=sum(nd.op.info.n_nit_sot for nd in nodes), + n_shared_outs=sum(nd.op.info.n_shared_outs for nd in nodes), + n_non_seqs=n_non_seqs, + as_while=as_while, + ) + + old_op = nodes[0].op + new_op = Scan( + new_inner_ins, + new_inner_outs, + info, + mode=old_op.mode, + profile=old_op.profile, + truncate_gradient=old_op.truncate_gradient, + allow_gc=old_op.allow_gc, + name="&".join([nd.op.name for nd in nodes]), + ) + new_outs = new_op(*outer_ins) + + if not isinstance(new_outs, (list, tuple)): + new_outs = [new_outs] + + return list(zip(outer_outs, new_outs)) + + def belongs_to_set(self, node, set_nodes): + """ + This function checks if node `node` belongs to `set_nodes`, in the + sense that it can be merged together with every other node in + `set_nodes`. In order for two nodes to be mergeable, they have to go + over the same number of steps, have the same condition (if any), + have the same value for truncate_gradient, and have the same mode. + Questionable, we should also consider profile ? + + """ + rep = set_nodes[0] + if ( + rep.op.info.as_while != node.op.info.as_while + or node.op.truncate_gradient != rep.op.truncate_gradient + or node.op.mode != rep.op.mode + ): + return False + + nsteps = node.inputs[0] + try: + nsteps = int(get_scalar_constant_value(nsteps)) + except NotScalarConstantError: + pass + + rep_nsteps = rep.inputs[0] + try: + rep_nsteps = int(get_scalar_constant_value(rep_nsteps)) + except NotScalarConstantError: + pass + + if nsteps != rep_nsteps: + return False + + # Check to see if it is an input of a different node + for nd in set_nodes: + if is_in_ancestors(node, nd) or is_in_ancestors(nd, node): + return False + + if not node.op.info.as_while: + return True + cond = node.op.inner_outputs[-1] + rep_cond = rep.op.inner_outputs[-1] + return equal_computations( + [cond], [rep_cond], node.op.inner_inputs, rep.op.inner_inputs + ) + + def apply(self, fgraph): + # Collect all scan nodes ordered according to toposort + scan_nodes = [nd for nd in fgraph.toposort() if isinstance(nd.op, Scan)] + + # All sets of possibly mergeable nodes + all_sets = [] + + for nd in scan_nodes: + belongs_to_set_idx = -1 + for pos, subset in enumerate(all_sets): + if self.belongs_to_set(nd, subset): + belongs_to_set_idx = pos + # It is possible that nd belongs to more than one subset. + # For instance, if we have 3 Scan nodes X, Y and Z, if Z + # depends on the output of X, then X and Z are incompatible + # and would create different subsets, but Y could be + # compatible with both X and Z. We choose the first one. + break + + if belongs_to_set_idx == -1: + all_sets.append([nd]) + else: + all_sets[belongs_to_set_idx].append(nd) + + for subset in all_sets: + if len(subset) > 1: + proposal = self.merge(subset) + fgraph.replace_all_validate_remove( + proposal, remove=subset, reason="scan_merge" + ) + + +def has_duplicates(l): + """ + Returns true if l has any duplicates (according to __eq__). + + """ + return len(set(l)) < len(l) + + +def make_equiv(lo, li): + """ + Builds a dictionary of equivalences between inner inputs based on + the equivalence of their corresponding outer inputs. + + """ + seeno = {} + left = [] + right = [] + for o, i in zip(lo, li): + if o in seeno: + left += [i] + right += [o] + else: + seeno[o] = i + return left, right + + +@node_rewriter([Scan]) +def scan_merge_inouts(fgraph, node): + """ + This optimization attempts to merge a `Scan` `Op`'s identical outer inputs as well + as merge its identical outer outputs (outputs that perform the same + computation on the same inputs). This can reduce the amount of computation as + well as result in a simpler graph for both the inner function and the outer + function. + """ + if not isinstance(node.op, Scan): + return False + + # Do a first pass to merge identical external inputs. + # Equivalent inputs will be stored in inp_equiv, then a new + # scan node created without duplicates. + a = ScanArgs( + node.inputs, + node.outputs, + node.op.inner_inputs, + node.op.inner_outputs, + node.op.info, + ) + + inp_equiv = {} + + if has_duplicates(a.outer_in_seqs): + new_outer_seqs = [] + new_inner_seqs = [] + for out_seq, in_seq in zip(a.outer_in_seqs, a.inner_in_seqs): + if out_seq in new_outer_seqs: + i = new_outer_seqs.index(out_seq) + inp_equiv[in_seq] = new_inner_seqs[i] + else: + new_outer_seqs.append(out_seq) + new_inner_seqs.append(in_seq) + a.outer_in_seqs = new_outer_seqs + a.inner_in_seqs = new_inner_seqs + + if has_duplicates(a.outer_in_non_seqs): + new_outer_nseqs = [] + new_inner_nseqs = [] + for out_nseq, in_nseq in zip(a.outer_in_non_seqs, a.inner_in_non_seqs): + if out_nseq in new_outer_nseqs: + i = new_outer_nseqs.index(out_nseq) + inp_equiv[in_nseq] = new_inner_nseqs[i] + else: + new_outer_nseqs.append(out_nseq) + new_inner_nseqs.append(in_nseq) + a.outer_in_non_seqs = new_outer_nseqs + a.inner_in_non_seqs = new_inner_nseqs + + if len(inp_equiv) > 0: + # do the replacement now. The rest will be left to ScanSaveMem + inner_inputs = a.inner_inputs + outer_inputs = a.outer_inputs + info = a.info + a_inner_outs = a.inner_outputs + inner_outputs = clone_replace(a_inner_outs, replace=inp_equiv) + + new_op = Scan( + inner_inputs, + inner_outputs, + info, + mode=node.op.mode, + profile=node.op.profile, + truncate_gradient=node.op.truncate_gradient, + # TODO: This seems questionable + name=node.op.name, + allow_gc=node.op.allow_gc, + ) + outputs = new_op(*outer_inputs) + + if not isinstance(outputs, (list, tuple)): + outputs = [outputs] + + na = ScanArgs( + outer_inputs, + outputs, + new_op.inner_inputs, + new_op.inner_outputs, + new_op.info, + ) + remove = [node] + else: + na = a + remove = [] + + # Now that the identical external inputs have been merged, we do a new + # loop in order to merge external outputs that compute the same things + # from the same inputs. + left = [] + right = [] + + if has_duplicates(na.outer_in_shared): + _left, _right = make_equiv(na.outer_in_shared, na.inner_in_shared) + left += _left + right += _right + if has_duplicates(na.outer_in_sit_sot): + _left, _right = make_equiv(na.outer_in_sit_sot, na.inner_in_sit_sot) + left += _left + right += _right + if has_duplicates(na.outer_in_mit_mot): + seen = {} + for omm, imm, _sl in zip( + na.outer_in_mit_mot, na.inner_in_mit_mot, na.mit_mot_in_slices + ): + sl = tuple(_sl) + if (omm, sl) in seen: + simm = seen[(omm, sl)] + left += imm + right += simm + else: + seen[(omm, sl)] = imm + + if has_duplicates(na.outer_in_mit_sot): + seen = {} + for oms, ims, _sl in zip( + na.outer_in_mit_sot, na.inner_in_mit_sot, na.mit_sot_in_slices + ): + sl = tuple(_sl) + if (oms, sl) in seen: + sims = seen[(oms, sl)] + left += ims + right += sims + else: + seen[(oms, sl)] = ims + + def map_out(outer_i, inner_o, outer_o, seen): + # Return the outer input corresponding to an + # (outer input, inner output) pair. If we see that pair for the first + # time, return the provided outer output. If an equivalent pair had + # already been seen, return that one instead. + # Note that we need to check that the outer input match as well, + # because they could have different sizes, and the corresponding + # outer outputs cannot be merged in that case. + for s_outer_i, s_inner_o, s_outer_o in seen: + if ( + equal_computations([inner_o], [s_inner_o], left, right) + and outer_i == s_outer_i + ): + return s_outer_o + seen.append((outer_i, inner_o, outer_o)) + return outer_o + + seen = [] + + assert len(na.outer_in_nit_sot) == len(na.inner_out_nit_sot) + assert len(na.inner_out_nit_sot) == len(na.outer_out_nit_sot) + na.outer_out_nit_sot = [ + map_out(outer_i, inner_o, outer_o, seen) + for outer_i, inner_o, outer_o in zip( + na.outer_in_nit_sot, na.inner_out_nit_sot, na.outer_out_nit_sot + ) + ] + + seen = [] + assert len(na.outer_in_sit_sot) == len(na.inner_out_sit_sot) + assert len(na.inner_out_sit_sot) == len(na.outer_out_sit_sot) + na.outer_out_sit_sot = [ + map_out(outer_i, inner_o, outer_o, seen) + for outer_i, inner_o, outer_o in zip( + na.outer_in_sit_sot, na.inner_out_sit_sot, na.outer_out_sit_sot + ) + ] + + seen = [] + assert len(na.outer_in_mit_sot) == len(na.inner_out_mit_sot) + assert len(na.inner_out_mit_sot) == len(na.outer_out_mit_sot) + na.outer_out_mit_sot = [ + map_out(outer_i, inner_o, outer_o, seen) + for outer_i, inner_o, outer_o in zip( + na.outer_in_mit_sot, na.inner_out_mit_sot, na.outer_out_mit_sot + ) + ] + + seen = [] + new_outer_out_mit_mot = [] + assert len(na.outer_in_mit_mot) == len(na.inner_out_mit_mot) + assert len(na.inner_out_mit_mot) == len(na.outer_out_mit_mot) + assert len(na.outer_out_mit_mot) == len(na.mit_mot_out_slices) + for outer_imm, inner_omm, outer_omm, osl in zip( + na.outer_in_mit_mot, + na.inner_out_mit_mot, + na.outer_out_mit_mot, + na.mit_mot_out_slices, + ): + for s_outer_imm, s_inner_omm, s_outer_omm, sosl in seen: + if ( + osl == sosl + and equal_computations(inner_omm, s_inner_omm, left, right) + and outer_imm == s_outer_imm + ): + + new_outer_out_mit_mot.append(s_outer_omm) + break + else: + seen.append((outer_imm, inner_omm, outer_omm, osl)) + new_outer_out_mit_mot.append(outer_omm) + na.outer_out_mit_mot = new_outer_out_mit_mot + if remove: + return dict([("remove", remove)] + list(zip(node.outputs, na.outer_outputs))) + return na.outer_outputs + + +@node_rewriter([Scan]) +def push_out_dot1_scan(fgraph, node): + r""" + This is another optimization that attempts to detect certain patterns of + computation in a `Scan` `Op`'s inner function and move this computation to the + outer graph. + """ + if not isinstance(node.op, Scan): + return False + + # Replace pattern of the form + # x[t] = x[t-1] + dot(seq[t], value) + # with Sequence.reshape((-1, seq.shape[2])) \dot Value + # When seq[t] is a vector/matrix and `value` is a matrix + # Note that this works when only you need X[-1] in the end + # and assumes dimshuffle are applied to vectors before calling dot + op = node.op + sitsot_ins = op.inner_sitsot(op.inner_inputs) + sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) + outer_sitsot = op.outer_sitsot_outs(node.outputs) + seqs = op.inner_seqs(op.inner_inputs) + for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot): + + if ( + out.owner + and isinstance(out.owner.op, Elemwise) + and isinstance(out.owner.op.scalar_op, aes.Add) + and inp in out.owner.inputs + and len(fgraph.clients[outer_out]) == 1 + and not isinstance(fgraph.clients[outer_out][0][0], str) + and isinstance(fgraph.clients[outer_out][0][0].op, Subtensor) + and fgraph.clients[outer_out][0][0].op.idx_list == (-1,) + ): + + x = out.owner.inputs[0] + if x == inp: + x = out.owner.inputs[1] + # We need to check if x is the result of an outer product + if ( + x.owner + and isinstance(x.owner.op, Dot) + and x.owner.inputs[0].ndim == 2 + and x.owner.inputs[1].ndim == 2 + ): + + # We need to check if any of the inputs are a sequence + inp1 = x.owner.inputs[0] + inp2 = x.owner.inputs[1] + + if inp1 in seqs or inp2 in seqs: + new_scan_out = inp1 + + if inp1 in seqs: + new_scan_out = inp2 + idx = sitsot_outs.index(out) + # We've found our pattern and need to construct a new + # scan node to replace this one. For this we need to + # replace the sit_sot output with a nit_sot output + + # First let us split all arguments according to their + # corresponding categories + + inner_seqs = op.inner_seqs(op.inner_inputs) + outer_seqs = op.outer_seqs(node.inputs) + inner_mitmot = op.inner_mitmot(op.inner_inputs) + outer_mitmot = op.outer_mitmot(node.inputs) + inner_mitmot_outs = op.inner_mitmot_outs(op.inner_outputs) + inner_mitsot = op.inner_mitsot(op.inner_inputs) + outer_mitsot = op.outer_mitsot(node.inputs) + inner_mitsot_outs = op.inner_mitsot_outs(op.inner_outputs) + inner_sitsot = op.inner_sitsot(op.inner_inputs) + outer_sitsot = op.outer_sitsot(node.inputs) + inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) + outer_nitsot = op.outer_nitsot(node.inputs) + inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs) + inner_shared = op.inner_shared(op.inner_inputs) + outer_shared = op.outer_shared(node.inputs) + inner_shared_outs = op.inner_shared_outs(op.inner_outputs) + inner_non_seqs = op.inner_non_seqs(op.inner_inputs) + outer_non_seqs = op.outer_non_seqs(node.inputs) + + new_info = dataclasses.replace( + op.info, + sit_sot_in_slices=op.info.sit_sot_in_slices[:idx] + + op.info.sit_sot_in_slices[idx + 1 :], + n_nit_sot=op.info.n_nit_sot + 1, + ) + inner_sitsot = inner_sitsot[:idx] + inner_sitsot[idx + 1 :] + outer_sitsot = outer_sitsot[:idx] + outer_sitsot[idx + 1 :] + inner_sitsot_outs = ( + inner_sitsot_outs[:idx] + inner_sitsot_outs[idx + 1 :] + ) + # add n_steps as the length + inner_nitsot_outs.append(new_scan_out) + + _new_inner_inps = ( + inner_seqs + + inner_mitmot + + inner_mitsot + + inner_sitsot + + inner_shared + + inner_non_seqs + ) + _new_inner_outs = ( + inner_mitmot_outs + + inner_mitsot_outs + + inner_sitsot_outs + + inner_nitsot_outs + + inner_shared_outs + ) + new_inner_inps, new_inner_outs = reconstruct_graph( + _new_inner_inps, _new_inner_outs + ) + new_op = Scan( + new_inner_inps, + new_inner_outs, + new_info, + mode=op.mode, + profile=op.profile, + truncate_gradient=op.truncate_gradient, + # TODO: This seems questionable + name=op.name, + allow_gc=op.allow_gc, + ) + _scan_inputs = ( + [node.inputs[0]] + + outer_seqs + + outer_mitmot + + outer_mitsot + + outer_sitsot + + outer_shared + + outer_nitsot + + [node.inputs[0]] + + outer_non_seqs + ) + + new_outs = new_op(*_scan_inputs) + if not isinstance(new_outs, (list, tuple)): + new_outs = [new_outs] + + # We need now to pair correctly the new outputs + # with the old ones + + outer_nitsot_outs = new_op.outer_nitsot_outs(new_outs) + + _val = outer_nitsot_outs[-1] + outer_nitsot_outs = outer_nitsot_outs[:-1] + if inp1 in seqs: + _out_seq = op.outer_seqs(node.inputs)[seqs.index(inp1)] + # We need to clip the seq to the number of steps + _out_seq = _out_seq[: node.inputs[0]] + sh0 = _out_seq.shape[0] + sh1 = _out_seq.shape[1] + sh2 = _out_seq.shape[2] + out_seq = _out_seq.dimshuffle(1, 0, 2) + out_seq = out_seq.reshape((sh1, sh0 * sh2)) + sh0 = _val.shape[0] + sh1 = _val.shape[1] + sh2 = _val.shape[2] + + val = _val.reshape((sh0 * sh1, sh2)) + new_out = dot(out_seq, val) + else: + _out_seq = op.outer_seqs(node.inputs)[seqs.index(inp2)] + out_seq = _out_seq.reshape( + ( + _out_seq.shape[0] * _out_seq.shape[1], + _out_seq.shape[2], + ) + ) + + val = _val.dimshuffle(1, 0, 2).reshape( + (_val.shape[1], _val.shape[0] * _val.shape[2]) + ) + new_out = dot(val, out_seq) + + pos = node.outputs.index(outer_out) + old_new = list(zip(node.outputs[:pos], new_outs[:pos])) + old = fgraph.clients[node.outputs[pos]][0][0].outputs[0] + old_new.append((old, new_out)) + old_new += list(zip(node.outputs[pos + 1 :], new_outs[pos:])) + replacements = dict(old_new) + replacements["remove"] = [node] + return replacements + + return False + + +# I've added an equilibrium because later scan optimization in the sequence +# can make it such that earlier optimizations should apply. However, in +# general I do not expect the sequence to run more then once +scan_eqopt1 = EquilibriumDB() +scan_seqopt1 = SequenceDB() +scan_eqopt2 = EquilibriumDB() + +# scan_eqopt1 before ShapeOpt at 0.1 +# This is needed to don't have ShapeFeature trac old Scan that we +# don't want to reintroduce. +optdb.register("scan_eqopt1", scan_eqopt1, "fast_run", "scan", position=0.05) +# We run before blas opt at 1.7 and specialize 2.0 +# but after stabilize at 1.5. Should we put it before stabilize? +optdb.register("scan_eqopt2", scan_eqopt2, "fast_run", "scan", position=1.6) +# ScanSaveMem should execute only once per node. +optdb.register( + "scan_save_mem", + in2out(save_mem_new_scan, ignore_newtrees=True), + "fast_run", + "scan", + position=1.61, +) +optdb.register( + "scan_make_inplace", + ScanInplaceOptimizer(), + "fast_run", + "inplace", + "scan", + position=75, +) + +scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan") + + +scan_seqopt1.register( + "scan_remove_constants_and_unused_inputs0", + in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), + "remove_constants_and_unused_inputs_scan", + "fast_run", + "scan", + position=1, +) + + +scan_seqopt1.register( + "scan_pushout_nonseqs_ops", + in2out(push_out_non_seq_scan, ignore_newtrees=True), + "fast_run", + "scan", + "scan_pushout", + position=2, +) + + +scan_seqopt1.register( + "scan_pushout_seqs_ops", + in2out(push_out_seq_scan, ignore_newtrees=True), + "fast_run", + "scan", + "scan_pushout", + position=3, +) + + +scan_seqopt1.register( + "scan_pushout_dot1", + in2out(push_out_dot1_scan, ignore_newtrees=True), + "fast_run", + "more_mem", + "scan", + "scan_pushout", + position=4, +) + + +scan_seqopt1.register( + "scan_pushout_add", + # TODO: Perhaps this should be an `EquilibriumGraphRewriter`? + in2out(push_out_add_scan, ignore_newtrees=False), + "fast_run", + "more_mem", + "scan", + "scan_pushout", + position=5, +) + + +scan_eqopt2.register( + "constant_folding_for_scan2", + in2out(constant_folding, ignore_newtrees=True), + "fast_run", + "scan", +) + + +scan_eqopt2.register( + "scan_remove_constants_and_unused_inputs1", + in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), + "remove_constants_and_unused_inputs_scan", + "fast_run", + "scan", +) + + +# after const merge but before stabilize so that we can have identity +# for equivalent nodes but we still have the chance to hoist stuff out +# of the scan later. +scan_eqopt2.register("scan_merge", ScanMerge(), "fast_run", "scan") + +# After Merge optimization +scan_eqopt2.register( + "scan_remove_constants_and_unused_inputs2", + in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), + "remove_constants_and_unused_inputs_scan", + "fast_run", + "scan", +) + +scan_eqopt2.register( + "scan_merge_inouts", + in2out(scan_merge_inouts, ignore_newtrees=True), + "fast_run", + "scan", +) + +# After everything else +scan_eqopt2.register( + "scan_remove_constants_and_unused_inputs3", + in2out(remove_constants_and_unused_inputs_scan, ignore_newtrees=True), + "remove_constants_and_unused_inputs_scan", + "fast_run", + "scan", +) diff --git a/aesara/sparse/__init__.py b/aesara/sparse/__init__.py index 376da5dcce..77316ff638 100644 --- a/aesara/sparse/__init__.py +++ b/aesara/sparse/__init__.py @@ -1,47 +1,34 @@ -from warnings import warn - - -try: - import scipy - - enable_sparse = True -except ImportError: - enable_sparse = False - warn("SciPy can't be imported. Sparse matrix support is disabled.") - +from aesara.sparse import rewriting, sharedvar +from aesara.sparse.basic import * +from aesara.sparse.sharedvar import sparse_constructor as shared from aesara.sparse.type import SparseTensorType, _is_sparse -if enable_sparse: - from aesara.sparse import opt, sharedvar - from aesara.sparse.basic import * - from aesara.sparse.sharedvar import sparse_constructor as shared - - def sparse_grad(var): - """This function return a new variable whose gradient will be - stored in a sparse format instead of dense. +def sparse_grad(var): + """This function return a new variable whose gradient will be + stored in a sparse format instead of dense. - Currently only variable created by AdvancedSubtensor1 is supported. - i.e. a_tensor_var[an_int_vector]. + Currently only variable created by AdvancedSubtensor1 is supported. + i.e. a_tensor_var[an_int_vector]. - .. versionadded:: 0.6rc4 - """ - from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1 + .. versionadded:: 0.6rc4 + """ + from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1 - if var.owner is None or not isinstance( - var.owner.op, (AdvancedSubtensor, AdvancedSubtensor1) - ): - raise TypeError( - "Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1" - ) + if var.owner is None or not isinstance( + var.owner.op, (AdvancedSubtensor, AdvancedSubtensor1) + ): + raise TypeError( + "Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1" + ) - x = var.owner.inputs[0] - indices = var.owner.inputs[1:] + x = var.owner.inputs[0] + indices = var.owner.inputs[1:] - if len(indices) > 1: - raise TypeError( - "Sparse gradient is only implemented for single advanced indexing" - ) + if len(indices) > 1: + raise TypeError( + "Sparse gradient is only implemented for single advanced indexing" + ) - ret = AdvancedSubtensor1(sparse_grad=True)(x, indices[0]) - return ret + ret = AdvancedSubtensor1(sparse_grad=True)(x, indices[0]) + return ret diff --git a/aesara/sparse/basic.py b/aesara/sparse/basic.py index eba5fbe353..6f5bb22b0a 100644 --- a/aesara/sparse/basic.py +++ b/aesara/sparse/basic.py @@ -384,8 +384,6 @@ def __gt__(self, other): def __ge__(self, other): return ge(self, other) - # extra pseudo-operator symbols - def __dot__(left, right): return structured_dot(left, right) @@ -397,25 +395,16 @@ def sum(self, axis=None, sparse_grad=False): dot = __dot__ - # N.B. THIS IS COMMENTED OUT ON PURPOSE!!! - # Discussion with Fred & James (at least, and maybe others before) - # we decided that casting from a sparse to dense should be explicit - # because it's usually something you just want to be pretty careful - # about, and not to do by accident. - # def _as_TensorVariable(self): - # return dense_from_sparse(self) - def toarray(self): return dense_from_sparse(self) @property def shape(self): + # TODO: The plan is that the ShapeFeature in at.opt will do shape + # propagation and remove the dense_from_sparse from the graph. This + # will *NOT* actually expand your sparse matrix just to get the shape. return shape(dense_from_sparse(self)) - # don't worry! - # the plan is that the ShapeFeature in at.opt will do shape propagation - # and remove the dense_from_sparse from the graph. This will *NOT* - # actually expand your sparse matrix just to get the shape. ndim = property(lambda self: self.type.ndim) dtype = property(lambda self: self.type.dtype) @@ -532,7 +521,6 @@ def bsr_matrix(name=None, dtype=None): return matrix("bsr", name, dtype) -# for more dtypes, call SparseTensorType(format, dtype) csc_dmatrix = SparseTensorType(format="csc", dtype="float64") csr_dmatrix = SparseTensorType(format="csr", dtype="float64") bsr_dmatrix = SparseTensorType(format="bsr", dtype="float64") @@ -551,38 +539,57 @@ def bsr_matrix(name=None, dtype=None): discrete_dtypes = int_dtypes + uint_dtypes -# CONSTRUCTION class CSMProperties(Op): - # See doc in instance of this Op or function after this class definition. - # NOTE - # We won't implement infer_shape for this op now. This will - # ask that we implement an GetNNZ op, and this op will keep - # the dependence on the input of this op. So this won't help - # to remove computations in the graph. To remove computation, - # we will need to make an infer_sparse_pattern feature to - # remove computations. Doing this is trickier then the - # infer_shape feature. For example, how do we handle the case - # when some op create some 0 values? So there is dependence - # on the values themselves. We could write an infer_shape for - # the last output that is the shape, but I dough this will - # get used. - - # we don't return a view of the shape, we create a new ndarray from the - # shape tuple. - __props__ = () - view_map = {0: [0], 1: [0], 2: [0]} + """Create arrays containing all the properties of a given sparse matrix. - """ - Indexing to specified what part of the data parameter - should be use to construct the sparse matrix. + More specifically, this `Op` extracts the ``.data``, ``.indices``, + ``.indptr`` and ``.shape`` fields. + + For specific field, `csm_data`, `csm_indices`, `csm_indptr` + and `csm_shape` are provided. + Notes + ----- + The grad implemented is regular, i.e. not structured. + `infer_shape` method is not available for this `Op`. + + We won't implement infer_shape for this op now. This will + ask that we implement an GetNNZ op, and this op will keep + the dependence on the input of this op. So this won't help + to remove computations in the graph. To remove computation, + we will need to make an infer_sparse_pattern feature to + remove computations. Doing this is trickier then the + infer_shape feature. For example, how do we handle the case + when some op create some 0 values? So there is dependence + on the values themselves. We could write an infer_shape for + the last output that is the shape, but I dough this will + get used. + + We don't return a view of the shape, we create a new ndarray from the shape + tuple. """ + __props__ = () + view_map = {0: [0], 1: [0], 2: [0]} + def __init__(self, kmap=None): if kmap is not None: raise Exception("Do not use kmap, it is removed") def make_node(self, csm): + """ + + The output vectors correspond to the tuple + ``(data, indices, indptr, shape)``, i.e. the properties of a `csm` + array. + + Parameters + ---------- + csm + Sparse matrix in `CSR` or `CSC` format. + + """ + csm = as_sparse_variable(csm) assert csm.format in ("csr", "csc") data = TensorType(dtype=csm.type.dtype, shape=(False,))() @@ -618,26 +625,6 @@ def grad(self, inputs, g): # don't make this a function or it breaks some optimizations below csm_properties = CSMProperties() -""" -Extract all of .data, .indices, .indptr and .shape field. - -For specific field, `csm_data`, `csm_indices`, `csm_indptr` -and `csm_shape` are provided. - -Parameters ----------- -csm - Sparse matrix in CSR or CSC format. - -Returns - (data, indices, indptr, shape), the properties of `csm`. - -Notes ------ -The grad implemented is regular, i.e. not structured. -`infer_shape` method is not available for this op. - -""" def csm_data(csm): @@ -673,18 +660,16 @@ def csm_shape(csm): class CSM(Op): - # See doc in instance of this Op or function after this class definition. - """ - Indexing to specified what part of the data parameter - should be used to construct the sparse matrix. + """Construct a CSM matrix from constituent parts. - """ - __props__ = ("format",) - """ - Pre-computed hash value, defined by __init__. + Notes + ----- + The grad method returns a dense vector, so it provides a regular grad. """ + __props__ = ("format",) + def __init__(self, format, kmap=None): if format not in ("csr", "csc"): raise ValueError("format must be one of: 'csr', 'csc'", format) @@ -696,6 +681,24 @@ def __init__(self, format, kmap=None): self.view_map = {0: [0]} def make_node(self, data, indices, indptr, shape): + """ + + Parameters + ---------- + data + One dimensional tensor representing the data of the sparse matrix to + construct. + indices + One dimensional tensor of integers representing the indices of the sparse + matrix to construct. + indptr + One dimensional tensor of integers representing the indice pointer for + the sparse matrix to construct. + shape + One dimensional tensor of integers representing the shape of the sparse + matrix to construct. + + """ data = at.as_tensor_variable(data) if not isinstance(indices, Variable): @@ -784,80 +787,29 @@ def infer_shape(self, fgraph, node, shapes): CSC = CSM("csc") -""" -Construct a CSC matrix from the internal representation. - -Parameters ----------- -data - One dimensional tensor representing the data of the sparse matrix to - construct. -indices - One dimensional tensor of integers representing the indices of the sparse - matrix to construct. -indptr - One dimensional tensor of integers representing the indice pointer for - the sparse matrix to construct. -shape - One dimensional tensor of integers representing the shape of the sparse - matrix to construct. - -Returns -------- -sparse matrix - A sparse matrix having the properties specified by the inputs. - -Notes ------ -The grad method returns a dense vector, so it provides a regular grad. - -""" CSR = CSM("csr") -""" -Construct a CSR matrix from the internal representation. - -Parameters ----------- -data - One dimensional tensor representing the data of the sparse matrix to - construct. -indices - One dimensional tensor of integers representing the indices of the sparse - matrix to construct. -indptr - One dimensional tensor of integers representing the indice pointer for - the sparse matrix to construct. -shape - One dimensional tensor of integers representing the shape of the sparse - matrix to construct. - -Returns -------- -sparse matrix - A sparse matrix having the properties specified by the inputs. - -Notes ------ -The grad method returns a dense vector, so it provides a regular grad. - -""" class CSMGrad(Op): - # Note - # This Op computes the gradient of the CSM Op. CSM creates a matrix from - # data, indices, and indptr vectors; it's gradient is the gradient of - # the data vector only. There are two complexities to calculate this - # gradient: - # 1. The gradient may be sparser than the input matrix defined by (data, - # indices, indptr). In this case, the data vector of the gradient will have - # less elements than the data vector of the input because sparse formats - # remove 0s. Since we are only returning the gradient of the data vector, - # the relevant 0s need to be added back. - # 2. The elements in the sparse dimension are not guaranteed to be sorted. - # Therefore, the input data vector may have a different order than the - # gradient data vector. + """Compute the gradient of a CSM. + + Note + ---- + CSM creates a matrix from data, indices, and indptr vectors; it's gradient + is the gradient of the data vector only. There are two complexities to + calculate this gradient: + + 1. The gradient may be sparser than the input matrix defined by (data, + indices, indptr). In this case, the data vector of the gradient will have + less elements than the data vector of the input because sparse formats + remove 0s. Since we are only returning the gradient of the data vector, + the relevant 0s need to be added back. + 2. The elements in the sparse dimension are not guaranteed to be sorted. + Therefore, the input data vector may have a different order than the + gradient data vector. + """ + __props__ = () def __init__(self, kmap=None): @@ -927,7 +879,6 @@ def infer_shape(self, fgraph, node, shapes): class Cast(Op): - # See doc in instance of this Op or function after this class definition. __props__ = ("out_type",) def __init__(self, out_type): @@ -1005,14 +956,18 @@ def cast(variable, dtype): return Cast(dtype)(variable) -# -# Conversion -# +class DenseFromSparse(Op): + """Convert a sparse matrix to a dense one. + Notes + ----- + The grad implementation can be controlled through the constructor via the + `structured` parameter. `True` will provide a structured grad while `False` + will provide a regular grad. By default, the grad is structured. -class DenseFromSparse(Op): - # See doc in instance of this Op or function after this class definition. - __props__ = () # We don't put sparse_grad in the props. + """ + + __props__ = () def __init__(self, structured=True): self.sparse_grad = structured @@ -1027,6 +982,14 @@ def __call__(self, x): return super().__call__(x) def make_node(self, x): + """ + + Parameters + ---------- + x + A sparse matrix. + + """ x = as_sparse_variable(x) return Apply( self, @@ -1071,29 +1034,10 @@ def infer_shape(self, fgraph, node, shapes): dense_from_sparse = DenseFromSparse() -""" -Convert a sparse matrix to a dense one. - -Parameters ----------- -x - A sparse matrix. - -Returns -------- -aesara.tensor.matrix - A dense matrix, the same as `x`. - -Notes ------ -The grad implementation can be controlled through the constructor via the -`structured` parameter. `True` will provide a structured grad while `False` -will provide a regular grad. By default, the grad is structured. - -""" class SparseFromDense(Op): + """Convert a dense matrix to a sparse matrix.""" __props__ = () @@ -1110,6 +1054,14 @@ def __call__(self, x): return super().__call__(x) def make_node(self, x): + """ + + Parameters + ---------- + x + A dense matrix. + + """ x = at.as_tensor_variable(x) if x.ndim > 2: raise TypeError( @@ -1146,40 +1098,12 @@ def infer_shape(self, fgraph, node, shapes): csr_from_dense = SparseFromDense("csr") -""" -Convert a dense matrix to a sparse csr matrix. - -Parameters ----------- -x - A dense matrix. - -Returns -------- -sparse matrix - The same as `x` in a sparse csr matrix format. - -""" csc_from_dense = SparseFromDense("csc") -""" -Convert a dense matrix to a sparse csc matrix. - -Parameters ----------- -x - A dense matrix. -Returns -------- -sparse matrix - The same as `x` in a sparse csc matrix format. -""" - - -# Indexing class GetItemList(Op): + """Select row of sparse matrix, returning them as a new sparse matrix.""" __props__ = () @@ -1187,6 +1111,16 @@ def infer_shape(self, fgraph, node, shapes): return [(shapes[1][0], shapes[0][1])] def make_node(self, x, index): + """ + + Parameters + ---------- + x + Sparse matrix. + index + List of rows. + + """ x = as_sparse_variable(x) assert x.format in ("csr", "csc") @@ -1213,22 +1147,6 @@ def grad(self, inputs, g_outputs): get_item_list = GetItemList() -""" -Select row of sparse matrix, returning them as a new sparse matrix. - -Parameters ----------- -x - Sparse matrix. -index - List of rows. - -Returns -------- -sparse matrix - The corresponding rows in `x`. - -""" class GetItemListGrad(Op): @@ -1276,10 +1194,22 @@ def perform(self, node, inp, outputs): class GetItem2Lists(Op): + """Select elements of sparse matrix, returning them in a vector.""" __props__ = () def make_node(self, x, ind1, ind2): + """ + + Parameters + ---------- + x + Sparse matrix. + index + List of two lists, first list indicating the row of each element and second + list indicating its column. + + """ x = as_sparse_variable(x) assert x.format in ("csr", "csc") ind1 = at.as_tensor_variable(ind1) @@ -1294,13 +1224,9 @@ def perform(self, node, inp, outputs): x = inp[0] ind1 = inp[1] ind2 = inp[2] + # SciPy returns the corresponding elements as a `matrix`-type instance, + # which isn't what we want, so we convert it into an `ndarray` out[0] = np.asarray(x[ind1, ind2]).flatten() - """ - Here scipy returns the corresponding elements in a matrix which isn't - what we are aiming for. Using asarray and flatten, out[0] becomes an - array. - - """ def grad(self, inputs, g_outputs): x, ind1, ind2 = inputs @@ -1313,23 +1239,6 @@ def grad(self, inputs, g_outputs): get_item_2lists = GetItem2Lists() -""" -Select elements of sparse matrix, returning them in a vector. - -Parameters ----------- -x - Sparse matrix. -index - List of two lists, first list indicating the row of each element and second - list indicating its column. - -Returns -------- -aesara.tensor.vector - The corresponding elements in `x`. - -""" class GetItem2ListsGrad(Op): @@ -1375,15 +1284,40 @@ def perform(self, node, inp, outputs): class GetItem2d(Op): - # See doc in instance of this Op or function after this class definition. + """Implement a subtensor of sparse variable, returning a sparse matrix. + + If you want to take only one element of a sparse matrix see + `GetItemScalar` that returns a tensor scalar. + + Notes + ----- + Subtensor selection always returns a matrix, so indexing with [a:b, c:d] + is forced. If one index is a scalar, for instance, x[a:b, c] or x[a, b:c], + an error will be raised. Use instead x[a:b, c:c+1] or x[a:a+1, b:c]. + + The above indexing methods are not supported because the return value + would be a sparse matrix rather than a sparse vector, which is a + deviation from numpy indexing rule. This decision is made largely + to preserve consistency between numpy and aesara. This may be revised + when sparse vectors are supported. + + The grad is not implemented for this op. + + """ __props__ = () - # Fred:Too complicated for now. If you need it, look at - # the Subtensor.infer_shape. - # def infer_shape(self, fgraph, node, i0_shapes): - # return i0_shapes def make_node(self, x, index): + """ + + Parameters + ---------- + x + Sparse matrix. + index + Tuple of slice object. + + """ scipy_ver = [int(n) for n in scipy.__version__.split(".")[:2]] x = as_sparse_variable(x) assert x.format in ("csr", "csc") @@ -1477,50 +1411,36 @@ def perform(self, node, inputs, outputs): get_item_2d = GetItem2d() -""" -Implement a subtensor of sparse variable, returning a sparse matrix. - -If you want to take only one element of a sparse matrix see -`GetItemScalar` that returns a tensor scalar. - -Parameters ----------- -x - Sparse matrix. -index - Tuple of slice object. -Returns -------- -sparse matrix - The corresponding slice in `x`. +class GetItemScalar(Op): + """Subtensor of a sparse variable that takes two scalars as index and returns a scalar. -Notes ------ -Subtensor selection always returns a matrix, so indexing with [a:b, c:d] -is forced. If one index is a scalar, for instance, x[a:b, c] or x[a, b:c], -an error will be raised. Use instead x[a:b, c:c+1] or x[a:a+1, b:c]. - -The above indexing methods are not supported because the return value -would be a sparse matrix rather than a sparse vector, which is a -deviation from numpy indexing rule. This decision is made largely -to preserve consistency between numpy and aesara. This may be revised -when sparse vectors are supported. - -The grad is not implemented for this op. + If you want to take a slice of a sparse matrix see `GetItem2d` that returns a + sparse matrix. -""" + Notes + ----- + The grad is not implemented for this op. + """ -class GetItemScalar(Op): - # See doc in instance of this Op or function after this class definition. __props__ = () def infer_shape(self, fgraph, node, shapes): return [()] def make_node(self, x, index): + """ + + Parameters + ---------- + x + Sparse matrix. + index + Tuple of scalars. + + """ x = as_sparse_variable(x) assert x.format in ("csr", "csc") assert len(index) == 2 @@ -1553,35 +1473,20 @@ def perform(self, node, inputs, outputs): get_item_scalar = GetItemScalar() -""" -Implement a subtensor of a sparse variable that takes two scalars as index and -returns a scalar. - -If you want to take a slice of a sparse matrix see `GetItem2d` that returns a -sparse matrix. -Parameters ----------- -x - Sparse matrix. -index - Tuple of scalars. -Returns -------- -AesaraVariable - The corresponding item in `x`. +class Transpose(Op): + """Transpose of a sparse matrix. -Notes ------ -The grad is not implemented for this op. + Notes + ----- + The returned matrix will not be in the same format. `csc` matrix will be changed + in `csr` matrix and `csr` matrix in `csc` matrix. -""" + The grad is regular, i.e. not structured. + """ -# Linear Algebra -class Transpose(Op): - # See doc in instance of this Op or function after this class definition. view_map = {0: [0]} format_map = {"csr": "csc", "csc": "csr"} @@ -1591,6 +1496,14 @@ def __str__(self): return "Sparse" + self.__class__.__name__ def make_node(self, x): + """ + + Parameters + ---------- + x + Sparse matrix. + + """ x = as_sparse_variable(x) assert x.format in ("csr", "csc") return Apply( @@ -1620,31 +1533,16 @@ def infer_shape(self, fgraph, node, shapes): transpose = Transpose() -""" -Return the transpose of the sparse matrix. - -Parameters ----------- -x - Sparse matrix. -Returns -------- -sparse matrix - `x` transposed. -Notes ------ -The returned matrix will not be in the same format. `csc` matrix will be changed -in `csr` matrix and `csr` matrix in `csc` matrix. - -The grad is regular, i.e. not structured. - -""" +class Neg(Op): + """Negative of the sparse matrix (i.e. multiply by ``-1``). + Notes + ----- + The grad is regular, i.e. not structured. -class Neg(Op): - # See doc in instance of this Op or function after this class definition. + """ __props__ = () @@ -1652,6 +1550,14 @@ def __str__(self): return "Sparse" + self.__class__.__name__ def make_node(self, x): + """ + + Parameters + ---------- + x + Sparse matrix. + + """ x = as_sparse_variable(x) assert x.format in ("csr", "csc") return Apply(self, [x], [x.type()]) @@ -1673,24 +1579,6 @@ def infer_shape(self, fgraph, node, shapes): neg = Neg() -""" -Return the negation of the sparse matrix. - -Parameters ----------- -x - Sparse matrix. - -Returns -------- -sparse matrix - -`x`. - -Notes ------ -The grad is regular, i.e. not structured. - -""" class ColScaleCSC(Op): @@ -1844,14 +1732,16 @@ def row_scale(x, s): class SpSum(Op): - # See doc in instance of this Op or function after this class definition. + """ + + WARNING: judgement call... + We are not using the structured in the comparison or hashing + because it doesn't change the perform method therefore, we + *do* want Sums with different structured values to be merged + by the merge optimization and this requires them to compare equal. + """ __props__ = ("axis",) - # WARNING: judgement call... - # We are not using the structured in the comparison or hashing - # because it doesn't change the perform method therefore, we - # *do* want Sums with different structured values to be merged - # by the merge optimization and this requires them to compare equal. def __init__(self, axis=None, sparse_grad=True): super().__init__() @@ -1960,10 +1850,26 @@ def sp_sum(x, axis=None, sparse_grad=False): class Diag(Op): - # See doc in instance of this Op or function after this class definition. + """Extract the diagonal of a square sparse matrix as a dense vector. + + Notes + ----- + The grad implemented is regular, i.e. not structured, since the output is a + dense vector. + + """ + __props__ = () def make_node(self, x): + """ + + Parameters + ---------- + x + A square sparse matrix in csc format. + + """ x = as_sparse_variable(x) assert x.format in ("csr", "csc") return Apply(self, [x], [tensor(shape=(False,), dtype=x.dtype)]) @@ -1986,33 +1892,28 @@ def infer_shape(self, fgraph, nodes, shapes): diag = Diag() -""" -Extract the diagonal of a square sparse matrix as a dense vector. -Parameters ----------- -x - A square sparse matrix in csc format. -Returns -------- -TensorVariable - A dense vector representing the diagonal elements. - -Notes ------ -The grad implemented is regular, i.e. not structured, since the output is a -dense vector. - -""" +class SquareDiagonal(Op): + """Produce a square sparse (csc) matrix with a diagonal given by a dense vector. + Notes + ----- + The grad implemented is regular, i.e. not structured. -class SquareDiagonal(Op): - # See doc in instance of this Op or function after this class definition. + """ __props__ = () def make_node(self, diag): + """ + + Parameters + ---------- + x + Dense vector for the diagonal. + + """ diag = at.as_tensor_variable(diag) if diag.type.ndim != 1: raise TypeError("data argument must be a vector", diag.type) @@ -2040,29 +1941,22 @@ def infer_shape(self, fgraph, nodes, shapes): square_diagonal = SquareDiagonal() -""" -Return a square sparse (csc) matrix whose diagonal is given by the dense vector -argument. -Parameters ----------- -x - Dense vector for the diagonal. -Returns -------- -sparse matrix - A sparse matrix having `x` as diagonal. +class EnsureSortedIndices(Op): + """Re-sort indices of a sparse matrix. -Notes ------ -The grad implemented is regular, i.e. not structured. + CSR column indices are not necessarily sorted. Likewise + for CSC row indices. Use `ensure_sorted_indices` when sorted + indices are required (e.g. when passing data to other + libraries). -""" + Notes + ----- + The grad implemented is regular, i.e. not structured. + """ -class EnsureSortedIndices(Op): - # See doc in instance of this Op or function after this class definition. __props__ = ("inplace",) def __init__(self, inplace): @@ -2071,6 +1965,13 @@ def __init__(self, inplace): self.view_map = {0: [0]} def make_node(self, x): + """ + Parameters + ---------- + x + A sparse matrix. + + """ x = as_sparse_variable(x) assert x.format in ("csr", "csc") return Apply(self, [x], [x.type()]) @@ -2097,29 +1998,6 @@ def __str__(self): ensure_sorted_indices = EnsureSortedIndices(inplace=False) -""" -Re-sort indices of a sparse matrix. - -CSR column indices are not necessarily sorted. Likewise -for CSC row indices. Use `ensure_sorted_indices` when sorted -indices are required (e.g. when passing data to other -libraries). - -Parameters ----------- -x - A sparse matrix. - -Returns -------- -sparse matrix - The same as `x` with indices sorted. - -Notes ------ -The grad implemented is regular, i.e. not structured. - -""" def clean(x): @@ -2186,10 +2064,31 @@ def infer_shape(self, fgraph, node, shapes): class AddSSData(Op): - # See doc in instance of this Op or function after this class definition. + """Add two sparse matrices assuming they have the same sparsity pattern. + + Notes + ----- + The grad implemented is structured. + + """ + __props__ = () def make_node(self, x, y): + """ + + Parameters + ---------- + x + Sparse matrix. + y + Sparse matrix. + + Notes + ----- + `x` and `y` are assumed to have the same sparsity pattern. + + """ x, y = map(as_sparse_variable, [x, y]) assert x.format in ("csr", "csc") assert y.format in ("csr", "csc") @@ -2221,28 +2120,6 @@ def infer_shape(self, fgraph, node, ins_shapes): add_s_s_data = AddSSData() -""" -Add two sparse matrices assuming they have the same sparsity pattern. - -Parameters ----------- -x - Sparse matrix. -y - Sparse matrix. - -Returns -------- -A sparse matrix - The sum of the two sparse matrices element wise. - -Notes ------ -`x` and `y` are assumed to have the same sparsity pattern. - -The grad implemented is structured. - -""" class AddSD(Op): @@ -2288,10 +2165,30 @@ def infer_shape(self, fgraph, node, shapes): class StructuredAddSV(Op): + """Structured addition of a sparse matrix and a dense vector. + + The elements of the vector are only added to the corresponding + non-zero elements of the sparse matrix. Therefore, this operation + outputs another sparse matrix. + + Notes + ----- + The grad implemented is structured since the op is structured. + + """ __props__ = () def make_node(self, x, y): + """ + Parameters + ---------- + x + Sparse matrix. + y + Tensor type vector. + + """ x = as_sparse_variable(x) assert x.format in ("csr", "csc") y = at.as_tensor_variable(y) @@ -2323,30 +2220,6 @@ def infer_shape(self, fgraph, node, ins_shapes): structured_add_s_v = StructuredAddSV() -""" -Structured addition of a sparse matrix and a dense vector. -The elements of the vector are only added to the corresponding -non-zero elements of the sparse matrix. Therefore, this operation -outputs another sparse matrix. - -Parameters ----------- -x - Sparse matrix. -y - Tensor type vector. - -Returns -------- -A sparse matrix - A sparse matrix containing the addition of the vector to - the data of the sparse matrix. - -Notes ------ -The grad implemented is structured since the op is structured. - -""" def add(x, y): @@ -2558,10 +2431,26 @@ def infer_shape(self, fgraph, node, shapes): class MulSV(Op): + """Element-wise multiplication of sparse matrix by a broadcasted dense vector element wise. + + Notes + ----- + The grad implemented is regular, i.e. not structured. + + """ __props__ = () def make_node(self, x, y): + """ + Parameters + ---------- + x + Sparse matrix to multiply. + y + Tensor broadcastable vector. + + """ x = as_sparse_variable(x) assert x.format in ("csr", "csc") y = at.as_tensor_variable(y) @@ -2605,26 +2494,6 @@ def infer_shape(self, fgraph, node, ins_shapes): mul_s_v = MulSV() -""" -Multiplication of sparse matrix by a broadcasted dense vector element wise. - -Parameters ----------- -x - Sparse matrix to multiply. -y - Tensor broadcastable vector. - -Returns -------- -A sparse matrix - The product x * y element wise. - -Notes ------ -The grad implemented is regular, i.e. not structured. - -""" def mul(x, y): @@ -2922,131 +2791,23 @@ def comparison(self, x, y): eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d) -""" -Parameters ----------- -x - A matrix variable. -y - A matrix variable. - -Returns -------- -matrix variable - `x` == `y` - -Notes ------ -At least one of `x` and `y` must be a sparse matrix. - -""" neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d) -""" -Parameters ----------- -x - A matrix variable. -y - A matrix variable. - -Returns -------- -matrix variable - `x` != `y` - -Notes ------ -At least one of `x` and `y` must be a sparse matrix. - -""" lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d) -""" -Parameters ----------- -x - A matrix variable. -y - A matrix variable. - -Returns -------- -matrix variable - `x` < `y` - -Notes ------ -At least one of `x` and `y` must be a sparse matrix. - -""" gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d) -""" -Parameters ----------- -x - A matrix variable. -y - A matrix variable. - -Returns -------- -matrix variable - `x` > `y` - -Notes ------ -At least one of `x` and `y` must be a sparse matrix. -""" le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d) -""" -Parameters ----------- -x - A matrix variable. -y - A matrix variable. - -Returns -------- -matrix variable - `x` <= `y` - -Notes ------ -At least one of `x` and `y` must be a sparse matrix. - -""" ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d) -""" -Parameters ----------- -x - A matrix variable. -y - A matrix variable. - -Returns -------- -matrix variable - `x` >= `y` - -Notes ------ -At least one of `x` and `y` must be a sparse matrix. - -""" class HStack(Op): - # See doc in instance of this Op or function after this class definition. __props__ = ("format", "dtype") def __init__(self, format=None, dtype=None): @@ -3150,7 +2911,6 @@ def hstack(blocks, format=None, dtype=None): class VStack(HStack): - # See doc in instance of this Op or function after this class definition. def perform(self, node, block, outputs): (out,) = outputs for b in block: @@ -3227,7 +2987,14 @@ def vstack(blocks, format=None, dtype=None): class Remove0(Op): - # See doc in instance of this Op or a function after the class definition. + """Remove explicit zeros from a sparse matrix. + + Notes + ----- + The grad implemented is regular, i.e. not structured. + + """ + __props__ = ("inplace",) def __init__(self, inplace=False): @@ -3242,6 +3009,14 @@ def __str__(self): return f"{self.__class__.__name__ }{{{', '.join(l)}}}" def make_node(self, x): + """ + + Parameters + ---------- + x + Sparse matrix. + + """ x = as_sparse_variable(x) assert x.format in ("csr", "csc") return Apply(self, [x], [x.type()]) @@ -3266,27 +3041,8 @@ def infer_shape(self, fgraph, node, i0_shapes): remove0 = Remove0() -""" -Remove explicit zeros from a sparse matrix. -Parameters ----------- -x - Sparse matrix. -Returns -------- -sparse matrix - Exactly `x` but with a data attribute exempt of zeros. - -Notes ------ -The grad implemented is regular, i.e. not structured. - -""" - - -# Structured monoid def structured_monoid(tensor_op): # Generic operation to perform many kinds of monoid element-wise # operations on the non-zeros of a sparse matrix. @@ -3319,7 +3075,6 @@ def structured_sigmoid(x): Structured elemwise sigmoid. """ - # see decorator for function body @structured_monoid(exp) @@ -3328,7 +3083,6 @@ def structured_exp(x): Structured elemwise exponential. """ - # see decorator for function body @structured_monoid(log) @@ -3337,7 +3091,6 @@ def structured_log(x): Structured elemwise logarithm. """ - # see decorator for function body @structured_monoid(at_pow) @@ -3346,7 +3099,6 @@ def structured_pow(x, y): Structured elemwise power of sparse matrix x by scalar y. """ - # see decorator for function body @structured_monoid(minimum) @@ -3355,7 +3107,6 @@ def structured_minimum(x, y): Structured elemwise minimum of sparse matrix x by scalar y. """ - # see decorator for function body @structured_monoid(maximum) @@ -3364,7 +3115,6 @@ def structured_maximum(x, y): Structured elemwise maximum of sparse matrix x by scalar y. """ - # see decorator for function body @structured_monoid(at_add) @@ -3373,17 +3123,14 @@ def structured_add(x): Structured addition of sparse matrix x and scalar y. """ - # see decorator for function body -# Sparse operation (map 0 to 0) @structured_monoid(sin) # type: ignore[no-redef] def sin(x): """ Elemwise sinus of `x`. """ - # see decorator for function body @structured_monoid(tan) # type: ignore[no-redef] @@ -3392,7 +3139,6 @@ def tan(x): Elemwise tan of `x`. """ - # see decorator for function body @structured_monoid(arcsin) # type: ignore[no-redef] @@ -3401,7 +3147,6 @@ def arcsin(x): Elemwise arcsinus of `x`. """ - # see decorator for function body @structured_monoid(arctan) # type: ignore[no-redef] @@ -3410,7 +3155,6 @@ def arctan(x): Elemwise arctan of `x`. """ - # see decorator for function body @structured_monoid(sinh) # type: ignore[no-redef] @@ -3419,7 +3163,6 @@ def sinh(x): Elemwise sinh of `x`. """ - # see decorator for function body @structured_monoid(arcsinh) # type: ignore[no-redef] @@ -3428,7 +3171,6 @@ def arcsinh(x): Elemwise arcsinh of `x`. """ - # see decorator for function body @structured_monoid(tanh) # type: ignore[no-redef] @@ -3437,7 +3179,6 @@ def tanh(x): Elemwise tanh of `x`. """ - # see decorator for function body @structured_monoid(arctanh) # type: ignore[no-redef] @@ -3446,7 +3187,6 @@ def arctanh(x): Elemwise arctanh of `x`. """ - # see decorator for function body @structured_monoid(round_half_to_even) @@ -3455,7 +3195,6 @@ def rint(x): Elemwise round half to even of `x`. """ - # see decorator for function body # Give it a simple name instead of the complex one that would automatically @@ -3469,7 +3208,6 @@ def sgn(x): Elemwise signe of `x`. """ - # see decorator for function body @structured_monoid(ceil) # type: ignore[no-redef] @@ -3478,7 +3216,6 @@ def ceil(x): Elemwise ceiling of `x`. """ - # see decorator for function body @structured_monoid(floor) # type: ignore[no-redef] @@ -3487,7 +3224,6 @@ def floor(x): Elemwise floor of `x`. """ - # see decorator for function body @structured_monoid(log1p) # type: ignore[no-redef] @@ -3496,7 +3232,6 @@ def log1p(x): Elemwise log(1 + `x`). """ - # see decorator for function body @structured_monoid(expm1) # type: ignore[no-redef] @@ -3505,7 +3240,6 @@ def expm1(x): Elemwise e^`x` - 1. """ - # see decorator for function body @structured_monoid(deg2rad) # type: ignore[no-redef] @@ -3514,7 +3248,6 @@ def deg2rad(x): Elemwise degree to radian. """ - # see decorator for function body @structured_monoid(rad2deg) # type: ignore[no-redef] @@ -3523,16 +3256,14 @@ def rad2deg(x): Elemwise radian to degree. """ - # see decorator for function body @structured_monoid(trunc) # type: ignore[no-redef] def trunc(x): """ - Elemwise truncature. + Elemwise truncation. """ - # see decorator for function body @structured_monoid(sqr) # type: ignore[no-redef] @@ -3541,7 +3272,6 @@ def sqr(x): Elemwise `x` * `x`. """ - # see decorator for function body @structured_monoid(sqrt) # type: ignore[no-redef] @@ -3550,7 +3280,6 @@ def sqrt(x): Elemwise square root of `x`. """ - # see decorator for function body @structured_monoid(_conj) # type: ignore[no-redef] @@ -3559,7 +3288,6 @@ def _conj(x): Elemwise complex conjugate of `x`. """ - # see decorator for function body def conjugate(x): @@ -3712,9 +3440,7 @@ def true_dot(x, y, grad_preserves_dense=True): return transpose(TrueDot(grad_preserves_dense)(y.T, x.T)) -# Dot class StructuredDot(Op): - # See doc in instance of this Op or function after this class definition. __props__ = () def make_node(self, a, b): @@ -4137,10 +3863,40 @@ def structured_dot_grad(sparse_A, dense_B, ga): class SamplingDot(Op): - # See doc in instance of this Op or function after this class definition. + """Compute the dot product ``dot(x, y.T) = z`` for only a subset of `z`. + + This is equivalent to ``p * (x . y.T)`` where ``*`` is the element-wise + product, ``x`` and ``y`` operands of the dot product and ``p`` is a matrix that + contains 1 when the corresponding element of ``z`` should be calculated + and ``0`` when it shouldn't. Note that `SamplingDot` has a different interface + than `dot` because it requires ``x`` to be a ``m x k`` matrix while + ``y`` is a ``n x k`` matrix instead of the usual ``k x n`` matrix. + + Notes + ----- + It will work if the pattern is not binary value, but if the + pattern doesn't have a high sparsity proportion it will be slower + then a more optimized dot followed by a normal elemwise + multiplication. + + The grad implemented is regular, i.e. not structured. + + """ + __props__ = () def make_node(self, x, y, p): + """ + Parameters + ---------- + x + Tensor matrix. + y + Tensor matrix. + p + Sparse matrix in csr format. + + """ x = at.as_tensor_variable(x) y = at.as_tensor_variable(y) p = as_sparse_variable(p) @@ -4180,46 +3936,9 @@ def infer_shape(self, fgraph, node, ins_shapes): sampling_dot = SamplingDot() -""" -Operand for calculating the dot product ``dot(x, y.T) = z`` when you -only want to calculate a subset of `z`. - -It is equivalent to ``p o (x . y.T)`` where ``o`` is the element-wise -product, `x` and `y` operands of the dot product and `p` is a matrix that -contains 1 when the corresponding element of `z` should be calculated -and 0 when it shouldn't. Note that SamplingDot has a different interface -than `dot` because SamplingDot requires `x` to be a ``m x k`` matrix while -`y` is a ``n x k`` matrix instead of the usual ``k x n`` matrix. - -Notes ------ -It will work if the pattern is not binary value, but if the -pattern doesn't have a high sparsity proportion it will be slower -then a more optimized dot followed by a normal elemwise -multiplication. - -The grad implemented is regular, i.e. not structured. - -Parameters ----------- -x - Tensor matrix. -y - Tensor matrix. -p - Sparse matrix in csr format. - -Returns -------- -sparse matrix - A dense matrix containing the dot product of `x` by ``y.T`` only - where `p` is 1. - -""" class Dot(Op): - # See doc in instance of this Op or function after this class definition. __props__ = () def __str__(self): @@ -4326,10 +4045,9 @@ def grad(self, inputs, gout): def dot(x, y): - """ - Operation for efficiently calculating the dot product when - one or all operands is sparse. Supported format are CSC and CSR. - The output of the operation is dense. + """Efficiently compute the dot product when one or all operands are sparse. + + Supported formats are CSC and CSR. The output of the operation is dense. Parameters ---------- @@ -4340,7 +4058,7 @@ def dot(x, y): Returns ------- - The dot product `x`.`y` in a dense format. + The dot product ``x @ y`` in a dense format. Notes ----- @@ -4348,9 +4066,9 @@ def dot(x, y): At least one of `x` or `y` must be a sparse matrix. - When the operation has the form dot(csr_matrix, dense) + When the operation has the form ``dot(csr_matrix, dense)`` the gradient of this operation can be performed inplace - by UsmmCscDense. This leads to significant speed-ups. + by `UsmmCscDense`. This leads to significant speed-ups. """ @@ -4369,15 +4087,34 @@ def dot(x, y): class Usmm(Op): - # See doc in instance of this Op or function after this class definition. - # We don't implement the infer_shape as it is - # inserted by optimization only. + """Computes the dense matrix resulting from ``alpha * x @ y + z``. + + Notes + ----- + At least one of `x` or `y` must be a sparse matrix. + + """ + __props__ = () def __str__(self): return "Usmm{no_inplace}" def make_node(self, alpha, x, y, z): + """ + + Parameters + ---------- + alpha + A scalar. + x + Matrix variable. + y + Matrix variable. + z + Dense matrix. + + """ if not _is_sparse_variable(x) and not _is_sparse_variable(y): # If x and y are tensor, we don't want to use this class # We should use Dot22 and Gemm in that case. @@ -4431,34 +4168,17 @@ def perform(self, node, inputs, outputs): usmm = Usmm() -""" -Performs the expression `alpha` * `x` `y` + `z`. - -Parameters ----------- -x - Matrix variable. -y - Matrix variable. -z - Dense matrix. -alpha - A tensor scalar. - -Returns -------- -The dense matrix resulting from `alpha` * `x` `y` + `z`. - -Notes ------ -The grad is not implemented for this op. -At least one of `x` or `y` must be a sparse matrix. - -""" class ConstructSparseFromList(Op): - # See doc in instance of this Op or function after this class definition. + """Constructs a sparse matrix out of a list of 2-D matrix rows. + + Notes + ----- + The grad implemented is regular, i.e. not structured. + + """ + __props__ = () def make_node(self, x, values, ilist): @@ -4549,11 +4269,3 @@ def grad(self, inputs, grads): construct_sparse_from_list = ConstructSparseFromList() -""" -Constructs a sparse matrix out of a list of 2-D matrix rows. - -Notes ------ -The grad implemented is regular, i.e. not structured. - -""" diff --git a/aesara/sparse/opt.py b/aesara/sparse/opt.py index 6dab75f6cb..7c275dd2e4 100644 --- a/aesara/sparse/opt.py +++ b/aesara/sparse/opt.py @@ -1,2083 +1,10 @@ -import scipy +import warnings -import aesara -import aesara.scalar as aes -from aesara.configdefaults import config -from aesara.graph.basic import Apply -from aesara.graph.opt import PatternSub, TopoOptimizer, local_optimizer -from aesara.link.c.op import COp, _NoPythonCOp -from aesara.misc.safe_asarray import _asarray -from aesara.sparse import basic as sparse -from aesara.sparse.basic import ( - CSC, - CSR, - csm_data, - csm_grad, - csm_indices, - csm_indptr, - csm_properties, - usmm, -) -from aesara.tensor import blas -from aesara.tensor.basic import as_tensor_variable, cast -from aesara.tensor.basic_opt import register_canonicalize, register_specialize -from aesara.tensor.math import mul, neg, sub -from aesara.tensor.shape import shape, specify_shape -from aesara.tensor.type import TensorType, tensor - - -_is_sparse_variable = sparse._is_sparse_variable -_is_dense = sparse._is_dense - -# This is tested in tests/test_opt.py:test_local_csm_properties_csm - - -@local_optimizer([csm_properties]) -def local_csm_properties_csm(fgraph, node): - """ - If we find csm_properties(CSM(*args)), then we can replace that with the - *args directly. - - """ - if node.op == csm_properties: - (csm,) = node.inputs - if csm.owner and (csm.owner.op == CSC or csm.owner.op == CSR): - return csm.owner.inputs - - return False - - -register_specialize(local_csm_properties_csm) - - -# This is tested in tests/test_basic.py:test_remove0 -@local_optimizer([sparse.Remove0]) -def local_inplace_remove0(fgraph, node): - """ - Optimization to insert inplace versions of Remove0. - - """ - # If inplace is not enabled, enable it and replace that op with a - # new op which has inplace enabled - if isinstance(node.op, sparse.Remove0) and not node.op.inplace: - new_op = node.op.__class__(inplace=True) - new_node = new_op(*node.inputs) - return [new_node] - return False - - -aesara.compile.optdb.register( - "local_inplace_remove0", - TopoOptimizer(local_inplace_remove0, failure_callback=TopoOptimizer.warn_inplace), - "fast_run", - "inplace", - position=60, -) - - -class AddSD_ccode(_NoPythonCOp): - """ - Add a sparse and a dense matrix. - - Parameters - ---------- - x - A sparse matrix. - y - A dense matrix - - Returns - ------- - matrix - `x`+`y` - - Notes - ----- - The grad implemented is structured on `x`. - - """ - - __props__ = ("format", "inplace") - - def __init__(self, format, inplace=False, *args, **kwargs): - super().__init__(*args, **kwargs) - # Should we do inplace addition or not ? - self.inplace = inplace - self.format = format - if self.inplace: - self.destroy_map = {0: [3]} - - def __str__(self): - inp = "" - if self.inplace: - inp = ",inplace" - return f"{self.__class__.__name__}{{{self.format}{inp}}}" - - def make_node(self, x, y): - x, y = sparse.as_sparse_variable(x), as_tensor_variable(y) - out_dtype = aes.upcast(x.type.dtype, y.type.dtype) - if self.inplace: - assert out_dtype == y.dtype - indices, indptr, data = csm_indices(x), csm_indptr(x), csm_data(x) - # We either use CSC or CSR depending on the format of input - assert self.format == x.type.format - # The magic number two here arises because L{scipy.sparse} - # objects must be matrices (have dimension 2) - assert y.type.ndim == 2 - out = TensorType(dtype=out_dtype, shape=y.type.broadcastable)() - return Apply(self, [data, indices, indptr, y], [out]) - - def c_code(self, node, name, inputs, outputs, sub): - (_data, _indices, _indptr, y) = inputs - (z,) = outputs - inplace = int(self.inplace) - format = {"csc": 0, "csr": 1}[self.format] - out_typenum = node.outputs[0].type.dtype_specs()[2] - code = """ - Py_XDECREF(%(z)s); - if (!%(inplace)s){ - if(PyArray_TYPE(%(y)s) != %(out_typenum)s){ - %(z)s = (PyArrayObject *) PyArray_FromArray(%(y)s, PyArray_DescrFromType(%(out_typenum)s), 0); - }else{ - %(z)s = (PyArrayObject *) PyArray_NewCopy(%(y)s, NPY_CORDER); - } - }else{ - %(z)s = %(y)s; - Py_XINCREF(%(z)s); - } - - npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; - - const dtype_%(_indptr)s* __restrict__ indptr = (dtype_%(_indptr)s*)PyArray_DATA(%(_indptr)s); - const dtype_%(_indices)s* __restrict__ indices = (dtype_%(_indices)s*)PyArray_DATA(%(_indices)s); - const dtype_%(_data)s* __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s); - - dtype_%(y)s* ydata = (dtype_%(y)s*)PyArray_DATA(%(y)s); - dtype_%(z)s* zdata = (dtype_%(z)s*)PyArray_DATA(%(z)s); - npy_intp Yi = PyArray_STRIDES(%(y)s)[0]/PyArray_DESCR(%(y)s)->elsize; - npy_intp Yj = PyArray_STRIDES(%(y)s)[1]/PyArray_DESCR(%(y)s)->elsize; - - npy_intp pos; - if (%(format)s == 0){ - for (npy_intp col = 0; col < N; ++col){ - for (dtype_%(_indptr)s ind = indptr[col]; ind < indptr[col+1]; ++ind){ - npy_intp row = indices[ind]; - pos = row * Yi + col * Yj; - zdata[pos] = ydata[pos] + data[ind]; - } - } - }else{ - for (npy_intp row = 0; row < N; ++row){ - for (dtype_%(_indptr)s ind = indptr[row]; ind < indptr[row+1]; ++ind){ - npy_intp col = indices[ind]; - pos = row * Yi + col * Yj; - zdata[pos] = ydata[pos] + data[ind]; - } - } - } - """ % dict( - locals(), **sub - ) - return code - - def infer_shape(self, fgraph, node, shapes): - return [shapes[3]] - - def c_code_cache_version(self): - return (2,) - - -@local_optimizer([sparse.AddSD]) -def local_inplace_addsd_ccode(fgraph, node): - """ - Optimization to insert inplace versions of AddSD. - - """ - if isinstance(node.op, sparse.AddSD) and config.cxx: - out_dtype = aes.upcast(*node.inputs) - if out_dtype != node.inputs[1].dtype: - return - new_node = AddSD_ccode(format=node.inputs[0].type.format, inplace=True)( - *node.inputs - ) - return [new_node] - return False - - -aesara.compile.optdb.register( - "local_inplace_addsd_ccode", - TopoOptimizer( - local_inplace_addsd_ccode, failure_callback=TopoOptimizer.warn_inplace - ), - "fast_run", - "inplace", - position=60, -) - - -@register_canonicalize("fast_compile") -@register_specialize -@local_optimizer([sparse.DenseFromSparse]) -def local_dense_from_sparse_sparse_from_dense(fgraph, node): - if isinstance(node.op, sparse.DenseFromSparse): - inp = node.inputs[0] - if inp.owner and isinstance(inp.owner.op, sparse.SparseFromDense): - return inp.owner.inputs - - -@local_optimizer([sparse.AddSD]) -def local_addsd_ccode(fgraph, node): - """ - Convert AddSD to faster AddSD_ccode. - - """ - if isinstance(node.op, sparse.AddSD) and config.cxx: - new_node = AddSD_ccode(format=node.inputs[0].type.format)(*node.inputs) - return [new_node] - return False - - -aesara.compile.optdb.register( - "local_addsd_ccode", - TopoOptimizer(local_addsd_ccode), - # Must be after local_inplace_addsd_ccode at 60 - "fast_run", - position=61, +warnings.warn( + "The module `aesara.sparse.opt` is deprecated; use `aesara.sparse.rewriting` instead.", + DeprecationWarning, + stacklevel=2, ) - -class StructuredDotCSC(COp): - """ - Structured Dot CSC is like dot, except that only the gradient wrt non-zero - elements of the sparse matrix `a` are calculated and propagated. - - The output is presumed to be a dense matrix, and is represented by a - TensorType instance. - - Parameters - ---------- - a - A sparse matrix in csc format. - b - A sparse or dense matrix. - - Returns - ------- - The dot product of `a` and `b`. - - Notes - ----- - The grad implemented is structured. - This op is used as an optimization for StructuredDot. - - """ - - __props__ = () - - def make_node(self, a_val, a_ind, a_ptr, a_nrows, b): - dtype_out = aes.upcast(a_val.type.dtype, b.type.dtype) - r = Apply( - self, - [a_val, a_ind, a_ptr, a_nrows, b], - [tensor(dtype_out, (False, b.type.broadcastable[1]))], - ) - return r - - def perform(self, node, inputs, outputs): - (a_val, a_ind, a_ptr, a_nrows, b) = inputs - (out,) = outputs - a = scipy.sparse.csc_matrix( - (a_val, a_ind, a_ptr), (a_nrows, b.shape[0]), copy=False - ) - # out[0] = a.dot(b) - out[0] = _asarray(a * b, dtype=node.outputs[0].type.dtype) - assert _is_dense(out[0]) # scipy 0.7 automatically converts to dense - - def c_code(self, node, name, inputs, outputs, sub): - # C-implementation of the dot product of the sparse matrix A and matrix - # B. - # @param a_val: non-zero values of the sparse matrix - # @param a_ind: column indices of the non-null values (.indices of a - # scipy.csc_matrix) - # @param a_ptr: a_ptr indicates col indices for col. i are in the range - # a_ptr[i]:a_ptr[i+1] - # @param n_rows: number of rows of sparse matrix - # @param b: dense matrix to perform dot product with, as in dot(a, b) - # @param z: return value - # @param sub: TODO, not too sure, something to do with weave probably - - (a_val, a_ind, a_ptr, a_nrows, b) = inputs - (z,) = outputs - if node.inputs[0].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for a_val") - if node.inputs[4].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for b") - - typenum_z = node.outputs[0].type.dtype_specs()[2] # retrieve dtype number - typenum_a_val = node.inputs[0].type.dtype_specs()[2] # retrieve dtype number - typenum_b = node.inputs[4].type.dtype_specs()[2] # retrieve dtype number - - rval = """ - - if (PyArray_NDIM(%(a_val)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;} - if (PyArray_NDIM(%(a_ind)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;} - if (PyArray_NDIM(%(a_ptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); %(fail)s;} - if (PyArray_NDIM(%(a_nrows)s) != 0) {PyErr_SetString(PyExc_NotImplementedError, "rank(nrows) != 0"); %(fail)s;} - if (PyArray_NDIM(%(b)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;} - - if (PyArray_TYPE(%(a_val)s) != %(typenum_a_val)s) { - PyErr_SetString(PyExc_NotImplementedError, "Invalid type for a_val"); %(fail)s;} - - if (PyArray_TYPE(%(b)s) != %(typenum_b)s) { - PyErr_SetString(PyExc_NotImplementedError, "Invalid type for b"); %(fail)s;} - - if (PyArray_TYPE(%(a_ind)s) != NPY_INT32) { - PyErr_SetString(PyExc_NotImplementedError, "a_ind dtype not INT32"); %(fail)s;} - - if (PyArray_TYPE(%(a_ptr)s) != NPY_INT32) - {PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;} - - if (PyArray_TYPE(%(a_nrows)s) != NPY_INT32) - {PyErr_SetString(PyExc_NotImplementedError, "a_nrows dtype not INT32"); %(fail)s;} - - if (PyArray_DIMS(%(a_val)s)[0] != PyArray_DIMS(%(a_ind)s)[0]) - {PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;} - - if (PyArray_DIMS(%(a_ptr)s)[0] != PyArray_DIMS(%(b)s)[0]+1) - {PyErr_SetString(PyExc_NotImplementedError, "a's number of columns doesn't match b's rows"); %(fail)s;} - - if ((!%(z)s) - || (PyArray_DIMS(%(z)s)[0] != ((npy_int32 *)PyArray_DATA(%(a_nrows)s))[0]) - || (PyArray_DIMS(%(z)s)[1] != PyArray_DIMS(%(b)s)[1]) - ) - { - {Py_XDECREF(%(z)s);} - npy_intp dims[] = {0, 0}; - dims[0] = ((npy_int32 *)PyArray_DATA(%(a_nrows)s))[0]; - dims[1] = PyArray_DIMS(%(b)s)[1]; - %(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(typenum_z)s); - } - - { - // sparse array has size MxK, dense KxN, output MxN - npy_intp M = PyArray_DIMS(%(z)s)[0]; - npy_intp N = PyArray_DIMS(%(z)s)[1]; - npy_intp K = PyArray_DIMS(%(b)s)[0]; - if (N > 0x7fffffffL) - {PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); %(fail)s;} - - // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Szm = PyArray_STRIDES(%(z)s)[0] / PyArray_DESCR(%(z)s)->elsize; - npy_intp Szn = PyArray_STRIDES(%(z)s)[1] / PyArray_DESCR(%(z)s)->elsize; - //npy_intp Sbm = PyArray_STRIDES(%(b)s)[0] / PyArray_DESCR(%(b)s)->elsize; - npy_intp Sbn = PyArray_STRIDES(%(b)s)[1] / PyArray_DESCR(%(b)s)->elsize; - npy_intp Sval = PyArray_STRIDES(%(a_val)s)[0] / PyArray_DESCR(%(a_val)s)->elsize; - npy_intp Sind = PyArray_STRIDES(%(a_ind)s)[0] / PyArray_DESCR(%(a_ind)s)->elsize; - npy_intp Sptr = PyArray_STRIDES(%(a_ptr)s)[0] / PyArray_DESCR(%(a_ptr)s)->elsize; - - // pointers to access actual data in the arrays passed as params. - dtype_%(z)s* __restrict__ Dz = (dtype_%(z)s*)PyArray_DATA(%(z)s); - const dtype_%(a_val)s* __restrict__ Dval = (dtype_%(a_val)s*)PyArray_DATA(%(a_val)s); - const npy_int32 * __restrict__ Dind = (npy_int32*)PyArray_DATA(%(a_ind)s); - const npy_int32 * __restrict__ Dptr = (npy_int32*)PyArray_DATA(%(a_ptr)s); - - //npy_intp nnz = PyArray_DIMS(%(a_ind)s)[0]; - - //clear the output array - memset(Dz, 0, M*N*sizeof(dtype_%(z)s)); - - //iterate over the sparse array, making the most of an entry wherever we find it. - // - // Normal matrix matrix multiply: A MxK, B KxN => Z = AB - // for m - // for n - // for k - // z[m, n] += a[m, k] * b[k, n] - // Here instead: Z = - // for k - // for m (sparse) - // for n - // z[m, n] += a[m, k] * b[k, n] - - // loop over inner dimension - for (npy_int32 k = 0; k < K; ++k) - { - // get pointer to k-th row of dense matrix - const dtype_%(b)s* __restrict__ bk = (dtype_%(b)s*)(PyArray_BYTES(%(b)s) + PyArray_STRIDES(%(b)s)[0] * k); - - // loop over sparse column indices through index pointer array - // (amounts to looping over rows M of sparse matrix) - - for (npy_int32 m_idx = Dptr[k * Sptr]; m_idx < Dptr[(k+1) * Sptr]; ++m_idx) - { - npy_int32 m = Dind[m_idx * Sind]; // row index of non-null value for column K - const dtype_%(a_val)s Amk = Dval[m_idx * Sval]; // actual value at that location - - // pointer to m-th row of the output matrix Z - dtype_%(z)s* __restrict__ zm = (dtype_%(z)s*)(PyArray_BYTES(%(z)s) + PyArray_STRIDES(%(z)s)[0] * m); - - //RESOLVE: a.shape[0] equals z.shape[0], why is this not an equality constraint? - if (m >= PyArray_DIMS(%(z)s)[0]) - {PyErr_SetString(PyExc_NotImplementedError, "illegal row index in a"); %(fail)s;} - - // loop over final dimension (cols of dense matrix) and perform dot product - if ((Szn == 1) && (Sbn == 1)) { - for(npy_int32 n = 0; n < N; ++n) - { - zm[n] += Amk * bk[n]; - } - } - else - { - for(npy_int32 n = 0; n < N; ++n) - { - zm[n*Szn] += Amk * bk[n*Sbn]; - } - } - } - } - } - """ % dict( - locals(), **sub - ) - - return rval - - def c_code_cache_version(self): - return (3,) - - -sd_csc = StructuredDotCSC() - - -class StructuredDotCSR(COp): - """ - Structured Dot CSR is like dot, except that only the - gradient wrt non-zero elements of the sparse matrix - `a` are calculated and propagated. - - The output is presumed to be a dense matrix, and is represented by a - TensorType instance. - - Parameters - ---------- - a - A sparse matrix in csr format. - b - A sparse or dense matrix. - - Returns - ------- - matrix - The dot product of `a` and `b`. - - Notes - ----- - The grad implemented is structured. - This op is used as an optimization for StructuredDot. - - """ - - __props__ = () - - def make_node(self, a_val, a_ind, a_ptr, b): - self.dtype_out = aes.upcast(a_val.type.dtype, b.type.dtype) - r = Apply( - self, - [a_val, a_ind, a_ptr, b], - [tensor(self.dtype_out, (False, b.type.broadcastable[1]))], - ) - return r - - def perform(self, node, inputs, outputs): - (a_val, a_ind, a_ptr, b) = inputs - (out,) = outputs - a = scipy.sparse.csr_matrix( - (a_val, a_ind, a_ptr), (len(a_ptr) - 1, b.shape[0]), copy=True - ) # use view_map before setting this to False - # out[0] = a.dot(b) - out[0] = a * b - # scipy 0.7 automatically converts to dense, but not .6 sometimes - assert _is_dense(out[0]) - - def c_code(self, node, name, inputs, outputs, sub): - """ - C-implementation of the dot product of the sparse matrix A and matrix B. - - Parameters - ---------- - a_val - Non-zero values of the sparse matrix. - a_ind - Column indices of the non-null values (.indices of a - scipy.csc_matrix). - a_ptr - Indicates col indices for col. i are in the range - a_ptr[i]:a_ptr[i+1]. - n_cols - Number of columns of sparse matrix. - b - Dense matrix to perform dot product with, as in dot(a, b). - z - Return value. - sub - TODO, not too sure, something to do with weave probably. - - """ - (a_val, a_ind, a_ptr, b) = inputs - (z,) = outputs - typenum_z = TensorType(self.dtype_out, []).dtype_specs()[2] - if node.inputs[0].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for a_val") - if node.inputs[3].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for b") - - return """ - if (PyArray_NDIM(%(a_val)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;} - if (PyArray_NDIM(%(a_ind)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;} - if (PyArray_NDIM(%(a_ptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); %(fail)s;} - if (PyArray_NDIM(%(b)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;} - - if (PyArray_TYPE(%(a_ind)s) != NPY_INT32) { - PyErr_SetString(PyExc_NotImplementedError, "a_ind dtype not INT32"); %(fail)s;} - - if (PyArray_TYPE(%(a_ptr)s) != NPY_INT32) - {PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;} - - if (PyArray_DIMS(%(a_val)s)[0] != PyArray_DIMS(%(a_ind)s)[0]) - {PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;} - - if ((!%(z)s) - || (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(a_ptr)s)[0]-1) //a's rows - || (PyArray_DIMS(%(z)s)[1] != PyArray_DIMS(%(b)s)[1]) //b's columns - ) - { - {Py_XDECREF(%(z)s);} - npy_intp dims[] = {0, 0}; - dims[0] = PyArray_DIMS(%(a_ptr)s)[0]-1; - dims[1] = PyArray_DIMS(%(b)s)[1]; - %(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(typenum_z)s); - } - - { - // sparse array has size MxK, dense KxN, output MxN - npy_intp M = PyArray_DIMS(%(z)s)[0]; - npy_intp N = PyArray_DIMS(%(z)s)[1]; - npy_intp K = PyArray_DIMS(%(b)s)[0]; - if (N > 0x7fffffffL) - {PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); %(fail)s;} - - // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Szm = PyArray_STRIDES(%(z)s)[0] / PyArray_DESCR(%(z)s)->elsize; - npy_intp Szn = PyArray_STRIDES(%(z)s)[1] / PyArray_DESCR(%(z)s)->elsize; - npy_intp Sbm = PyArray_STRIDES(%(b)s)[0] / PyArray_DESCR(%(b)s)->elsize; - npy_intp Sbn = PyArray_STRIDES(%(b)s)[1] / PyArray_DESCR(%(b)s)->elsize; - npy_intp Sval = PyArray_STRIDES(%(a_val)s)[0] / PyArray_DESCR(%(a_val)s)->elsize; - npy_intp Sind = PyArray_STRIDES(%(a_ind)s)[0] / PyArray_DESCR(%(a_ind)s)->elsize; - npy_intp Sptr = PyArray_STRIDES(%(a_ptr)s)[0] / PyArray_DESCR(%(a_ptr)s)->elsize; - - // pointers to access actual data in the arrays passed as params. - dtype_%(z)s* __restrict__ Dz = (dtype_%(z)s*)PyArray_DATA(%(z)s); - const dtype_%(a_val)s* __restrict__ Dval = (dtype_%(a_val)s*)PyArray_DATA(%(a_val)s); - const npy_int32 * __restrict__ Dind = (npy_int32*)PyArray_DATA(%(a_ind)s); - const npy_int32 * __restrict__ Dptr = (npy_int32*)PyArray_DATA(%(a_ptr)s); - - //npy_intp nnz = PyArray_DIMS(%(a_ind)s)[0]; - - //clear the output array - memset(Dz, 0, M*N*sizeof(dtype_%(z)s)); - - //iterate over the sparse array, making the most of an entry wherever we find it. - // Normal matrix matrix multiply: - // for m - // for n - // for k - // z[m, n] += a[m, k] * b[k, n] - // Here instead: - // for m - // for k (sparse) - // for n - // z[m, n] += a[m, k] * b[k, n] - - // loop over inner dimension - for (npy_int64 m = 0; m < M; ++m) - { - // pointer to m-th row of the output matrix Z - dtype_%(z)s* __restrict__ zm = (dtype_%(z)s*)(PyArray_BYTES(%(z)s) + PyArray_STRIDES(%(z)s)[0] * m); - - // loop over sparse rows indices through index pointer array - // (amounts to looping over cols k of sparse matrix) - for (npy_int32 k_idx = Dptr[m * Sptr]; k_idx < Dptr[(m+1) * Sptr]; ++k_idx) - { - npy_int32 k = Dind[k_idx * Sind]; // col index of non-null value for row m - const dtype_%(a_val)s Amk = Dval[k_idx * Sval]; // actual value at that location - - // get pointer to k-th row of dense matrix - const dtype_%(b)s* __restrict__ bk = (dtype_%(b)s*)(PyArray_BYTES(%(b)s) + PyArray_STRIDES(%(b)s)[0] * k); - - // loop over final dimension (cols of dense matrix) and perform dot product - for(npy_int32 n = 0; n < N; ++n) - { - zm[n*Szn] += Amk * bk[n*Sbn]; - } - } - } - } - - """ % dict( - locals(), **sub - ) - - def c_code_cache_version(self): - return (2,) - - -sd_csr = StructuredDotCSR() - - -# register a specialization to replace StructuredDot -> StructuredDotCSx -# This is tested in tests/test_basic.py:792 -@local_optimizer([sparse._structured_dot]) -def local_structured_dot(fgraph, node): - if node.op == sparse._structured_dot: - a, b = node.inputs - if a.type.format == "csc": - a_val, a_ind, a_ptr, a_shape = csm_properties(a) - a_nsparse = a_shape[0] - return [sd_csc(a_val, a_ind, a_ptr, a_nsparse, b)] - if a.type.format == "csr": - a_val, a_ind, a_ptr, a_shape = csm_properties(a) - return [sd_csr(a_val, a_ind, a_ptr, b)] - return False - - -# Commented out because -# a) it is only slightly faster than scipy these days, and sometimes a little -# slower, and -# b) the resulting graphs make it very difficult for an op to do size checking -# on the matrices involved. dimension mismatches are hard to detect sensibly. -# register_specialize(local_structured_dot) - - -class UsmmCscDense(_NoPythonCOp): - """ - Performs the expression is `alpha` * `x` `y` + `z`. - - Parameters - ---------- - x - Matrix variable. - y - Matrix variable. - z - Dense matrix. - alpha - A tensor scalar. - - Returns - ------- - The dense matrix resulting from `alpha` * `x` `y` + `z`. - - Notes - ----- - The grad is not implemented for this op. - Optimized version os Usmm when `x` is in csc format and `y` is dense. - """ - - __props__ = ("inplace",) - - def __init__(self, inplace): - self.inplace = inplace - if inplace: - self.destroy_map = {0: [6]} - - def __str__(self): - if self.inplace: - return "UsmmCscDense{inplace}" - else: - return "UsmmCscDense{no_inplace}" - - def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z): - alpha = as_tensor_variable(alpha) - x_val = as_tensor_variable(x_val) - x_ind = as_tensor_variable(x_ind) - x_ptr = as_tensor_variable(x_ptr) - x_nrows = as_tensor_variable(x_nrows) - y = as_tensor_variable(y) - z = as_tensor_variable(z) - assert x_ind.dtype == "int32" - assert x_ptr.dtype == "int32" - assert x_nrows.dtype == "int32" - assert alpha.ndim == 2 and alpha.type.broadcastable == (True, True) - assert x_val.ndim == 1 - assert y.ndim == 2 - assert z.ndim == 2 - - dtype_out = aes.upcast( - alpha.type.dtype, x_val.type.dtype, y.type.dtype, z.type.dtype - ) - - if dtype_out not in ("float32", "float64"): - raise NotImplementedError("only float types are supported in " "operands") - - if self.inplace: - assert z.type.dtype == dtype_out - - # axpy work only with the same dtype, so we should upcast the input - if dtype_out != alpha.type.dtype: - alpha = cast(alpha, dtype_out) - if dtype_out != x_val.type.dtype: - x_val = cast(x_val, dtype_out) - if dtype_out != y.type.dtype: - y = cast(y, dtype_out) - if dtype_out != z.type.dtype: - z = cast(z, dtype_out) - - r = Apply( - self, - [alpha, x_val, x_ind, x_ptr, x_nrows, y, z], - [tensor(dtype_out, (False, y.type.broadcastable[1]))], - ) - return r - - def c_support_code(self, **kwargs): - return blas.blas_header_text() - - def c_libraries(self, **kwargs): - return blas.ldflags() - - def c_compile_args(self, **kwargs): - return blas.ldflags(libs=False, flags=True) - - def c_lib_dirs(self, **kwargs): - return blas.ldflags(libs=False, libs_dir=True) - - def c_header_dirs(self, **kwargs): - return blas.ldflags(libs=False, include_dir=True) - - def c_code(self, node, name, inputs, outputs, sub): - alpha, x_val, x_ind, x_ptr, x_nrows, y, z = inputs - zn = outputs[0] - if node.inputs[1].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for " "x_val") - if node.inputs[5].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for y") - if node.inputs[6].type.dtype != node.outputs[0].type.dtype: - raise NotImplementedError("z and output must have same type") - - if node.inputs[1].type.dtype == "float32": - conv_type = "float" - axpy = "saxpy_" - else: - conv_type = "double" - axpy = "daxpy_" - # retrieve dtype numbers - typenum_alpha = node.inputs[0].type.dtype_specs()[2] - typenum_x_val = node.inputs[1].type.dtype_specs()[2] - typenum_y = node.inputs[5].type.dtype_specs()[2] - typenum_z = node.inputs[6].type.dtype_specs()[2] - typenum_zn = node.outputs[0].type.dtype_specs()[2] - - inplace = int(self.inplace) - - rval = """ - - if (PyArray_NDIM(%(x_val)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(x_val) != 1"); %(fail)s;} - if (PyArray_NDIM(%(x_ind)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(x_ind) != 1"); %(fail)s;} - if (PyArray_NDIM(%(x_ptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(x_ptr) != 1"); %(fail)s;} - if (PyArray_NDIM(%(x_nrows)s) != 0) {PyErr_SetString(PyExc_NotImplementedError, "rank(nrows) != 0"); %(fail)s;} - if (PyArray_NDIM(%(y)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;} - - if (PyArray_TYPE(%(x_val)s) != %(typenum_x_val)s) { - PyErr_SetString(PyExc_NotImplementedError, "Invalid type for x_val"); %(fail)s;} - - if (PyArray_TYPE(%(y)s) != %(typenum_y)s) { - PyErr_SetString(PyExc_NotImplementedError, "Invalid type for y"); %(fail)s;} - - if (PyArray_TYPE(%(z)s) != %(typenum_z)s) { - PyErr_SetString(PyExc_NotImplementedError, "Invalid type for z"); %(fail)s;} - - if (PyArray_TYPE(%(alpha)s) != %(typenum_alpha)s) { - PyErr_SetString(PyExc_NotImplementedError, "Invalid type for alpha"); %(fail)s;} - - if (PyArray_TYPE(%(x_ind)s) != NPY_INT32) { - PyErr_SetString(PyExc_NotImplementedError, "x_ind dtype not INT32"); %(fail)s;} - - if (PyArray_TYPE(%(x_ptr)s) != NPY_INT32) - {PyErr_SetString(PyExc_NotImplementedError, "x_ptr dtype not INT32"); %(fail)s;} - - if (PyArray_TYPE(%(x_nrows)s) != NPY_INT32) - {PyErr_SetString(PyExc_NotImplementedError, "x_nrows dtype not INT32"); %(fail)s;} - - if (PyArray_DIMS(%(x_val)s)[0] != PyArray_DIMS(%(x_ind)s)[0]) - {PyErr_SetString(PyExc_NotImplementedError, "x_val and x_ind have different lengths"); %(fail)s;} - - if (PyArray_DIMS(%(x_ptr)s)[0] != PyArray_DIMS(%(y)s)[0]+1) - {PyErr_SetString(PyExc_NotImplementedError, "x's number of columns doesn't match y's rows"); %(fail)s;} - - if (PyArray_DIMS(%(z)s)[0] != ((npy_int32 *)PyArray_DATA(%(x_nrows)s))[0] || PyArray_DIMS(%(z)s)[1] != PyArray_DIMS(%(y)s)[1]) - {PyErr_SetString(PyExc_NotImplementedError, "The dimension of the allocated output doesn't match the correct output size."); %(fail)s;} - - if (PyArray_SIZE(%(alpha)s) != 1) - {PyErr_SetString(PyExc_NotImplementedError, "The number of element in alpha must be 1"); %(fail)s;} - - if (PyArray_NDIM(%(alpha)s) != 2) - {PyErr_SetString(PyExc_NotImplementedError, "The number dimension of alpha must be 2"); %(fail)s;} - - if (PyArray_NDIM(%(x_val)s) != 1) - {PyErr_SetString(PyExc_NotImplementedError, "The number dimension of x_val must be 1"); %(fail)s;} - - if (PyArray_NDIM(%(y)s) != 2) - {PyErr_SetString(PyExc_NotImplementedError, "The number dimension of y must be 2"); %(fail)s;} - - if (PyArray_NDIM(%(z)s) != 2) - {PyErr_SetString(PyExc_NotImplementedError, "The number dimension of z must be 2"); %(fail)s;} - - if (%(inplace)s) - { - if (%(typenum_zn)s != %(typenum_z)s) { - PyErr_SetString(PyExc_NotImplementedError, "When inplace the output dtype must be the same as the input"); %(fail)s;} - - Py_XDECREF(%(zn)s); - %(zn)s = %(z)s; - Py_INCREF(%(zn)s); - } - else if (!%(zn)s - || (PyArray_DIMS(%(zn)s)[0] != ((npy_int32 *)PyArray_DATA(%(x_nrows)s))[0]) - || (PyArray_DIMS(%(zn)s)[1] != PyArray_DIMS(%(y)s)[1]) - ) - { - {Py_XDECREF(%(zn)s);} - npy_intp dims[] = {0, 0}; - dims[0] = ((npy_int32 *)PyArray_DATA(%(x_nrows)s))[0]; - dims[1] = PyArray_DIMS(%(y)s)[1]; - %(zn)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(typenum_zn)s); - } - - { - // sparse array has size MxK, dense KxN, output MxN - npy_intp M = PyArray_DIMS(%(zn)s)[0]; - npy_intp N = PyArray_DIMS(%(zn)s)[1]; - npy_intp K = PyArray_DIMS(%(y)s)[0]; - - // pointers to access actual data in the arrays passed as params. - const dtype_%(x_val)s* __restrict__ Dval = (dtype_%(x_val)s*)PyArray_DATA(%(x_val)s); - const npy_int32 * __restrict__ Dind = (npy_int32*)PyArray_DATA(%(x_ind)s); - const npy_int32 * __restrict__ Dptr = (npy_int32*)PyArray_DATA(%(x_ptr)s); - const dtype_%(alpha)s alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; - - npy_intp Sz = PyArray_STRIDES(%(z)s)[1] / PyArray_DESCR(%(z)s)->elsize; - npy_intp Szn = PyArray_STRIDES(%(zn)s)[1] / PyArray_DESCR(%(zn)s)->elsize; - npy_intp Sval = PyArray_STRIDES(%(x_val)s)[0] / PyArray_DESCR(%(x_val)s)->elsize; - npy_intp Sind = PyArray_STRIDES(%(x_ind)s)[0] / PyArray_DESCR(%(x_ind)s)->elsize; - npy_intp Sptr = PyArray_STRIDES(%(x_ptr)s)[0] / PyArray_DESCR(%(x_ptr)s)->elsize; - npy_intp Sy = PyArray_STRIDES(%(y)s)[1] / PyArray_DESCR(%(y)s)->elsize; - - // blas expects ints; convert here (rather than just making N etc ints) to avoid potential overflow in the negative-stride correction - if ((N > 0x7fffffffL)||(Sy > 0x7fffffffL)||(Szn > 0x7fffffffL)||(Sy < -0x7fffffffL)||(Szn < -0x7fffffffL)) - {PyErr_SetString(PyExc_NotImplementedError, "array too big for BLAS (overflows int32 index)"); %(fail)s;} - int N32 = N; - int Sy32 = Sy; - int Szn32 = Szn; - - if (!(%(inplace)s)) - { - if (PyArray_CopyInto(%(zn)s, %(z)s)) - { - Py_XDECREF(%(zn)s); - %(fail)s; - } - } - - for (npy_intp k = 0; k < K; ++k) - { - for (npy_int32 m_idx = Dptr[k * Sptr]; m_idx < Dptr[(k+1)*Sptr]; ++m_idx) - { - const npy_int32 m = Dind[m_idx * Sind]; // row index of non-null value for column K - - const dtype_%(x_val)s Amk = alpha * Dval[m_idx * Sval]; // actual value at that location - - dtype_%(y)s* y_row = (dtype_%(y)s*)(PyArray_BYTES(%(y)s) + PyArray_STRIDES(%(y)s)[0] * k); - // axpy expects pointer to the beginning of memory arrays, - // so when the stride is negative, we need to get the - // last element - if (Sy < 0) - y_row += (K - 1) * Sy; - - dtype_%(zn)s* z_row = (dtype_%(zn)s*)(PyArray_BYTES(%(zn)s) + PyArray_STRIDES(%(zn)s)[0] * m); - if (Szn < 0) - z_row += (N - 1) * Szn; - - %(axpy)s(&N32, (%(conv_type)s*)&Amk, (%(conv_type)s*)y_row, &Sy32, (%(conv_type)s*)z_row, &Szn32); - } - } - } - """ % dict( - locals(), **sub - ) - - return rval - - def c_code_cache_version(self): - return (3, blas.blas_header_version()) - - -usmm_csc_dense = UsmmCscDense(inplace=False) -usmm_csc_dense_inplace = UsmmCscDense(inplace=True) - - -# This is tested in tests/test_basic.py:UsmmTests -local_usmm = PatternSub( - ( - sub, - "z", - ( - mul, - { - "pattern": "alpha", - "constraint": lambda expr: ( - all(expr.type.broadcastable) and config.blas__ldflags - ), - }, - (sparse._dot, "x", "y"), - ), - ), - (usmm, (neg, "alpha"), "x", "y", "z"), -) -register_specialize(local_usmm, name="local_usmm") - - -# register a specialization to replace usmm_csc_dense -> usmm_csc_dense_inplace -# This is tested in tests/test_basic.py:UsmmTests -@local_optimizer([usmm_csc_dense]) -def local_usmm_csc_dense_inplace(fgraph, node): - if node.op == usmm_csc_dense: - return [usmm_csc_dense_inplace(*node.inputs)] - - -register_specialize(local_usmm_csc_dense_inplace, "cxx_only", "inplace") - - -# This is tested in tests/test_basic.py:UsmmTests -@local_optimizer([usmm]) -def local_usmm_csx(fgraph, node): - """ - usmm -> usmm_csc_dense - - """ - if node.op == usmm: - alpha, x, y, z = node.inputs - - x_is_sparse_variable = _is_sparse_variable(x) - y_is_sparse_variable = _is_sparse_variable(y) - - if x_is_sparse_variable and not y_is_sparse_variable: - if x.type.format == "csc": - x_val, x_ind, x_ptr, x_shape = csm_properties(x) - x_nsparse = x_shape[0] - dtype_out = aes.upcast( - alpha.type.dtype, x.type.dtype, y.type.dtype, z.type.dtype - ) - if dtype_out not in ("float32", "float64"): - return False - # Sparse cast is not implemented. - if y.type.dtype != dtype_out: - return False - - return [usmm_csc_dense(alpha, x_val, x_ind, x_ptr, x_nsparse, y, z)] - return False - - -register_specialize(local_usmm_csx, "cxx_only") - - -class CSMGradC(_NoPythonCOp): - - __props__ = () - - def make_node(self, a_val, a_ind, a_ptr, a_dim, b_val, b_ind, b_ptr, b_dim): - return Apply( - self, - [a_val, a_ind, a_ptr, a_dim, b_val, b_ind, b_ptr, b_dim], - [b_val.type()], - ) - - def c_code(self, node, name, inputs, outputs, sub): - # retrieve dtype number - (a_val, a_ind, a_ptr, a_dim, b_val, b_ind, b_ptr, b_dim) = inputs - (z,) = outputs - typenum_z = node.outputs[0].type.dtype_specs()[2] - if node.inputs[0].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for a_val") - if node.inputs[3].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for b_val") - - return """ - if (PyArray_NDIM(%(a_val)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;} - if (PyArray_NDIM(%(a_ind)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;} - if (PyArray_NDIM(%(a_ptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); %(fail)s;} - if (PyArray_NDIM(%(b_val)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b_val) != 1"); %(fail)s;} - if (PyArray_NDIM(%(b_ind)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b_ind) != 1"); %(fail)s;} - if (PyArray_NDIM(%(b_ptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b_ptr) != 1"); %(fail)s;} - - if (PyArray_TYPE(%(a_ind)s) != NPY_INT32) { - PyErr_SetString(PyExc_NotImplementedError, "a_ind dtype not INT32"); %(fail)s;} - - if (PyArray_TYPE(%(a_ptr)s) != NPY_INT32) - {PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;} - - if (PyArray_TYPE(%(b_ind)s) != NPY_INT32) { - PyErr_SetString(PyExc_NotImplementedError, "b_ind dtype not INT32"); %(fail)s;} - - if (PyArray_TYPE(%(b_ptr)s) != NPY_INT32) - {PyErr_SetString(PyExc_NotImplementedError, "b_ptr dtype not INT32"); %(fail)s;} - - if (PyArray_DIMS(%(a_val)s)[0] != PyArray_DIMS(%(a_ind)s)[0]) - {PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;} - - if (PyArray_DIMS(%(b_val)s)[0] != PyArray_DIMS(%(b_ind)s)[0]) - {PyErr_SetString(PyExc_NotImplementedError, "b_val and b_ind have different lengths"); %(fail)s;} - - if (PyArray_DIMS(%(a_ptr)s)[0] != PyArray_DIMS(%(b_ptr)s)[0]) - {PyErr_SetString(PyExc_NotImplementedError, "a_ptr and b_ptr have different lengths"); %(fail)s;} - - if ((!%(z)s) || (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(a_val)s)[0])) - { - {Py_XDECREF(%(z)s);} - npy_intp dims[] = {0}; - dims[0] = PyArray_DIMS(%(a_val)s)[0]; - %(z)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, %(typenum_z)s); - } - - { - // sparse array has size MxK, dense KxN, output MxN - npy_intp M = PyArray_DIMS(%(a_ptr)s)[0] - 1; - npy_intp a_dim_0 = ((npy_int32 *)PyArray_DATA(%(a_dim)s))[0]; - npy_intp a_dim_1 = ((npy_int32 *)PyArray_DATA(%(a_dim)s))[1]; - - npy_intp sp_dim = (M == a_dim_0)?a_dim_1:a_dim_0; - - // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Sz = PyArray_STRIDES(%(z)s)[0] / PyArray_DESCR(%(z)s)->elsize; - npy_intp Sa_val = PyArray_STRIDES(%(a_val)s)[0] / PyArray_DESCR(%(a_val)s)->elsize; - npy_intp Sa_ind = PyArray_STRIDES(%(a_ind)s)[0] / PyArray_DESCR(%(a_ind)s)->elsize; - npy_intp Sa_ptr = PyArray_STRIDES(%(a_ptr)s)[0] / PyArray_DESCR(%(a_ptr)s)->elsize; - npy_intp Sb_val = PyArray_STRIDES(%(b_val)s)[0] / PyArray_DESCR(%(b_val)s)->elsize; - npy_intp Sb_ind = PyArray_STRIDES(%(b_ind)s)[0] / PyArray_DESCR(%(b_ind)s)->elsize; - npy_intp Sb_ptr = PyArray_STRIDES(%(b_ptr)s)[0] / PyArray_DESCR(%(b_ptr)s)->elsize; - - // pointers to access actual data in the arrays passed as params. - dtype_%(z)s* __restrict__ Dz = (dtype_%(z)s*)PyArray_DATA(%(z)s); - const dtype_%(a_val)s* __restrict__ Da_val = (dtype_%(a_val)s*)PyArray_DATA(%(a_val)s); - const npy_int32 * __restrict__ Da_ind = (npy_int32*)PyArray_DATA(%(a_ind)s); - const npy_int32 * __restrict__ Da_ptr = (npy_int32*)PyArray_DATA(%(a_ptr)s); - const dtype_%(b_val)s* __restrict__ Db_val = (dtype_%(b_val)s*)PyArray_DATA(%(b_val)s); - const npy_int32 * __restrict__ Db_ind = (npy_int32*)PyArray_DATA(%(b_ind)s); - const npy_int32 * __restrict__ Db_ptr = (npy_int32*)PyArray_DATA(%(b_ptr)s); - - npy_intp nnz = PyArray_DIMS(%(a_ind)s)[0]; - - dtype_%(b_val)s b_row[sp_dim]; - - //clear the output array - for (npy_int64 i = 0; i < nnz; ++i) - { - Dz[i*Sz] = 0; - } - memset(b_row, 0, sp_dim*sizeof(dtype_%(b_val)s)); - - // loop over inner dimension - for (npy_int64 m = 0; m < M; ++m) - { - for (npy_int32 j_ptr = Db_ptr[m * Sb_ptr]; - j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) { - b_row[Db_ind[j_ptr * Sb_ind]] += Db_val[j_ptr*Sb_val]; - } - - for (npy_int32 j_ptr = Da_ptr[m * Sa_ptr]; - j_ptr < Da_ptr[(m + 1) * Sa_ptr]; j_ptr++) { - Dz[j_ptr*Sz] = b_row[Da_ind[j_ptr * Sa_ind]]; - } - - for (npy_int32 j_ptr = Db_ptr[m * Sb_ptr]; - j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) { - b_row[Db_ind[j_ptr * Sb_ind]] = 0; - } - } - } - - """ % dict( - locals(), **sub - ) - - def c_code_cache_version(self): - return (3,) - - -csm_grad_c = CSMGradC() - - -# register a specialization to replace csm_grad -> csm_grad_c -# This is tested in tests/test_opt.py:test_local_csm_grad_c -@local_optimizer([csm_grad(None)]) -def local_csm_grad_c(fgraph, node): - """ - csm_grad(None) -> csm_grad_c - - """ - if node.op == csm_grad(None): - return [csm_grad_c(*node.inputs)] - return False - - -# DISABLED AS IT IS BROKEN FOR UNSORTED INDICES! -# register_specialize(local_csm_grad_c, 'cxx_only') - - -class MulSDCSC(_NoPythonCOp): - """ - Multiplication of sparse matrix by a broadcasted dense vector - element wise. - - Parameters - ---------- - a_data - Sparse matrix data. - a_indices - Sparse matrix indices. - a_indptr - Sparse matrix indptr. - b - Tensor type matrix. - - Returns - ------- - The multiplication of the two matrices element-wise. - - Notes - ----- - `a_data`, `a_indices` and `a_indptr` must be the properties of a sparse - matrix in csc format. - - The dtype of `a_data`, i.e. the dtype of the sparse matrix, cannot be a - complex type. - - This op is used as an optimization of mul_s_d. - - """ - - __props__ = () - - def make_node(self, a_data, a_indices, a_indptr, b): - assert b.type.ndim == 2 - return Apply( - self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, (False,))] - ) - - def c_code_cache_version(self): - return (3,) - - def c_code(self, node, name, inputs, outputs, sub): - - ( - _data, - _indices, - _indptr, - _b, - ) = inputs - (_zout,) = outputs - if node.inputs[0].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for a") - if node.inputs[3].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for b") - - return """ - if (PyArray_NDIM(%(_b)s) != 2) { - PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); - %(fail)s;} - if (PyArray_NDIM(%(_data)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); - %(fail)s;} - if (PyArray_NDIM(%(_indices)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); - %(fail)s;} - if (PyArray_NDIM(%(_indptr)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); - %(fail)s;} - - if( PyArray_TYPE(%(_indices)s) != NPY_INT32) { - PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} - - if( PyArray_TYPE(%(_indptr)s) != NPY_INT32) - {PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;} - - if (!%(_zout)s || - (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_indices)s)[0]) || - !(PyArray_ISCONTIGUOUS(%(_zout)s))) - { - Py_XDECREF(%(_zout)s); - %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, - PyArray_DIMS(%(_indices)s), PyArray_TYPE(%(_b)s)); - if (!%(_zout)s) - { - PyErr_SetString(PyExc_MemoryError, - "Could not allocate output memory."); - %(fail)s; - } - } - - { //makes it compile even though labels jump over variable definitions. - const npy_intp nnz = PyArray_DIMS(%(_indices)s)[0]; - //TODO: error checking with this - const npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; - - const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s); - const npy_int32 * const __restrict__ indptr = (npy_int32 *)PyArray_DATA(%(_indptr)s); - const npy_int32 * const __restrict__ indices = (npy_int32 *)PyArray_DATA(%(_indices)s); - - dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)PyArray_DATA(%(_zout)s); - - const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0]; - - // loop over columns - for (npy_intp j = 0; j < N; ++j) - { - // for each non-null value in the sparse column - for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx) - { - // extract row index of non-null value - npy_int32 i = indices[i_idx]; - - // extract i-th row of dense matrix - const dtype_%(_b)s* __restrict__ b_row = (dtype_%(_b)s*)(PyArray_BYTES(%(_b)s) + Sb * i); - - // write resulting gradient to sparse output - zout[i_idx] = data[i_idx] * b_row[j]; - } - } - } - - """ % dict( - locals(), **sub - ) - - def __str__(self): - return self.__class__.__name__ - - -mul_s_d_csc = MulSDCSC() - - -class MulSDCSR(_NoPythonCOp): - """ - Multiplication of sparse matrix by a broadcasted dense vector - element wise. - - Parameters - ---------- - a_data - Sparse matrix data. - a_indices - Sparse matrix indices. - a_indptr - Sparse matrix indptr. - b - Tensor type matrix. - - Returns - ------- - The multiplication of the two matrix element wise. - - Notes - ----- - `a_data`, `a_indices` and `a_indptr` must be the properties - of a sparse matrix in csr format. - - The dtype of `a_data`, i.e. the dtype of the sparse matrix, - cannot be a complex type. - - This op is used as an optimization of mul_s_d. - - """ - - __props__ = () - - def make_node(self, a_data, a_indices, a_indptr, b): - assert b.type.ndim == 2 - return Apply( - self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, (False,))] - ) - - def c_code_cache_version(self): - return (3,) - - def c_code(self, node, name, inputs, outputs, sub): - - ( - _data, - _indices, - _indptr, - _b, - ) = inputs - (_zout,) = outputs - if node.inputs[0].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for a") - if node.inputs[3].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for b") - - return """ - if (PyArray_NDIM(%(_b)s) != 2) { - PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); - %(fail)s;} - if (PyArray_NDIM(%(_data)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); - %(fail)s;} - if (PyArray_NDIM(%(_indices)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); - %(fail)s;} - if (PyArray_NDIM(%(_indptr)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); - %(fail)s;} - - if( PyArray_TYPE(%(_indices)s) != NPY_INT32) { - PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} - - if( PyArray_TYPE(%(_indptr)s) != NPY_INT32) - {PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;} - - if (!%(_zout)s || - (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_indices)s)[0]) || - !(PyArray_ISCONTIGUOUS(%(_zout)s))) - { - Py_XDECREF(%(_zout)s); - %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, - PyArray_DIMS(%(_indices)s), PyArray_TYPE(%(_b)s)); - if (!%(_zout)s) - { - PyErr_SetString(PyExc_MemoryError, - "Could not allocate output memory."); - %(fail)s; - } - } - - { //makes it compile even though labels jump over variable definitions. - const npy_intp nnz = PyArray_DIMS(%(_indices)s)[0]; - //TODO: error checking with this - const npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; - - const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s); - const npy_int32 * const __restrict__ indptr = (npy_int32 *)PyArray_DATA(%(_indptr)s); - const npy_int32 * const __restrict__ indices = (npy_int32 *)PyArray_DATA(%(_indices)s); - - dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)PyArray_DATA(%(_zout)s); - - const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0]; - - // loop over columns - for (npy_intp j = 0; j < N; ++j) - { - // extract i-th row of dense matrix - const dtype_%(_b)s* __restrict__ b_row = (dtype_%(_b)s*)(PyArray_BYTES(%(_b)s) + Sb * j); - - // for each non-null value in the sparse column - for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx) - { - // extract row index of non-null value - npy_int32 i = indices[i_idx]; - - // write resulting gradient to sparse output - zout[i_idx] = data[i_idx] * b_row[i]; - } - } - } - - """ % dict( - locals(), **sub - ) - - def __str__(self): - return self.__class__.__name__ - - -mul_s_d_csr = MulSDCSR() - - -# register a specialization to replace MulSD -> MulSDCSX -@local_optimizer([sparse.mul_s_d]) -def local_mul_s_d(fgraph, node): - if node.op == sparse.mul_s_d: - x, y = node.inputs - - x_is_sparse_variable = _is_sparse_variable(x) - - if x_is_sparse_variable: - svar = x - dvar = y - else: - svar = y - dvar = x - - if dvar.type.ndim != 2: - return False - if svar.type.format == "csc": - CSx = sparse.CSC - mul_s_d_csx = mul_s_d_csc - elif svar.type.format == "csr": - CSx = sparse.CSR - mul_s_d_csx = mul_s_d_csr - else: - raise NotImplementedError - if x.dtype != y.dtype: - # mul_s_d_csx don't support that case - return - - c_data = mul_s_d_csx( - sparse.csm_data(svar), - sparse.csm_indices(svar), - sparse.csm_indptr(svar), - dvar, - ) - - return [ - CSx( - c_data, - sparse.csm_indices(svar), - sparse.csm_indptr(svar), - sparse.csm_shape(svar), - ) - ] - - return False - - -register_specialize(local_mul_s_d, "cxx_only") - - -class MulSVCSR(_NoPythonCOp): - """ - Multiplication of sparse matrix by a broadcasted dense vector - element wise. - - Parameters - ---------- - a_data - Sparse matrix data. - a_indices - Sparse matrix indices. - a_indptr - Sparse matrix indptr. - b - Tensor type matrix. - - Returns - ------- - The multiplication of the two matrix element wise. - - Notes - ----- - `a_data`, `a_indices` and `a_indptr` must be the properties - of a sparse matrix in csr format. - - The dtype of `a_data`, i.e. the dtype of the sparse matrix, - cannot be a complex type. - - This op is used as an optimization of MulSV. - - """ - - __props__ = () - - def make_node(self, a_data, a_indices, a_indptr, b): - assert b.type.ndim == 1 - return Apply( - self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, (False,))] - ) - - def c_code_cache_version(self): - return (2,) - - def c_code(self, node, name, inputs, outputs, sub): - ( - _data, - _indices, - _indptr, - _b, - ) = inputs - (_zout,) = outputs - if node.inputs[0].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for a") - if node.inputs[3].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for b") - - return """ - if (PyArray_NDIM(%(_b)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); - %(fail)s; - } - if (PyArray_NDIM(%(_data)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); - %(fail)s; - } - if (PyArray_NDIM(%(_indices)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); - %(fail)s; - } - if (PyArray_NDIM(%(_indptr)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); - %(fail)s; - } - - if( PyArray_TYPE(%(_indices)s) != NPY_INT32) { - PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} - - if( PyArray_TYPE(%(_indptr)s) != NPY_INT32) - {PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;} - - if (!%(_zout)s - || PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_indices)s)[0] - || !PyArray_ISCONTIGUOUS(%(_zout)s)) - { - Py_XDECREF(%(_zout)s); - %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, - PyArray_DIMS(%(_indices)s), PyArray_TYPE(%(_b)s)); - } - - { //makes it compile even though labels jump over variable definitions. - const npy_intp nnz = PyArray_DIMS(%(_indices)s)[0]; - //TODO: error checking with this - const npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; - - const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s); - const npy_int32 * const __restrict__ indptr = (npy_int32 *)PyArray_DATA(%(_indptr)s); - const npy_int32 * const __restrict__ indices = (npy_int32 *)PyArray_DATA(%(_indices)s); - - const dtype_%(_b)s* __restrict__ Db = (dtype_%(_b)s*)PyArray_DATA(%(_b)s); - - dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)PyArray_DATA(%(_zout)s); - - const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0] / PyArray_DESCR(%(_b)s)->elsize; - - // loop over rows - for (npy_intp j = 0; j < N; ++j) - { - // for each non-null value in the sparse column - for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx) - { - // extract row index of non-null value - npy_int32 i = indices[i_idx]; - - zout[i_idx] = data[i_idx] * Db[i * Sb]; - } - } - } - - """ % dict( - locals(), **sub - ) - - def __str__(self): - return self.__class__.__name__ - - -mul_s_v_csr = MulSVCSR() - - -# register a specialization to replace MulSV -> MulSVCSR -@local_optimizer([sparse.mul_s_v]) -def local_mul_s_v(fgraph, node): - if node.op == sparse.mul_s_v: - x, y = node.inputs - - x_is_sparse_variable = _is_sparse_variable(x) - - if x_is_sparse_variable: - svar = x - dvar = y - else: - svar = y - dvar = x - - if dvar.type.ndim != 1: - return False - elif svar.type.format == "csr": - CSx = sparse.CSR - mul_s_v_csx = mul_s_v_csr - else: - return False - - s_val, s_ind, s_ptr, s_shape = sparse.csm_properties(svar) - - c_data = mul_s_v_csx(s_val, s_ind, s_ptr, dvar) - - return [CSx(c_data, s_ind, s_ptr, s_shape)] - - return False - - -register_specialize(local_mul_s_v, "cxx_only") - - -class StructuredAddSVCSR(_NoPythonCOp): - """ - Structured addition of a sparse matrix and a dense vector. - The elements of the vector are are only added to the corresponding - non-zero elements. Therefore, this operation outputs another sparse - matrix. - - Parameters - ---------- - a_data - Sparse matrix data. - a_indices - Sparse matrix indices. - a_indptr - Sparse matrix indptr. - b - Tensor type vector. - - Returns - ------- - A sparse matrix containing the addition of the vector to the data of the - sparse matrix. - - Notes - ----- - The a_* are the properties of a sparse matrix in csr format. - - This op is used as an optimization for StructuredAddSV. - - """ - - __props__ = () - - def make_node(self, a_data, a_indices, a_indptr, b): - b = as_tensor_variable(b) - a_data = as_tensor_variable(a_data) - a_indices = as_tensor_variable(a_indices) - a_indptr = as_tensor_variable(a_indptr) - assert a_data.type.ndim == 1 - assert a_indices.type.ndim == 1 - assert a_indptr.type.ndim == 1 - assert b.type.ndim == 1 - return Apply( - self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, (False,))] - ) - - def c_code_cache_version(self): - return (3,) - - def c_code(self, node, name, inputs, outputs, sub): - ( - _data, - _indices, - _indptr, - _b, - ) = inputs - (_zout,) = outputs - if node.inputs[0].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for a") - if node.inputs[3].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for b") - - return """ - if (PyArray_NDIM(%(_b)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); - %(fail)s; - } - if (PyArray_NDIM(%(_data)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); - %(fail)s; - } - if (PyArray_NDIM(%(_indices)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); - %(fail)s; - } - if (PyArray_NDIM(%(_indptr)s) != 1) { - PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); - %(fail)s; - } - - if( PyArray_TYPE(%(_indices)s) != NPY_INT32) { - PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} - - if( PyArray_TYPE(%(_indptr)s) != NPY_INT32) - {PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;} - - if (!%(_zout)s - || (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_indices)s)[0]) - || !(PyArray_ISCONTIGUOUS(%(_zout)s))) - { - Py_XDECREF(%(_zout)s); - %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, - PyArray_DIMS(%(_indices)s), PyArray_TYPE(%(_b)s)); - if (!%(_zout)s) - { - PyErr_SetString(PyExc_MemoryError, - "Could not allocate output memory."); - %(fail)s; - } - } - - { //makes it compile even though labels jump over variable definitions. - const npy_intp nnz = PyArray_DIMS(%(_indices)s)[0]; - //TODO: error checking with this - const npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; - - const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s); - const npy_int32 * const __restrict__ indptr = (npy_int32 *)PyArray_DATA(%(_indptr)s); - const npy_int32 * const __restrict__ indices = (npy_int32 *)PyArray_DATA(%(_indices)s); - - const dtype_%(_b)s* __restrict__ Db = (dtype_%(_b)s*)PyArray_DATA(%(_b)s); - - dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)PyArray_DATA(%(_zout)s); - - const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0] / PyArray_DESCR(%(_b)s)->elsize; - - // loop over columns - for (npy_intp j = 0; j < N; ++j) - { - // for each non-null value in the sparse column - for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx) - { - // extract row index of non-null value - npy_int32 i = indices[i_idx]; - - // write resulting gradient to sparse output - zout[i_idx] = data[i_idx] + Db[i * Sb]; - } - } - } - - """ % dict( - locals(), **sub - ) - - def __str__(self): - return self.__class__.__name__ - - -structured_add_s_v_csr = StructuredAddSVCSR() - - -# register a specialization to replace -# structured_add_s_v -> structured_add_s_v_csr -@local_optimizer([sparse.structured_add_s_v]) -def local_structured_add_s_v(fgraph, node): - if node.op == sparse.structured_add_s_v: - x, y = node.inputs - - x_is_sparse_variable = _is_sparse_variable(x) - # y_is_sparse_variable = _is_sparse_variable(y) - - if x_is_sparse_variable: - svar = x - dvar = y - else: - svar = y - dvar = x - - if dvar.type.ndim != 1: - return False - elif svar.type.format == "csr": - CSx = sparse.CSR - structured_add_s_v_csx = structured_add_s_v_csr - else: - return False - - s_val, s_ind, s_ptr, s_shape = sparse.csm_properties(svar) - - c_data = structured_add_s_v_csx(s_val, s_ind, s_ptr, dvar) - - return [CSx(c_data, s_ind, s_ptr, s_shape)] - - return False - - -register_specialize(local_structured_add_s_v, "cxx_only") - - -class SamplingDotCSR(_NoPythonCOp): - r""" - Operand optimized for calculating the dot product :math:`x y^\top = z` - when you only want to calculate a subset of :math:`z`. - - It is equivalent to :math:`p \circ (x \cdot y^\top)` where :math:`\circ` is - the element-wise product, :math:`x` and :math:`y` operands of the dot - product, and :math:`p` is a matrix that contains 1 when the corresponding - element of :math:`z` should be calculated and 0 when it shouldn't. Note - that `SamplingDot` has a different interface than ``dot`` because - `SamplingDot` requires :math:`x` to be a :math:`m \times k` matrix while - :math:`y` is a :math:`n \times k` matrix instead of the usual :math:``k - \times n` matrix. - - Parameters - ---------- - x - Tensor matrix. - y - Tensor matrix. - p_data - Sparse matrix data. - p_ind - Sparse matrix indices. - p_ptr - Sparse matric indptr. - p_ncols - Sparse matrix number of columns. - - Returns - ------- - A dense matrix containing the dot product of :math:`x` by :math:`y^\top` only - where :math:`p` is 1. - - Notes - ----- - It will work if the pattern is not binary value, but if the - pattern doesn't have a high sparsity proportion it will be slower - then a more optimized dot followed by a normal elemwise - multiplication. - - If we have the input of mixed dtype, we insert cast elemwise - in the graph to be able to call BLAS function as they don't - allow mixed dtype. - - This `Op` is used as an optimization for `SamplingDot`. - - """ - - __props__ = () - - def make_node(self, x, y, p_data, p_ind, p_ptr, p_ncols): - x = as_tensor_variable(x) - y = as_tensor_variable(y) - p_data = as_tensor_variable(p_data) - p_ind = as_tensor_variable(p_ind) - p_ptr = as_tensor_variable(p_ptr) - p_ncols = as_tensor_variable(p_ncols) - - assert p_ncols.dtype == "int32" - - dtype_out = aes.upcast(x.type.dtype, y.type.dtype, p_data.type.dtype) - dot_out = aes.upcast(x.type.dtype, y.type.dtype) - - # We call blas ?dot function that take only param of the same type - x = cast(x, dot_out) - y = cast(y, dot_out) - - return Apply( - self, - [x, y, p_data, p_ind, p_ptr, p_ncols], - [ - tensor(dtype=dtype_out, shape=(False,)), - tensor(dtype=p_ind.type.dtype, shape=(False,)), - tensor(dtype=p_ptr.type.dtype, shape=(False,)), - ], - ) - - def c_code_cache_version(self): - return (4, blas.blas_header_version()) - - def c_support_code(self, **kwargs): - return blas.blas_header_text() - - def c_libraries(self, **kwargs): - return blas.ldflags() - - def c_compile_args(self, **kwargs): - return blas.ldflags(libs=False, flags=True) - - def c_lib_dirs(self, **kwargs): - return blas.ldflags(libs=False, libs_dir=True) - - def c_header_dirs(self, **kwargs): - return blas.ldflags(libs=False, include_dir=True) - - def c_code(self, node, name, inputs, outputs, sub): - x, y, p_data, p_ind, p_ptr, p_ncols = inputs - z_data, z_ind, z_ptr = outputs - if node.inputs[0].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for x") - if node.inputs[1].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for y") - if node.inputs[2].type.dtype in ("complex64", "complex128"): - raise NotImplementedError("Complex types are not supported for pattern") - - dot_out = aes.upcast(node.inputs[0].type.dtype, node.inputs[1].type.dtype) - - if dot_out == "float32": - conv_type = "float" - cdot = "sdot_" - else: - conv_type = "double" - cdot = "ddot_" - - # retrieve dtype number - typenum_x = node.inputs[0].type.dtype_specs()[2] - typenum_y = node.inputs[1].type.dtype_specs()[2] - typenum_p = node.inputs[2].type.dtype_specs()[2] - typenum_zd = TensorType(node.outputs[0].dtype, []).dtype_specs()[2] - typenum_zi = TensorType(node.outputs[1].dtype, []).dtype_specs()[2] - typenum_zp = TensorType(node.outputs[2].dtype, []).dtype_specs()[2] - - rval = """ - if (PyArray_NDIM(%(x)s) != 2) { -PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;} - if (PyArray_NDIM(%(y)s) != 2) { -PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;} - - if (PyArray_TYPE(%(x)s) != %(typenum_x)s) { - PyErr_SetString(PyExc_NotImplementedError, - "Invalid type for x"); - %(fail)s;} - - if (PyArray_TYPE(%(y)s) != %(typenum_y)s) { - PyErr_SetString(PyExc_NotImplementedError, - "Invalid type for y"); - %(fail)s;} - - if (PyArray_TYPE(%(p_data)s) != %(typenum_p)s) { - PyErr_SetString(PyExc_NotImplementedError, - "Invalid type for pattern"); - %(fail)s;} - - if (PyArray_DIMS(%(x)s)[1] != PyArray_DIMS(%(y)s)[1]) { - PyErr_SetString(PyExc_NotImplementedError, - "x's number of columns doesn't match y's rows! Note: sampling_dot is different from dot because y is assumed to be transposed."); - %(fail)s;} - - if (PyArray_DIMS(%(y)s)[0] != ((npy_int32 *)PyArray_DATA(%(p_ncols)s))[0] || - PyArray_DIMS(%(x)s)[0] != (PyArray_DIMS(%(p_ptr)s)[0] - 1)) - {PyErr_SetString(PyExc_NotImplementedError, - "The dimension of the pattern and the output must match"); %(fail)s;} - - // Allocate output - if (!%(z_data)s - || (PyArray_DIMS(%(z_data)s)[0] != PyArray_DIMS(%(p_data)s)[0]) - || (PyArray_TYPE(%(z_data)s) != %(typenum_zd)s) - || !(PyArray_ISCONTIGUOUS(%(z_data)s))) - { - {Py_XDECREF(%(z_data)s);} - npy_intp dims[] = {0}; - dims[0] = PyArray_DIMS(%(p_data)s)[0]; - %(z_data)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, - %(typenum_zd)s); - } - if (!%(z_ind)s - || (PyArray_DIMS(%(z_ind)s)[0] != PyArray_DIMS(%(p_ind)s)[0]) - || (PyArray_TYPE(%(z_ind)s) != %(typenum_zi)s) - || !(PyArray_ISCONTIGUOUS(%(z_ind)s))) - { - {Py_XDECREF(%(z_ind)s);} - npy_intp dims[] = {0}; - dims[0] = PyArray_DIMS(%(p_ind)s)[0]; - %(z_ind)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, - %(typenum_zi)s); - } - if (!%(z_ptr)s - || (PyArray_DIMS(%(z_ptr)s)[0] != PyArray_DIMS(%(p_ptr)s)[0]) - || (PyArray_TYPE(%(z_ptr)s) != %(typenum_zp)s) - || !(PyArray_ISCONTIGUOUS(%(z_ptr)s))) - { - {Py_XDECREF(%(z_ptr)s);} - npy_intp dims[] = {0}; - dims[0] = PyArray_DIMS(%(p_ptr)s)[0]; - %(z_ptr)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, - %(typenum_zp)s); - } - - { - // Product of MxK and NxK, output MxN - npy_intp M = PyArray_DIMS(%(x)s)[0]; - npy_intp N = PyArray_DIMS(%(y)s)[0]; - npy_intp K = PyArray_DIMS(%(y)s)[1]; - - // pointers to access actual data in the arrays passed as params. - const dtype_%(x)s* __restrict__ Dx = (dtype_%(x)s*)PyArray_DATA(%(x)s); - const dtype_%(y)s* __restrict__ Dy = (dtype_%(y)s*)PyArray_DATA(%(y)s); - const dtype_%(p_data)s* __restrict__ Dpd = (dtype_%(p_data)s*)PyArray_DATA(%(p_data)s); - const dtype_%(p_ind)s* __restrict__ Dpi = (dtype_%(p_ind)s*)PyArray_DATA(%(p_ind)s); - const dtype_%(p_ptr)s* __restrict__ Dpp = (dtype_%(p_ptr)s*)PyArray_DATA(%(p_ptr)s); - dtype_%(z_data)s* __restrict__ Dzd = (dtype_%(z_data)s*)PyArray_DATA(%(z_data)s); - dtype_%(z_ind)s* __restrict__ Dzi = (dtype_%(z_ind)s*)PyArray_DATA(%(z_ind)s); - dtype_%(z_ptr)s* __restrict__ Dzp = (dtype_%(z_ptr)s*)PyArray_DATA(%(z_ptr)s); - - const npy_intp Sdx = PyArray_STRIDES(%(x)s)[1]/PyArray_DESCR(%(x)s)->elsize; - const npy_intp Sdy = PyArray_STRIDES(%(y)s)[1]/PyArray_DESCR(%(y)s)->elsize; - const npy_intp Sdpd = PyArray_STRIDES(%(p_data)s)[0] / PyArray_DESCR(%(p_data)s)->elsize; - const npy_intp Sdpi = PyArray_STRIDES(%(p_ind)s)[0] / PyArray_DESCR(%(p_ind)s)->elsize; - const npy_intp Sdpp = PyArray_STRIDES(%(p_ptr)s)[0] / PyArray_DESCR(%(p_ptr)s)->elsize; - const npy_intp Sdzd = PyArray_STRIDES(%(z_data)s)[0] / PyArray_DESCR(%(z_data)s)->elsize; - const npy_intp Sdzi = PyArray_STRIDES(%(z_ind)s)[0] / PyArray_DESCR(%(z_ind)s)->elsize; - const npy_intp Sdzp = PyArray_STRIDES(%(z_ptr)s)[0] / PyArray_DESCR(%(z_ptr)s)->elsize; - - memcpy(Dzi, Dpi, PyArray_DIMS(%(p_ind)s)[0]*sizeof(dtype_%(p_ind)s)); - memcpy(Dzp, Dpp, PyArray_DIMS(%(p_ptr)s)[0]*sizeof(dtype_%(p_ptr)s)); - - // blas expects ints; convert here (rather than just making K etc ints) to avoid potential overflow in the negative-stride correction - if ((K > 0x7fffffffL)||(Sdx > 0x7fffffffL)||(Sdy > 0x7fffffffL)||(Sdx < -0x7fffffffL)||(Sdy < -0x7fffffffL)) - {PyErr_SetString(PyExc_NotImplementedError, "array too big for BLAS (overflows int32 index)"); %(fail)s;} - int K32 = K; - int Sdx32 = Sdx; - int Sdy32 = Sdy; - - for (npy_intp m = 0; m < M; ++m) { - for (npy_int32 n_idx = Dpp[m * Sdpp]; n_idx < Dpp[(m+1)*Sdpp]; ++n_idx) { - const npy_int32 n = Dpi[n_idx * Sdpi]; // row index of non-null value for column K - - const dtype_%(x)s* x_row = (dtype_%(x)s*)(PyArray_BYTES(%(x)s) + PyArray_STRIDES(%(x)s)[0] * m); - - const dtype_%(y)s* y_col = (dtype_%(y)s*)(PyArray_BYTES(%(y)s) + PyArray_STRIDES(%(y)s)[0] * n); - // dot expects pointer to the beginning of memory arrays, - // so when the stride is negative, we need to get the - // last element - if (Sdx < 0) - x_row += (K - 1) * Sdx; - if (Sdy < 0) - y_col += (K - 1) * Sdy; - - Dzd[n_idx * Sdzd] = Dpd[n_idx * Sdpd] * %(cdot)s(&K32, (const %(conv_type)s*)x_row, &Sdx32, (const %(conv_type)s*)y_col, &Sdy32); - } - } - } - """ % dict( - locals(), **sub - ) - - return rval - - -sampling_dot_csr = SamplingDotCSR() - - -# register a specialization to replace SamplingDot -> SamplingDotCsr -@local_optimizer([sparse.sampling_dot]) -def local_sampling_dot_csr(fgraph, node): - if not config.blas__ldflags: - # The C implementation of SamplingDotCsr relies on BLAS routines - return - if node.op == sparse.sampling_dot: - x, y, p = node.inputs - if p.type.format == "csr": - p_data, p_ind, p_ptr, p_shape = sparse.csm_properties(p) - - z_data, z_ind, z_ptr = sampling_dot_csr( - x, y, p_data, p_ind, p_ptr, p_shape[1] - ) - # This is a hack that works around some missing `Type`-related - # static shape narrowing. More specifically, - # `TensorType.convert_variable` currently won't combine the static - # shape information from `old_out.type` and `new_out.type`, only - # the broadcast patterns, and, since `CSR.make_node` doesn't do - # that either, we use `specify_shape` to produce an output `Type` - # with the same level of static shape information as the original - # `old_out`. - old_out = node.outputs[0] - new_out = specify_shape( - sparse.CSR(z_data, z_ind, z_ptr, p_shape), shape(old_out) - ) - return [new_out] - return False - - -register_specialize(local_sampling_dot_csr, "cxx_only", name="local_sampling_dot_csr") +from aesara.sparse.rewriting import * # noqa: F401 E402 F403 diff --git a/aesara/sparse/rewriting.py b/aesara/sparse/rewriting.py new file mode 100644 index 0000000000..fde57a30ac --- /dev/null +++ b/aesara/sparse/rewriting.py @@ -0,0 +1,2065 @@ +import scipy + +import aesara +import aesara.scalar as aes +from aesara.configdefaults import config +from aesara.graph.basic import Apply +from aesara.graph.rewriting.basic import ( + PatternNodeRewriter, + WalkingGraphRewriter, + node_rewriter, +) +from aesara.link.c.op import COp, _NoPythonCOp +from aesara.misc.safe_asarray import _asarray +from aesara.sparse import basic as sparse +from aesara.sparse.basic import ( + CSC, + CSR, + csm_data, + csm_grad, + csm_indices, + csm_indptr, + csm_properties, + usmm, +) +from aesara.tensor import blas +from aesara.tensor.basic import as_tensor_variable, cast +from aesara.tensor.math import mul, neg, sub +from aesara.tensor.rewriting.basic import register_canonicalize, register_specialize +from aesara.tensor.shape import shape, specify_shape +from aesara.tensor.type import TensorType, tensor + + +_is_sparse_variable = sparse._is_sparse_variable +_is_dense = sparse._is_dense + + +@node_rewriter([csm_properties]) +def local_csm_properties_csm(fgraph, node): + """ + If we find csm_properties(CSM(*args)), then we can replace that with the + *args directly. + + """ + if node.op == csm_properties: + (csm,) = node.inputs + if csm.owner and (csm.owner.op == CSC or csm.owner.op == CSR): + return csm.owner.inputs + + return False + + +register_specialize(local_csm_properties_csm) + + +# This is tested in tests/test_basic.py:test_remove0 +@node_rewriter([sparse.Remove0]) +def local_inplace_remove0(fgraph, node): + """Rewrite to insert inplace versions of `Remove0`.""" + # If inplace is not enabled, enable it and replace that op with a + # new op which has inplace enabled + if isinstance(node.op, sparse.Remove0) and not node.op.inplace: + new_op = node.op.__class__(inplace=True) + new_node = new_op(*node.inputs) + return [new_node] + return False + + +aesara.compile.optdb.register( + "local_inplace_remove0", + WalkingGraphRewriter( + local_inplace_remove0, failure_callback=WalkingGraphRewriter.warn_inplace + ), + "fast_run", + "inplace", + position=60, +) + + +class AddSD_ccode(_NoPythonCOp): + """ + Add a sparse and a dense matrix. + + Parameters + ---------- + x + A sparse matrix. + y + A dense matrix + + Returns + ------- + matrix + `x`+`y` + + Notes + ----- + The grad implemented is structured on `x`. + + """ + + __props__ = ("format", "inplace") + + def __init__(self, format, inplace=False, *args, **kwargs): + super().__init__(*args, **kwargs) + # Should we do inplace addition or not ? + self.inplace = inplace + self.format = format + if self.inplace: + self.destroy_map = {0: [3]} + + def __str__(self): + inp = "" + if self.inplace: + inp = ",inplace" + return f"{self.__class__.__name__}{{{self.format}{inp}}}" + + def make_node(self, x, y): + x, y = sparse.as_sparse_variable(x), as_tensor_variable(y) + out_dtype = aes.upcast(x.type.dtype, y.type.dtype) + if self.inplace: + assert out_dtype == y.dtype + + indices, indptr, data = csm_indices(x), csm_indptr(x), csm_data(x) + # We either use CSC or CSR depending on the format of input + assert self.format == x.type.format + # The magic number two here arises because L{scipy.sparse} + # objects must be matrices (have dimension 2) + assert y.type.ndim == 2 + out = TensorType(dtype=out_dtype, shape=y.type.broadcastable)() + return Apply(self, [data, indices, indptr, y], [out]) + + def c_code(self, node, name, inputs, outputs, sub): + (_data, _indices, _indptr, y) = inputs + (z,) = outputs + inplace = int(self.inplace) + format = {"csc": 0, "csr": 1}[self.format] + out_typenum = node.outputs[0].type.dtype_specs()[2] + code = """ + Py_XDECREF(%(z)s); + if (!%(inplace)s){ + if(PyArray_TYPE(%(y)s) != %(out_typenum)s){ + %(z)s = (PyArrayObject *) PyArray_FromArray(%(y)s, PyArray_DescrFromType(%(out_typenum)s), 0); + }else{ + %(z)s = (PyArrayObject *) PyArray_NewCopy(%(y)s, NPY_CORDER); + } + }else{ + %(z)s = %(y)s; + Py_XINCREF(%(z)s); + } + + npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; + + const dtype_%(_indptr)s* __restrict__ indptr = (dtype_%(_indptr)s*)PyArray_DATA(%(_indptr)s); + const dtype_%(_indices)s* __restrict__ indices = (dtype_%(_indices)s*)PyArray_DATA(%(_indices)s); + const dtype_%(_data)s* __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s); + + dtype_%(y)s* ydata = (dtype_%(y)s*)PyArray_DATA(%(y)s); + dtype_%(z)s* zdata = (dtype_%(z)s*)PyArray_DATA(%(z)s); + npy_intp Yi = PyArray_STRIDES(%(y)s)[0]/PyArray_DESCR(%(y)s)->elsize; + npy_intp Yj = PyArray_STRIDES(%(y)s)[1]/PyArray_DESCR(%(y)s)->elsize; + + npy_intp pos; + if (%(format)s == 0){ + for (npy_intp col = 0; col < N; ++col){ + for (dtype_%(_indptr)s ind = indptr[col]; ind < indptr[col+1]; ++ind){ + npy_intp row = indices[ind]; + pos = row * Yi + col * Yj; + zdata[pos] = ydata[pos] + data[ind]; + } + } + }else{ + for (npy_intp row = 0; row < N; ++row){ + for (dtype_%(_indptr)s ind = indptr[row]; ind < indptr[row+1]; ++ind){ + npy_intp col = indices[ind]; + pos = row * Yi + col * Yj; + zdata[pos] = ydata[pos] + data[ind]; + } + } + } + """ % dict( + locals(), **sub + ) + return code + + def infer_shape(self, fgraph, node, shapes): + return [shapes[3]] + + def c_code_cache_version(self): + return (2,) + + +@node_rewriter([sparse.AddSD]) +def local_inplace_addsd_ccode(fgraph, node): + """Rewrite to insert inplace versions of `AddSD`.""" + if isinstance(node.op, sparse.AddSD) and config.cxx: + out_dtype = aes.upcast(*node.inputs) + if out_dtype != node.inputs[1].dtype: + return + new_node = AddSD_ccode(format=node.inputs[0].type.format, inplace=True)( + *node.inputs + ) + return [new_node] + return False + + +aesara.compile.optdb.register( + "local_inplace_addsd_ccode", + WalkingGraphRewriter( + local_inplace_addsd_ccode, failure_callback=WalkingGraphRewriter.warn_inplace + ), + "fast_run", + "inplace", + position=60, +) + + +@register_canonicalize("fast_compile") +@register_specialize +@node_rewriter([sparse.DenseFromSparse]) +def local_dense_from_sparse_sparse_from_dense(fgraph, node): + if isinstance(node.op, sparse.DenseFromSparse): + inp = node.inputs[0] + if inp.owner and isinstance(inp.owner.op, sparse.SparseFromDense): + return inp.owner.inputs + + +@node_rewriter([sparse.AddSD]) +def local_addsd_ccode(fgraph, node): + """ + Convert AddSD to faster AddSD_ccode. + + """ + if isinstance(node.op, sparse.AddSD) and config.cxx: + new_node = AddSD_ccode(format=node.inputs[0].type.format)(*node.inputs) + return [new_node] + return False + + +aesara.compile.optdb.register( + "local_addsd_ccode", + WalkingGraphRewriter(local_addsd_ccode), + # Must be after local_inplace_addsd_ccode at 60 + "fast_run", + position=61, +) + + +class StructuredDotCSC(COp): + """ + Structured Dot CSC is like `dot`, except that only the gradient wrt non-zero + elements of a sparse matrix are calculated and propagated. + + The output is presumed to be a dense matrix, and is represented by a + `TensorType` instance. + + Notes + ----- + The gradient. implemented is structured. + + This `Op` is used as a rewritten form of `StructuredDot`. + + """ + + __props__ = () + + def make_node(self, a_val, a_ind, a_ptr, a_nrows, b): + dtype_out = aes.upcast(a_val.type.dtype, b.type.dtype) + r = Apply( + self, + [a_val, a_ind, a_ptr, a_nrows, b], + [tensor(dtype_out, (False, b.type.broadcastable[1]))], + ) + return r + + def perform(self, node, inputs, outputs): + (a_val, a_ind, a_ptr, a_nrows, b) = inputs + (out,) = outputs + a = scipy.sparse.csc_matrix( + (a_val, a_ind, a_ptr), (a_nrows, b.shape[0]), copy=False + ) + # out[0] = a.dot(b) + out[0] = _asarray(a * b, dtype=node.outputs[0].type.dtype) + assert _is_dense(out[0]) # scipy 0.7 automatically converts to dense + + def c_code(self, node, name, inputs, outputs, sub): + # C-implementation of the dot product of the sparse matrix A and matrix + # B. + # @param a_val: non-zero values of the sparse matrix + # @param a_ind: column indices of the non-null values (.indices of a + # scipy.csc_matrix) + # @param a_ptr: a_ptr indicates col indices for col. i are in the range + # a_ptr[i]:a_ptr[i+1] + # @param n_rows: number of rows of sparse matrix + # @param b: dense matrix to perform dot product with, as in dot(a, b) + # @param z: return value + # @param sub: TODO, not too sure, something to do with weave probably + + (a_val, a_ind, a_ptr, a_nrows, b) = inputs + (z,) = outputs + if node.inputs[0].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for a_val") + if node.inputs[4].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for b") + + typenum_z = node.outputs[0].type.dtype_specs()[2] # retrieve dtype number + typenum_a_val = node.inputs[0].type.dtype_specs()[2] # retrieve dtype number + typenum_b = node.inputs[4].type.dtype_specs()[2] # retrieve dtype number + + rval = """ + + if (PyArray_NDIM(%(a_val)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;} + if (PyArray_NDIM(%(a_ind)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;} + if (PyArray_NDIM(%(a_ptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); %(fail)s;} + if (PyArray_NDIM(%(a_nrows)s) != 0) {PyErr_SetString(PyExc_NotImplementedError, "rank(nrows) != 0"); %(fail)s;} + if (PyArray_NDIM(%(b)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;} + + if (PyArray_TYPE(%(a_val)s) != %(typenum_a_val)s) { + PyErr_SetString(PyExc_NotImplementedError, "Invalid type for a_val"); %(fail)s;} + + if (PyArray_TYPE(%(b)s) != %(typenum_b)s) { + PyErr_SetString(PyExc_NotImplementedError, "Invalid type for b"); %(fail)s;} + + if (PyArray_TYPE(%(a_ind)s) != NPY_INT32) { + PyErr_SetString(PyExc_NotImplementedError, "a_ind dtype not INT32"); %(fail)s;} + + if (PyArray_TYPE(%(a_ptr)s) != NPY_INT32) + {PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;} + + if (PyArray_TYPE(%(a_nrows)s) != NPY_INT32) + {PyErr_SetString(PyExc_NotImplementedError, "a_nrows dtype not INT32"); %(fail)s;} + + if (PyArray_DIMS(%(a_val)s)[0] != PyArray_DIMS(%(a_ind)s)[0]) + {PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;} + + if (PyArray_DIMS(%(a_ptr)s)[0] != PyArray_DIMS(%(b)s)[0]+1) + {PyErr_SetString(PyExc_NotImplementedError, "a's number of columns doesn't match b's rows"); %(fail)s;} + + if ((!%(z)s) + || (PyArray_DIMS(%(z)s)[0] != ((npy_int32 *)PyArray_DATA(%(a_nrows)s))[0]) + || (PyArray_DIMS(%(z)s)[1] != PyArray_DIMS(%(b)s)[1]) + ) + { + {Py_XDECREF(%(z)s);} + npy_intp dims[] = {0, 0}; + dims[0] = ((npy_int32 *)PyArray_DATA(%(a_nrows)s))[0]; + dims[1] = PyArray_DIMS(%(b)s)[1]; + %(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(typenum_z)s); + } + + { + // sparse array has size MxK, dense KxN, output MxN + npy_intp M = PyArray_DIMS(%(z)s)[0]; + npy_intp N = PyArray_DIMS(%(z)s)[1]; + npy_intp K = PyArray_DIMS(%(b)s)[0]; + if (N > 0x7fffffffL) + {PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); %(fail)s;} + + // strides tell you how many bytes to skip to go to next column/row entry + npy_intp Szm = PyArray_STRIDES(%(z)s)[0] / PyArray_DESCR(%(z)s)->elsize; + npy_intp Szn = PyArray_STRIDES(%(z)s)[1] / PyArray_DESCR(%(z)s)->elsize; + //npy_intp Sbm = PyArray_STRIDES(%(b)s)[0] / PyArray_DESCR(%(b)s)->elsize; + npy_intp Sbn = PyArray_STRIDES(%(b)s)[1] / PyArray_DESCR(%(b)s)->elsize; + npy_intp Sval = PyArray_STRIDES(%(a_val)s)[0] / PyArray_DESCR(%(a_val)s)->elsize; + npy_intp Sind = PyArray_STRIDES(%(a_ind)s)[0] / PyArray_DESCR(%(a_ind)s)->elsize; + npy_intp Sptr = PyArray_STRIDES(%(a_ptr)s)[0] / PyArray_DESCR(%(a_ptr)s)->elsize; + + // pointers to access actual data in the arrays passed as params. + dtype_%(z)s* __restrict__ Dz = (dtype_%(z)s*)PyArray_DATA(%(z)s); + const dtype_%(a_val)s* __restrict__ Dval = (dtype_%(a_val)s*)PyArray_DATA(%(a_val)s); + const npy_int32 * __restrict__ Dind = (npy_int32*)PyArray_DATA(%(a_ind)s); + const npy_int32 * __restrict__ Dptr = (npy_int32*)PyArray_DATA(%(a_ptr)s); + + //npy_intp nnz = PyArray_DIMS(%(a_ind)s)[0]; + + //clear the output array + memset(Dz, 0, M*N*sizeof(dtype_%(z)s)); + + //iterate over the sparse array, making the most of an entry wherever we find it. + // + // Normal matrix matrix multiply: A MxK, B KxN => Z = AB + // for m + // for n + // for k + // z[m, n] += a[m, k] * b[k, n] + // Here instead: Z = + // for k + // for m (sparse) + // for n + // z[m, n] += a[m, k] * b[k, n] + + // loop over inner dimension + for (npy_int32 k = 0; k < K; ++k) + { + // get pointer to k-th row of dense matrix + const dtype_%(b)s* __restrict__ bk = (dtype_%(b)s*)(PyArray_BYTES(%(b)s) + PyArray_STRIDES(%(b)s)[0] * k); + + // loop over sparse column indices through index pointer array + // (amounts to looping over rows M of sparse matrix) + + for (npy_int32 m_idx = Dptr[k * Sptr]; m_idx < Dptr[(k+1) * Sptr]; ++m_idx) + { + npy_int32 m = Dind[m_idx * Sind]; // row index of non-null value for column K + const dtype_%(a_val)s Amk = Dval[m_idx * Sval]; // actual value at that location + + // pointer to m-th row of the output matrix Z + dtype_%(z)s* __restrict__ zm = (dtype_%(z)s*)(PyArray_BYTES(%(z)s) + PyArray_STRIDES(%(z)s)[0] * m); + + //RESOLVE: a.shape[0] equals z.shape[0], why is this not an equality constraint? + if (m >= PyArray_DIMS(%(z)s)[0]) + {PyErr_SetString(PyExc_NotImplementedError, "illegal row index in a"); %(fail)s;} + + // loop over final dimension (cols of dense matrix) and perform dot product + if ((Szn == 1) && (Sbn == 1)) { + for(npy_int32 n = 0; n < N; ++n) + { + zm[n] += Amk * bk[n]; + } + } + else + { + for(npy_int32 n = 0; n < N; ++n) + { + zm[n*Szn] += Amk * bk[n*Sbn]; + } + } + } + } + } + """ % dict( + locals(), **sub + ) + + return rval + + def c_code_cache_version(self): + return (3,) + + +sd_csc = StructuredDotCSC() + + +class StructuredDotCSR(COp): + """ + Structured Dot CSR is like dot, except that only the + gradient wrt non-zero elements of a sparse matrix + are calculated and propagated. + + The output is presumed to be a dense matrix, and is represented by a + `TensorType` instance. + + Notes + ----- + The gradient implemented is structured. + + This `Op` is used as a rewritten form of `StructuredDot`. + + """ + + __props__ = () + + def make_node(self, a_val, a_ind, a_ptr, b): + self.dtype_out = aes.upcast(a_val.type.dtype, b.type.dtype) + r = Apply( + self, + [a_val, a_ind, a_ptr, b], + [tensor(self.dtype_out, (False, b.type.broadcastable[1]))], + ) + return r + + def perform(self, node, inputs, outputs): + (a_val, a_ind, a_ptr, b) = inputs + (out,) = outputs + a = scipy.sparse.csr_matrix( + (a_val, a_ind, a_ptr), (len(a_ptr) - 1, b.shape[0]), copy=True + ) # use view_map before setting this to False + # out[0] = a.dot(b) + out[0] = a * b + # scipy 0.7 automatically converts to dense, but not .6 sometimes + assert _is_dense(out[0]) + + def c_code(self, node, name, inputs, outputs, sub): + """ + C-implementation of the dot product of the sparse matrix A and matrix B. + + Parameters + ---------- + a_val + Non-zero values of the sparse matrix. + a_ind + Column indices of the non-null values (.indices of a + scipy.csc_matrix). + a_ptr + Indicates col indices for col. i are in the range + a_ptr[i]:a_ptr[i+1]. + n_cols + Number of columns of sparse matrix. + b + Dense matrix to perform dot product with, as in dot(a, b). + z + Return value. + sub + TODO, not too sure, something to do with weave probably. + + """ + (a_val, a_ind, a_ptr, b) = inputs + (z,) = outputs + typenum_z = TensorType(self.dtype_out, []).dtype_specs()[2] + if node.inputs[0].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for a_val") + if node.inputs[3].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for b") + + return """ + if (PyArray_NDIM(%(a_val)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;} + if (PyArray_NDIM(%(a_ind)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;} + if (PyArray_NDIM(%(a_ptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); %(fail)s;} + if (PyArray_NDIM(%(b)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;} + + if (PyArray_TYPE(%(a_ind)s) != NPY_INT32) { + PyErr_SetString(PyExc_NotImplementedError, "a_ind dtype not INT32"); %(fail)s;} + + if (PyArray_TYPE(%(a_ptr)s) != NPY_INT32) + {PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;} + + if (PyArray_DIMS(%(a_val)s)[0] != PyArray_DIMS(%(a_ind)s)[0]) + {PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;} + + if ((!%(z)s) + || (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(a_ptr)s)[0]-1) //a's rows + || (PyArray_DIMS(%(z)s)[1] != PyArray_DIMS(%(b)s)[1]) //b's columns + ) + { + {Py_XDECREF(%(z)s);} + npy_intp dims[] = {0, 0}; + dims[0] = PyArray_DIMS(%(a_ptr)s)[0]-1; + dims[1] = PyArray_DIMS(%(b)s)[1]; + %(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(typenum_z)s); + } + + { + // sparse array has size MxK, dense KxN, output MxN + npy_intp M = PyArray_DIMS(%(z)s)[0]; + npy_intp N = PyArray_DIMS(%(z)s)[1]; + npy_intp K = PyArray_DIMS(%(b)s)[0]; + if (N > 0x7fffffffL) + {PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); %(fail)s;} + + // strides tell you how many bytes to skip to go to next column/row entry + npy_intp Szm = PyArray_STRIDES(%(z)s)[0] / PyArray_DESCR(%(z)s)->elsize; + npy_intp Szn = PyArray_STRIDES(%(z)s)[1] / PyArray_DESCR(%(z)s)->elsize; + npy_intp Sbm = PyArray_STRIDES(%(b)s)[0] / PyArray_DESCR(%(b)s)->elsize; + npy_intp Sbn = PyArray_STRIDES(%(b)s)[1] / PyArray_DESCR(%(b)s)->elsize; + npy_intp Sval = PyArray_STRIDES(%(a_val)s)[0] / PyArray_DESCR(%(a_val)s)->elsize; + npy_intp Sind = PyArray_STRIDES(%(a_ind)s)[0] / PyArray_DESCR(%(a_ind)s)->elsize; + npy_intp Sptr = PyArray_STRIDES(%(a_ptr)s)[0] / PyArray_DESCR(%(a_ptr)s)->elsize; + + // pointers to access actual data in the arrays passed as params. + dtype_%(z)s* __restrict__ Dz = (dtype_%(z)s*)PyArray_DATA(%(z)s); + const dtype_%(a_val)s* __restrict__ Dval = (dtype_%(a_val)s*)PyArray_DATA(%(a_val)s); + const npy_int32 * __restrict__ Dind = (npy_int32*)PyArray_DATA(%(a_ind)s); + const npy_int32 * __restrict__ Dptr = (npy_int32*)PyArray_DATA(%(a_ptr)s); + + //npy_intp nnz = PyArray_DIMS(%(a_ind)s)[0]; + + //clear the output array + memset(Dz, 0, M*N*sizeof(dtype_%(z)s)); + + //iterate over the sparse array, making the most of an entry wherever we find it. + // Normal matrix matrix multiply: + // for m + // for n + // for k + // z[m, n] += a[m, k] * b[k, n] + // Here instead: + // for m + // for k (sparse) + // for n + // z[m, n] += a[m, k] * b[k, n] + + // loop over inner dimension + for (npy_int64 m = 0; m < M; ++m) + { + // pointer to m-th row of the output matrix Z + dtype_%(z)s* __restrict__ zm = (dtype_%(z)s*)(PyArray_BYTES(%(z)s) + PyArray_STRIDES(%(z)s)[0] * m); + + // loop over sparse rows indices through index pointer array + // (amounts to looping over cols k of sparse matrix) + for (npy_int32 k_idx = Dptr[m * Sptr]; k_idx < Dptr[(m+1) * Sptr]; ++k_idx) + { + npy_int32 k = Dind[k_idx * Sind]; // col index of non-null value for row m + const dtype_%(a_val)s Amk = Dval[k_idx * Sval]; // actual value at that location + + // get pointer to k-th row of dense matrix + const dtype_%(b)s* __restrict__ bk = (dtype_%(b)s*)(PyArray_BYTES(%(b)s) + PyArray_STRIDES(%(b)s)[0] * k); + + // loop over final dimension (cols of dense matrix) and perform dot product + for(npy_int32 n = 0; n < N; ++n) + { + zm[n*Szn] += Amk * bk[n*Sbn]; + } + } + } + } + + """ % dict( + locals(), **sub + ) + + def c_code_cache_version(self): + return (2,) + + +sd_csr = StructuredDotCSR() + + +# register a specialization to replace StructuredDot -> StructuredDotCSx +# This is tested in tests/test_basic.py:792 +@node_rewriter([sparse._structured_dot]) +def local_structured_dot(fgraph, node): + if node.op == sparse._structured_dot: + a, b = node.inputs + if a.type.format == "csc": + a_val, a_ind, a_ptr, a_shape = csm_properties(a) + a_nsparse = a_shape[0] + return [sd_csc(a_val, a_ind, a_ptr, a_nsparse, b)] + if a.type.format == "csr": + a_val, a_ind, a_ptr, a_shape = csm_properties(a) + return [sd_csr(a_val, a_ind, a_ptr, b)] + return False + + +# Commented out because +# a) it is only slightly faster than scipy these days, and sometimes a little +# slower, and +# b) the resulting graphs make it very difficult for an op to do size checking +# on the matrices involved. dimension mismatches are hard to detect sensibly. +# register_specialize(local_structured_dot) + + +class UsmmCscDense(_NoPythonCOp): + """Performs ``alpha * x @ y + z``. + + ``x`` and ``y`` are a matrices, ``z`` is a dense matrix, and ``alpha`` is a + scalar. The result is a dense matrix. + + Notes + ----- + The gradient is not implemented for this `Op`. + + This is an optimized version of `Usmm` when ``x`` is in CSC format and ``y`` is dense. + + """ + + __props__ = ("inplace",) + + def __init__(self, inplace): + self.inplace = inplace + if inplace: + self.destroy_map = {0: [6]} + + def __str__(self): + if self.inplace: + return "UsmmCscDense{inplace}" + else: + return "UsmmCscDense{no_inplace}" + + def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z): + alpha = as_tensor_variable(alpha) + x_val = as_tensor_variable(x_val) + x_ind = as_tensor_variable(x_ind) + x_ptr = as_tensor_variable(x_ptr) + x_nrows = as_tensor_variable(x_nrows) + y = as_tensor_variable(y) + z = as_tensor_variable(z) + assert x_ind.dtype == "int32" + assert x_ptr.dtype == "int32" + assert x_nrows.dtype == "int32" + assert alpha.ndim == 2 and alpha.type.broadcastable == (True, True) + assert x_val.ndim == 1 + assert y.ndim == 2 + assert z.ndim == 2 + + dtype_out = aes.upcast( + alpha.type.dtype, x_val.type.dtype, y.type.dtype, z.type.dtype + ) + + if dtype_out not in ("float32", "float64"): + raise NotImplementedError("only float types are supported in " "operands") + + if self.inplace: + assert z.type.dtype == dtype_out + + # axpy work only with the same dtype, so we should upcast the input + if dtype_out != alpha.type.dtype: + alpha = cast(alpha, dtype_out) + if dtype_out != x_val.type.dtype: + x_val = cast(x_val, dtype_out) + if dtype_out != y.type.dtype: + y = cast(y, dtype_out) + if dtype_out != z.type.dtype: + z = cast(z, dtype_out) + + r = Apply( + self, + [alpha, x_val, x_ind, x_ptr, x_nrows, y, z], + [tensor(dtype_out, (False, y.type.broadcastable[1]))], + ) + return r + + def c_support_code(self, **kwargs): + return blas.blas_header_text() + + def c_libraries(self, **kwargs): + return blas.ldflags() + + def c_compile_args(self, **kwargs): + return blas.ldflags(libs=False, flags=True) + + def c_lib_dirs(self, **kwargs): + return blas.ldflags(libs=False, libs_dir=True) + + def c_header_dirs(self, **kwargs): + return blas.ldflags(libs=False, include_dir=True) + + def c_code(self, node, name, inputs, outputs, sub): + alpha, x_val, x_ind, x_ptr, x_nrows, y, z = inputs + zn = outputs[0] + if node.inputs[1].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for " "x_val") + if node.inputs[5].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for y") + if node.inputs[6].type.dtype != node.outputs[0].type.dtype: + raise NotImplementedError("z and output must have same type") + + if node.inputs[1].type.dtype == "float32": + conv_type = "float" + axpy = "saxpy_" + else: + conv_type = "double" + axpy = "daxpy_" + # retrieve dtype numbers + typenum_alpha = node.inputs[0].type.dtype_specs()[2] + typenum_x_val = node.inputs[1].type.dtype_specs()[2] + typenum_y = node.inputs[5].type.dtype_specs()[2] + typenum_z = node.inputs[6].type.dtype_specs()[2] + typenum_zn = node.outputs[0].type.dtype_specs()[2] + + inplace = int(self.inplace) + + rval = """ + + if (PyArray_NDIM(%(x_val)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(x_val) != 1"); %(fail)s;} + if (PyArray_NDIM(%(x_ind)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(x_ind) != 1"); %(fail)s;} + if (PyArray_NDIM(%(x_ptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(x_ptr) != 1"); %(fail)s;} + if (PyArray_NDIM(%(x_nrows)s) != 0) {PyErr_SetString(PyExc_NotImplementedError, "rank(nrows) != 0"); %(fail)s;} + if (PyArray_NDIM(%(y)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;} + + if (PyArray_TYPE(%(x_val)s) != %(typenum_x_val)s) { + PyErr_SetString(PyExc_NotImplementedError, "Invalid type for x_val"); %(fail)s;} + + if (PyArray_TYPE(%(y)s) != %(typenum_y)s) { + PyErr_SetString(PyExc_NotImplementedError, "Invalid type for y"); %(fail)s;} + + if (PyArray_TYPE(%(z)s) != %(typenum_z)s) { + PyErr_SetString(PyExc_NotImplementedError, "Invalid type for z"); %(fail)s;} + + if (PyArray_TYPE(%(alpha)s) != %(typenum_alpha)s) { + PyErr_SetString(PyExc_NotImplementedError, "Invalid type for alpha"); %(fail)s;} + + if (PyArray_TYPE(%(x_ind)s) != NPY_INT32) { + PyErr_SetString(PyExc_NotImplementedError, "x_ind dtype not INT32"); %(fail)s;} + + if (PyArray_TYPE(%(x_ptr)s) != NPY_INT32) + {PyErr_SetString(PyExc_NotImplementedError, "x_ptr dtype not INT32"); %(fail)s;} + + if (PyArray_TYPE(%(x_nrows)s) != NPY_INT32) + {PyErr_SetString(PyExc_NotImplementedError, "x_nrows dtype not INT32"); %(fail)s;} + + if (PyArray_DIMS(%(x_val)s)[0] != PyArray_DIMS(%(x_ind)s)[0]) + {PyErr_SetString(PyExc_NotImplementedError, "x_val and x_ind have different lengths"); %(fail)s;} + + if (PyArray_DIMS(%(x_ptr)s)[0] != PyArray_DIMS(%(y)s)[0]+1) + {PyErr_SetString(PyExc_NotImplementedError, "x's number of columns doesn't match y's rows"); %(fail)s;} + + if (PyArray_DIMS(%(z)s)[0] != ((npy_int32 *)PyArray_DATA(%(x_nrows)s))[0] || PyArray_DIMS(%(z)s)[1] != PyArray_DIMS(%(y)s)[1]) + {PyErr_SetString(PyExc_NotImplementedError, "The dimension of the allocated output doesn't match the correct output size."); %(fail)s;} + + if (PyArray_SIZE(%(alpha)s) != 1) + {PyErr_SetString(PyExc_NotImplementedError, "The number of element in alpha must be 1"); %(fail)s;} + + if (PyArray_NDIM(%(alpha)s) != 2) + {PyErr_SetString(PyExc_NotImplementedError, "The number dimension of alpha must be 2"); %(fail)s;} + + if (PyArray_NDIM(%(x_val)s) != 1) + {PyErr_SetString(PyExc_NotImplementedError, "The number dimension of x_val must be 1"); %(fail)s;} + + if (PyArray_NDIM(%(y)s) != 2) + {PyErr_SetString(PyExc_NotImplementedError, "The number dimension of y must be 2"); %(fail)s;} + + if (PyArray_NDIM(%(z)s) != 2) + {PyErr_SetString(PyExc_NotImplementedError, "The number dimension of z must be 2"); %(fail)s;} + + if (%(inplace)s) + { + if (%(typenum_zn)s != %(typenum_z)s) { + PyErr_SetString(PyExc_NotImplementedError, "When inplace the output dtype must be the same as the input"); %(fail)s;} + + Py_XDECREF(%(zn)s); + %(zn)s = %(z)s; + Py_INCREF(%(zn)s); + } + else if (!%(zn)s + || (PyArray_DIMS(%(zn)s)[0] != ((npy_int32 *)PyArray_DATA(%(x_nrows)s))[0]) + || (PyArray_DIMS(%(zn)s)[1] != PyArray_DIMS(%(y)s)[1]) + ) + { + {Py_XDECREF(%(zn)s);} + npy_intp dims[] = {0, 0}; + dims[0] = ((npy_int32 *)PyArray_DATA(%(x_nrows)s))[0]; + dims[1] = PyArray_DIMS(%(y)s)[1]; + %(zn)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, %(typenum_zn)s); + } + + { + // sparse array has size MxK, dense KxN, output MxN + npy_intp M = PyArray_DIMS(%(zn)s)[0]; + npy_intp N = PyArray_DIMS(%(zn)s)[1]; + npy_intp K = PyArray_DIMS(%(y)s)[0]; + + // pointers to access actual data in the arrays passed as params. + const dtype_%(x_val)s* __restrict__ Dval = (dtype_%(x_val)s*)PyArray_DATA(%(x_val)s); + const npy_int32 * __restrict__ Dind = (npy_int32*)PyArray_DATA(%(x_ind)s); + const npy_int32 * __restrict__ Dptr = (npy_int32*)PyArray_DATA(%(x_ptr)s); + const dtype_%(alpha)s alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; + + npy_intp Sz = PyArray_STRIDES(%(z)s)[1] / PyArray_DESCR(%(z)s)->elsize; + npy_intp Szn = PyArray_STRIDES(%(zn)s)[1] / PyArray_DESCR(%(zn)s)->elsize; + npy_intp Sval = PyArray_STRIDES(%(x_val)s)[0] / PyArray_DESCR(%(x_val)s)->elsize; + npy_intp Sind = PyArray_STRIDES(%(x_ind)s)[0] / PyArray_DESCR(%(x_ind)s)->elsize; + npy_intp Sptr = PyArray_STRIDES(%(x_ptr)s)[0] / PyArray_DESCR(%(x_ptr)s)->elsize; + npy_intp Sy = PyArray_STRIDES(%(y)s)[1] / PyArray_DESCR(%(y)s)->elsize; + + // blas expects ints; convert here (rather than just making N etc ints) to avoid potential overflow in the negative-stride correction + if ((N > 0x7fffffffL)||(Sy > 0x7fffffffL)||(Szn > 0x7fffffffL)||(Sy < -0x7fffffffL)||(Szn < -0x7fffffffL)) + {PyErr_SetString(PyExc_NotImplementedError, "array too big for BLAS (overflows int32 index)"); %(fail)s;} + int N32 = N; + int Sy32 = Sy; + int Szn32 = Szn; + + if (!(%(inplace)s)) + { + if (PyArray_CopyInto(%(zn)s, %(z)s)) + { + Py_XDECREF(%(zn)s); + %(fail)s; + } + } + + for (npy_intp k = 0; k < K; ++k) + { + for (npy_int32 m_idx = Dptr[k * Sptr]; m_idx < Dptr[(k+1)*Sptr]; ++m_idx) + { + const npy_int32 m = Dind[m_idx * Sind]; // row index of non-null value for column K + + const dtype_%(x_val)s Amk = alpha * Dval[m_idx * Sval]; // actual value at that location + + dtype_%(y)s* y_row = (dtype_%(y)s*)(PyArray_BYTES(%(y)s) + PyArray_STRIDES(%(y)s)[0] * k); + // axpy expects pointer to the beginning of memory arrays, + // so when the stride is negative, we need to get the + // last element + if (Sy < 0) + y_row += (K - 1) * Sy; + + dtype_%(zn)s* z_row = (dtype_%(zn)s*)(PyArray_BYTES(%(zn)s) + PyArray_STRIDES(%(zn)s)[0] * m); + if (Szn < 0) + z_row += (N - 1) * Szn; + + %(axpy)s(&N32, (%(conv_type)s*)&Amk, (%(conv_type)s*)y_row, &Sy32, (%(conv_type)s*)z_row, &Szn32); + } + } + } + """ % dict( + locals(), **sub + ) + + return rval + + def c_code_cache_version(self): + return (3, blas.blas_header_version()) + + +usmm_csc_dense = UsmmCscDense(inplace=False) +usmm_csc_dense_inplace = UsmmCscDense(inplace=True) + + +# This is tested in tests/test_basic.py:UsmmTests +local_usmm = PatternNodeRewriter( + ( + sub, + "z", + ( + mul, + { + "pattern": "alpha", + "constraint": lambda expr: ( + all(expr.type.broadcastable) and config.blas__ldflags + ), + }, + (sparse._dot, "x", "y"), + ), + ), + (usmm, (neg, "alpha"), "x", "y", "z"), +) +register_specialize(local_usmm, name="local_usmm") + + +# register a specialization to replace usmm_csc_dense -> usmm_csc_dense_inplace +# This is tested in tests/test_basic.py:UsmmTests +@node_rewriter([usmm_csc_dense]) +def local_usmm_csc_dense_inplace(fgraph, node): + if node.op == usmm_csc_dense: + return [usmm_csc_dense_inplace(*node.inputs)] + + +register_specialize(local_usmm_csc_dense_inplace, "cxx_only", "inplace") + + +# This is tested in tests/test_basic.py:UsmmTests +@node_rewriter([usmm]) +def local_usmm_csx(fgraph, node): + """ + usmm -> usmm_csc_dense + + """ + if node.op == usmm: + alpha, x, y, z = node.inputs + + x_is_sparse_variable = _is_sparse_variable(x) + y_is_sparse_variable = _is_sparse_variable(y) + + if x_is_sparse_variable and not y_is_sparse_variable: + if x.type.format == "csc": + x_val, x_ind, x_ptr, x_shape = csm_properties(x) + x_nsparse = x_shape[0] + dtype_out = aes.upcast( + alpha.type.dtype, x.type.dtype, y.type.dtype, z.type.dtype + ) + if dtype_out not in ("float32", "float64"): + return False + # Sparse cast is not implemented. + if y.type.dtype != dtype_out: + return False + + return [usmm_csc_dense(alpha, x_val, x_ind, x_ptr, x_nsparse, y, z)] + return False + + +register_specialize(local_usmm_csx, "cxx_only") + + +class CSMGradC(_NoPythonCOp): + + __props__ = () + + def make_node(self, a_val, a_ind, a_ptr, a_dim, b_val, b_ind, b_ptr, b_dim): + return Apply( + self, + [a_val, a_ind, a_ptr, a_dim, b_val, b_ind, b_ptr, b_dim], + [b_val.type()], + ) + + def c_code(self, node, name, inputs, outputs, sub): + # retrieve dtype number + (a_val, a_ind, a_ptr, a_dim, b_val, b_ind, b_ptr, b_dim) = inputs + (z,) = outputs + typenum_z = node.outputs[0].type.dtype_specs()[2] + if node.inputs[0].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for a_val") + if node.inputs[3].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for b_val") + + return """ + if (PyArray_NDIM(%(a_val)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); %(fail)s;} + if (PyArray_NDIM(%(a_ind)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); %(fail)s;} + if (PyArray_NDIM(%(a_ptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); %(fail)s;} + if (PyArray_NDIM(%(b_val)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b_val) != 1"); %(fail)s;} + if (PyArray_NDIM(%(b_ind)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b_ind) != 1"); %(fail)s;} + if (PyArray_NDIM(%(b_ptr)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b_ptr) != 1"); %(fail)s;} + + if (PyArray_TYPE(%(a_ind)s) != NPY_INT32) { + PyErr_SetString(PyExc_NotImplementedError, "a_ind dtype not INT32"); %(fail)s;} + + if (PyArray_TYPE(%(a_ptr)s) != NPY_INT32) + {PyErr_SetString(PyExc_NotImplementedError, "a_ptr dtype not INT32"); %(fail)s;} + + if (PyArray_TYPE(%(b_ind)s) != NPY_INT32) { + PyErr_SetString(PyExc_NotImplementedError, "b_ind dtype not INT32"); %(fail)s;} + + if (PyArray_TYPE(%(b_ptr)s) != NPY_INT32) + {PyErr_SetString(PyExc_NotImplementedError, "b_ptr dtype not INT32"); %(fail)s;} + + if (PyArray_DIMS(%(a_val)s)[0] != PyArray_DIMS(%(a_ind)s)[0]) + {PyErr_SetString(PyExc_NotImplementedError, "a_val and a_ind have different lengths"); %(fail)s;} + + if (PyArray_DIMS(%(b_val)s)[0] != PyArray_DIMS(%(b_ind)s)[0]) + {PyErr_SetString(PyExc_NotImplementedError, "b_val and b_ind have different lengths"); %(fail)s;} + + if (PyArray_DIMS(%(a_ptr)s)[0] != PyArray_DIMS(%(b_ptr)s)[0]) + {PyErr_SetString(PyExc_NotImplementedError, "a_ptr and b_ptr have different lengths"); %(fail)s;} + + if ((!%(z)s) || (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(a_val)s)[0])) + { + {Py_XDECREF(%(z)s);} + npy_intp dims[] = {0}; + dims[0] = PyArray_DIMS(%(a_val)s)[0]; + %(z)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, %(typenum_z)s); + } + + { + // sparse array has size MxK, dense KxN, output MxN + npy_intp M = PyArray_DIMS(%(a_ptr)s)[0] - 1; + npy_intp a_dim_0 = ((npy_int32 *)PyArray_DATA(%(a_dim)s))[0]; + npy_intp a_dim_1 = ((npy_int32 *)PyArray_DATA(%(a_dim)s))[1]; + + npy_intp sp_dim = (M == a_dim_0)?a_dim_1:a_dim_0; + + // strides tell you how many bytes to skip to go to next column/row entry + npy_intp Sz = PyArray_STRIDES(%(z)s)[0] / PyArray_DESCR(%(z)s)->elsize; + npy_intp Sa_val = PyArray_STRIDES(%(a_val)s)[0] / PyArray_DESCR(%(a_val)s)->elsize; + npy_intp Sa_ind = PyArray_STRIDES(%(a_ind)s)[0] / PyArray_DESCR(%(a_ind)s)->elsize; + npy_intp Sa_ptr = PyArray_STRIDES(%(a_ptr)s)[0] / PyArray_DESCR(%(a_ptr)s)->elsize; + npy_intp Sb_val = PyArray_STRIDES(%(b_val)s)[0] / PyArray_DESCR(%(b_val)s)->elsize; + npy_intp Sb_ind = PyArray_STRIDES(%(b_ind)s)[0] / PyArray_DESCR(%(b_ind)s)->elsize; + npy_intp Sb_ptr = PyArray_STRIDES(%(b_ptr)s)[0] / PyArray_DESCR(%(b_ptr)s)->elsize; + + // pointers to access actual data in the arrays passed as params. + dtype_%(z)s* __restrict__ Dz = (dtype_%(z)s*)PyArray_DATA(%(z)s); + const dtype_%(a_val)s* __restrict__ Da_val = (dtype_%(a_val)s*)PyArray_DATA(%(a_val)s); + const npy_int32 * __restrict__ Da_ind = (npy_int32*)PyArray_DATA(%(a_ind)s); + const npy_int32 * __restrict__ Da_ptr = (npy_int32*)PyArray_DATA(%(a_ptr)s); + const dtype_%(b_val)s* __restrict__ Db_val = (dtype_%(b_val)s*)PyArray_DATA(%(b_val)s); + const npy_int32 * __restrict__ Db_ind = (npy_int32*)PyArray_DATA(%(b_ind)s); + const npy_int32 * __restrict__ Db_ptr = (npy_int32*)PyArray_DATA(%(b_ptr)s); + + npy_intp nnz = PyArray_DIMS(%(a_ind)s)[0]; + + dtype_%(b_val)s b_row[sp_dim]; + + //clear the output array + for (npy_int64 i = 0; i < nnz; ++i) + { + Dz[i*Sz] = 0; + } + memset(b_row, 0, sp_dim*sizeof(dtype_%(b_val)s)); + + // loop over inner dimension + for (npy_int64 m = 0; m < M; ++m) + { + for (npy_int32 j_ptr = Db_ptr[m * Sb_ptr]; + j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) { + b_row[Db_ind[j_ptr * Sb_ind]] += Db_val[j_ptr*Sb_val]; + } + + for (npy_int32 j_ptr = Da_ptr[m * Sa_ptr]; + j_ptr < Da_ptr[(m + 1) * Sa_ptr]; j_ptr++) { + Dz[j_ptr*Sz] = b_row[Da_ind[j_ptr * Sa_ind]]; + } + + for (npy_int32 j_ptr = Db_ptr[m * Sb_ptr]; + j_ptr < Db_ptr[(m + 1) * Sb_ptr]; j_ptr++) { + b_row[Db_ind[j_ptr * Sb_ind]] = 0; + } + } + } + + """ % dict( + locals(), **sub + ) + + def c_code_cache_version(self): + return (3,) + + +csm_grad_c = CSMGradC() + + +@node_rewriter([csm_grad(None)]) +def local_csm_grad_c(fgraph, node): + """ + csm_grad(None) -> csm_grad_c + + """ + if node.op == csm_grad(None): + return [csm_grad_c(*node.inputs)] + return False + + +# DISABLED AS IT IS BROKEN FOR UNSORTED INDICES! +# register_specialize(local_csm_grad_c, 'cxx_only') + + +class MulSDCSC(_NoPythonCOp): + """Multiplication of sparse matrix by a broadcasted dense vector element-wise. + + Notes + ----- + + This `Op` is used as a rewritten form of `mul_s_d`. + + """ + + __props__ = () + + def make_node(self, a_data, a_indices, a_indptr, b): + """ + + Parameters + ---------- + a_data + Sparse matrix data. + a_indices + Sparse matrix indices. + a_indptr + Sparse matrix indptr. + b + Tensor type matrix. + + Returns + ------- + The multiplication of the two matrices element-wise. + + Notes + ----- + `a_data`, `a_indices` and `a_indptr` must be the properties of a sparse + matrix in csc format. + + The dtype of `a_data`, i.e. the dtype of the sparse matrix, cannot be a + complex type. + + """ + assert b.type.ndim == 2 + return Apply( + self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, (False,))] + ) + + def c_code_cache_version(self): + return (3,) + + def c_code(self, node, name, inputs, outputs, sub): + + ( + _data, + _indices, + _indptr, + _b, + ) = inputs + (_zout,) = outputs + if node.inputs[0].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for a") + if node.inputs[3].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for b") + + return """ + if (PyArray_NDIM(%(_b)s) != 2) { + PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); + %(fail)s;} + if (PyArray_NDIM(%(_data)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); + %(fail)s;} + if (PyArray_NDIM(%(_indices)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); + %(fail)s;} + if (PyArray_NDIM(%(_indptr)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); + %(fail)s;} + + if( PyArray_TYPE(%(_indices)s) != NPY_INT32) { + PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} + + if( PyArray_TYPE(%(_indptr)s) != NPY_INT32) + {PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;} + + if (!%(_zout)s || + (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_indices)s)[0]) || + !(PyArray_ISCONTIGUOUS(%(_zout)s))) + { + Py_XDECREF(%(_zout)s); + %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, + PyArray_DIMS(%(_indices)s), PyArray_TYPE(%(_b)s)); + if (!%(_zout)s) + { + PyErr_SetString(PyExc_MemoryError, + "Could not allocate output memory."); + %(fail)s; + } + } + + { //makes it compile even though labels jump over variable definitions. + const npy_intp nnz = PyArray_DIMS(%(_indices)s)[0]; + //TODO: error checking with this + const npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; + + const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s); + const npy_int32 * const __restrict__ indptr = (npy_int32 *)PyArray_DATA(%(_indptr)s); + const npy_int32 * const __restrict__ indices = (npy_int32 *)PyArray_DATA(%(_indices)s); + + dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)PyArray_DATA(%(_zout)s); + + const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0]; + + // loop over columns + for (npy_intp j = 0; j < N; ++j) + { + // for each non-null value in the sparse column + for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx) + { + // extract row index of non-null value + npy_int32 i = indices[i_idx]; + + // extract i-th row of dense matrix + const dtype_%(_b)s* __restrict__ b_row = (dtype_%(_b)s*)(PyArray_BYTES(%(_b)s) + Sb * i); + + // write resulting gradient to sparse output + zout[i_idx] = data[i_idx] * b_row[j]; + } + } + } + + """ % dict( + locals(), **sub + ) + + def __str__(self): + return self.__class__.__name__ + + +mul_s_d_csc = MulSDCSC() + + +class MulSDCSR(_NoPythonCOp): + """Multiplication of sparse matrix by a broadcasted dense vector element-wise. + + Notes + ----- + + This `Op` is used as a rewritten form of `mul_s_d`. + + """ + + __props__ = () + + def make_node(self, a_data, a_indices, a_indptr, b): + """ + + Parameters + ---------- + a_data + Sparse matrix data. + a_indices + Sparse matrix indices. + a_indptr + Sparse matrix indptr. + b + Tensor type matrix. + + Returns + ------- + The multiplication of the two matrix element wise. + + Notes + ----- + `a_data`, `a_indices` and `a_indptr` must be the properties + of a sparse matrix in csr format. + + The dtype of `a_data`, i.e. the dtype of the sparse matrix, + cannot be a complex type. + + """ + assert b.type.ndim == 2 + return Apply( + self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, (False,))] + ) + + def c_code_cache_version(self): + return (3,) + + def c_code(self, node, name, inputs, outputs, sub): + + ( + _data, + _indices, + _indptr, + _b, + ) = inputs + (_zout,) = outputs + if node.inputs[0].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for a") + if node.inputs[3].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for b") + + return """ + if (PyArray_NDIM(%(_b)s) != 2) { + PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); + %(fail)s;} + if (PyArray_NDIM(%(_data)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); + %(fail)s;} + if (PyArray_NDIM(%(_indices)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); + %(fail)s;} + if (PyArray_NDIM(%(_indptr)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); + %(fail)s;} + + if( PyArray_TYPE(%(_indices)s) != NPY_INT32) { + PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} + + if( PyArray_TYPE(%(_indptr)s) != NPY_INT32) + {PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;} + + if (!%(_zout)s || + (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_indices)s)[0]) || + !(PyArray_ISCONTIGUOUS(%(_zout)s))) + { + Py_XDECREF(%(_zout)s); + %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, + PyArray_DIMS(%(_indices)s), PyArray_TYPE(%(_b)s)); + if (!%(_zout)s) + { + PyErr_SetString(PyExc_MemoryError, + "Could not allocate output memory."); + %(fail)s; + } + } + + { //makes it compile even though labels jump over variable definitions. + const npy_intp nnz = PyArray_DIMS(%(_indices)s)[0]; + //TODO: error checking with this + const npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; + + const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s); + const npy_int32 * const __restrict__ indptr = (npy_int32 *)PyArray_DATA(%(_indptr)s); + const npy_int32 * const __restrict__ indices = (npy_int32 *)PyArray_DATA(%(_indices)s); + + dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)PyArray_DATA(%(_zout)s); + + const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0]; + + // loop over columns + for (npy_intp j = 0; j < N; ++j) + { + // extract i-th row of dense matrix + const dtype_%(_b)s* __restrict__ b_row = (dtype_%(_b)s*)(PyArray_BYTES(%(_b)s) + Sb * j); + + // for each non-null value in the sparse column + for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx) + { + // extract row index of non-null value + npy_int32 i = indices[i_idx]; + + // write resulting gradient to sparse output + zout[i_idx] = data[i_idx] * b_row[i]; + } + } + } + + """ % dict( + locals(), **sub + ) + + def __str__(self): + return self.__class__.__name__ + + +mul_s_d_csr = MulSDCSR() + + +# register a specialization to replace MulSD -> MulSDCSX +@node_rewriter([sparse.mul_s_d]) +def local_mul_s_d(fgraph, node): + if node.op == sparse.mul_s_d: + x, y = node.inputs + + x_is_sparse_variable = _is_sparse_variable(x) + + if x_is_sparse_variable: + svar = x + dvar = y + else: + svar = y + dvar = x + + if dvar.type.ndim != 2: + return False + if svar.type.format == "csc": + CSx = sparse.CSC + mul_s_d_csx = mul_s_d_csc + elif svar.type.format == "csr": + CSx = sparse.CSR + mul_s_d_csx = mul_s_d_csr + else: + raise NotImplementedError + if x.dtype != y.dtype: + # mul_s_d_csx don't support that case + return + + c_data = mul_s_d_csx( + sparse.csm_data(svar), + sparse.csm_indices(svar), + sparse.csm_indptr(svar), + dvar, + ) + + return [ + CSx( + c_data, + sparse.csm_indices(svar), + sparse.csm_indptr(svar), + sparse.csm_shape(svar), + ) + ] + + return False + + +register_specialize(local_mul_s_d, "cxx_only") + + +class MulSVCSR(_NoPythonCOp): + """Multiplication of sparse matrix by a broadcasted dense vector element-wise. + + + Notes + ----- + + This `Op` is used as a rewritten form of `MulSV`. + + """ + + __props__ = () + + def make_node(self, a_data, a_indices, a_indptr, b): + """ + + Parameters + ---------- + a_data + Sparse matrix data. + a_indices + Sparse matrix indices. + a_indptr + Sparse matrix indptr. + b + Tensor type matrix. + + Returns + ------- + The multiplication of the two matrix element wise. + + Notes + ----- + `a_data`, `a_indices` and `a_indptr` must be the properties + of a sparse matrix in csr format. + + The dtype of `a_data`, i.e. the dtype of the sparse matrix, + cannot be a complex type. + + """ + assert b.type.ndim == 1 + return Apply( + self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, (False,))] + ) + + def c_code_cache_version(self): + return (2,) + + def c_code(self, node, name, inputs, outputs, sub): + ( + _data, + _indices, + _indptr, + _b, + ) = inputs + (_zout,) = outputs + if node.inputs[0].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for a") + if node.inputs[3].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for b") + + return """ + if (PyArray_NDIM(%(_b)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); + %(fail)s; + } + if (PyArray_NDIM(%(_data)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); + %(fail)s; + } + if (PyArray_NDIM(%(_indices)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); + %(fail)s; + } + if (PyArray_NDIM(%(_indptr)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); + %(fail)s; + } + + if( PyArray_TYPE(%(_indices)s) != NPY_INT32) { + PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} + + if( PyArray_TYPE(%(_indptr)s) != NPY_INT32) + {PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;} + + if (!%(_zout)s + || PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_indices)s)[0] + || !PyArray_ISCONTIGUOUS(%(_zout)s)) + { + Py_XDECREF(%(_zout)s); + %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, + PyArray_DIMS(%(_indices)s), PyArray_TYPE(%(_b)s)); + } + + { //makes it compile even though labels jump over variable definitions. + const npy_intp nnz = PyArray_DIMS(%(_indices)s)[0]; + //TODO: error checking with this + const npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; + + const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s); + const npy_int32 * const __restrict__ indptr = (npy_int32 *)PyArray_DATA(%(_indptr)s); + const npy_int32 * const __restrict__ indices = (npy_int32 *)PyArray_DATA(%(_indices)s); + + const dtype_%(_b)s* __restrict__ Db = (dtype_%(_b)s*)PyArray_DATA(%(_b)s); + + dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)PyArray_DATA(%(_zout)s); + + const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0] / PyArray_DESCR(%(_b)s)->elsize; + + // loop over rows + for (npy_intp j = 0; j < N; ++j) + { + // for each non-null value in the sparse column + for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx) + { + // extract row index of non-null value + npy_int32 i = indices[i_idx]; + + zout[i_idx] = data[i_idx] * Db[i * Sb]; + } + } + } + + """ % dict( + locals(), **sub + ) + + def __str__(self): + return self.__class__.__name__ + + +mul_s_v_csr = MulSVCSR() + + +# register a specialization to replace MulSV -> MulSVCSR +@node_rewriter([sparse.mul_s_v]) +def local_mul_s_v(fgraph, node): + if node.op == sparse.mul_s_v: + x, y = node.inputs + + x_is_sparse_variable = _is_sparse_variable(x) + + if x_is_sparse_variable: + svar = x + dvar = y + else: + svar = y + dvar = x + + if dvar.type.ndim != 1: + return False + elif svar.type.format == "csr": + CSx = sparse.CSR + mul_s_v_csx = mul_s_v_csr + else: + return False + + s_val, s_ind, s_ptr, s_shape = sparse.csm_properties(svar) + + c_data = mul_s_v_csx(s_val, s_ind, s_ptr, dvar) + + return [CSx(c_data, s_ind, s_ptr, s_shape)] + + return False + + +register_specialize(local_mul_s_v, "cxx_only") + + +class StructuredAddSVCSR(_NoPythonCOp): + """Structured addition of a sparse matrix and a dense vector. + + The elements of the vector are are only added to the corresponding + non-zero elements. Therefore, this operation outputs another sparse + matrix. + + Notes + ----- + + This `Op` is used as a rewritten form of `StructuredAddSV`. + + """ + + __props__ = () + + def make_node(self, a_data, a_indices, a_indptr, b): + """ + + Parameters + ---------- + a_data + Sparse matrix data. + a_indices + Sparse matrix indices. + a_indptr + Sparse matrix indptr. + b + Tensor type vector. + + Returns + ------- + A sparse matrix containing the addition of the vector to the data of the + sparse matrix. + + """ + b = as_tensor_variable(b) + a_data = as_tensor_variable(a_data) + a_indices = as_tensor_variable(a_indices) + a_indptr = as_tensor_variable(a_indptr) + assert a_data.type.ndim == 1 + assert a_indices.type.ndim == 1 + assert a_indptr.type.ndim == 1 + assert b.type.ndim == 1 + return Apply( + self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, (False,))] + ) + + def c_code_cache_version(self): + return (3,) + + def c_code(self, node, name, inputs, outputs, sub): + ( + _data, + _indices, + _indptr, + _b, + ) = inputs + (_zout,) = outputs + if node.inputs[0].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for a") + if node.inputs[3].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for b") + + return """ + if (PyArray_NDIM(%(_b)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); + %(fail)s; + } + if (PyArray_NDIM(%(_data)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(data) != 1"); + %(fail)s; + } + if (PyArray_NDIM(%(_indices)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); + %(fail)s; + } + if (PyArray_NDIM(%(_indptr)s) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(indptr) != 1"); + %(fail)s; + } + + if( PyArray_TYPE(%(_indices)s) != NPY_INT32) { + PyErr_SetString(PyExc_NotImplementedError, "C"); %(fail)s;} + + if( PyArray_TYPE(%(_indptr)s) != NPY_INT32) + {PyErr_SetString(PyExc_NotImplementedError, "D"); %(fail)s;} + + if (!%(_zout)s + || (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_indices)s)[0]) + || !(PyArray_ISCONTIGUOUS(%(_zout)s))) + { + Py_XDECREF(%(_zout)s); + %(_zout)s = (PyArrayObject*) PyArray_SimpleNew(1, + PyArray_DIMS(%(_indices)s), PyArray_TYPE(%(_b)s)); + if (!%(_zout)s) + { + PyErr_SetString(PyExc_MemoryError, + "Could not allocate output memory."); + %(fail)s; + } + } + + { //makes it compile even though labels jump over variable definitions. + const npy_intp nnz = PyArray_DIMS(%(_indices)s)[0]; + //TODO: error checking with this + const npy_intp N = PyArray_DIMS(%(_indptr)s)[0]-1; + + const dtype_%(_data)s * const __restrict__ data = (dtype_%(_data)s*)PyArray_DATA(%(_data)s); + const npy_int32 * const __restrict__ indptr = (npy_int32 *)PyArray_DATA(%(_indptr)s); + const npy_int32 * const __restrict__ indices = (npy_int32 *)PyArray_DATA(%(_indices)s); + + const dtype_%(_b)s* __restrict__ Db = (dtype_%(_b)s*)PyArray_DATA(%(_b)s); + + dtype_%(_zout)s * const __restrict__ zout = (dtype_%(_zout)s*)PyArray_DATA(%(_zout)s); + + const npy_intp Sb = PyArray_STRIDES(%(_b)s)[0] / PyArray_DESCR(%(_b)s)->elsize; + + // loop over columns + for (npy_intp j = 0; j < N; ++j) + { + // for each non-null value in the sparse column + for (npy_int32 i_idx = indptr[j]; i_idx < indptr[j+1]; ++i_idx) + { + // extract row index of non-null value + npy_int32 i = indices[i_idx]; + + // write resulting gradient to sparse output + zout[i_idx] = data[i_idx] + Db[i * Sb]; + } + } + } + + """ % dict( + locals(), **sub + ) + + def __str__(self): + return self.__class__.__name__ + + +structured_add_s_v_csr = StructuredAddSVCSR() + + +# register a specialization to replace +# structured_add_s_v -> structured_add_s_v_csr +@node_rewriter([sparse.structured_add_s_v]) +def local_structured_add_s_v(fgraph, node): + if node.op == sparse.structured_add_s_v: + x, y = node.inputs + + x_is_sparse_variable = _is_sparse_variable(x) + # y_is_sparse_variable = _is_sparse_variable(y) + + if x_is_sparse_variable: + svar = x + dvar = y + else: + svar = y + dvar = x + + if dvar.type.ndim != 1: + return False + elif svar.type.format == "csr": + CSx = sparse.CSR + structured_add_s_v_csx = structured_add_s_v_csr + else: + return False + + s_val, s_ind, s_ptr, s_shape = sparse.csm_properties(svar) + + c_data = structured_add_s_v_csx(s_val, s_ind, s_ptr, dvar) + + return [CSx(c_data, s_ind, s_ptr, s_shape)] + + return False + + +register_specialize(local_structured_add_s_v, "cxx_only") + + +class SamplingDotCSR(_NoPythonCOp): + r""" + An operator optimized for calculating the dot product :math:`x y^\top = z` + when one only wants to calculate a subset of :math:`z`. + + This is equivalent to :math:`p \circ (x \cdot y^\top)` where :math:`\circ` is + the element-wise product, :math:`x` and :math:`y` operands of the dot + product, and :math:`p` is a matrix that contains 1 when the corresponding + element of :math:`z` should be calculated and 0 when it shouldn't. Note + that `SamplingDot` has a different interface than ``dot`` because + `SamplingDot` requires :math:`x` to be a :math:`m \times k` matrix while + :math:`y` is a :math:`n \times k` matrix instead of the usual :math:``k + \times n` matrix. + + Notes + ----- + It will work if the pattern is not binary value, but if the + pattern doesn't have a high sparsity proportion it will be slower + then a more optimized dot followed by a normal element-wise + multiplication. + + If we have the input of mixed dtype, we insert cast element-wise + in the graph to be able to call BLAS function as they don't + allow mixed dtype. + + This `Op` is used as a rewritten form of `SamplingDot`. + + """ + + __props__ = () + + def make_node(self, x, y, p_data, p_ind, p_ptr, p_ncols): + """ + + Parameters + ---------- + x + Tensor matrix. + y + Tensor matrix. + p_data + Sparse matrix data. + p_ind + Sparse matrix indices. + p_ptr + Sparse matric indptr. + p_ncols + Sparse matrix number of columns. + + Returns + ------- + A dense matrix containing the dot product of :math:`x` by :math:`y^\top` only + where :math:`p` is 1. + + """ + x = as_tensor_variable(x) + y = as_tensor_variable(y) + p_data = as_tensor_variable(p_data) + p_ind = as_tensor_variable(p_ind) + p_ptr = as_tensor_variable(p_ptr) + p_ncols = as_tensor_variable(p_ncols) + + assert p_ncols.dtype == "int32" + + dtype_out = aes.upcast(x.type.dtype, y.type.dtype, p_data.type.dtype) + dot_out = aes.upcast(x.type.dtype, y.type.dtype) + + # We call blas ?dot function that take only param of the same type + x = cast(x, dot_out) + y = cast(y, dot_out) + + return Apply( + self, + [x, y, p_data, p_ind, p_ptr, p_ncols], + [ + tensor(dtype=dtype_out, shape=(False,)), + tensor(dtype=p_ind.type.dtype, shape=(False,)), + tensor(dtype=p_ptr.type.dtype, shape=(False,)), + ], + ) + + def c_code_cache_version(self): + return (4, blas.blas_header_version()) + + def c_support_code(self, **kwargs): + return blas.blas_header_text() + + def c_libraries(self, **kwargs): + return blas.ldflags() + + def c_compile_args(self, **kwargs): + return blas.ldflags(libs=False, flags=True) + + def c_lib_dirs(self, **kwargs): + return blas.ldflags(libs=False, libs_dir=True) + + def c_header_dirs(self, **kwargs): + return blas.ldflags(libs=False, include_dir=True) + + def c_code(self, node, name, inputs, outputs, sub): + x, y, p_data, p_ind, p_ptr, p_ncols = inputs + z_data, z_ind, z_ptr = outputs + if node.inputs[0].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for x") + if node.inputs[1].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for y") + if node.inputs[2].type.dtype in ("complex64", "complex128"): + raise NotImplementedError("Complex types are not supported for pattern") + + dot_out = aes.upcast(node.inputs[0].type.dtype, node.inputs[1].type.dtype) + + if dot_out == "float32": + conv_type = "float" + cdot = "sdot_" + else: + conv_type = "double" + cdot = "ddot_" + + # retrieve dtype number + typenum_x = node.inputs[0].type.dtype_specs()[2] + typenum_y = node.inputs[1].type.dtype_specs()[2] + typenum_p = node.inputs[2].type.dtype_specs()[2] + typenum_zd = TensorType(node.outputs[0].dtype, []).dtype_specs()[2] + typenum_zi = TensorType(node.outputs[1].dtype, []).dtype_specs()[2] + typenum_zp = TensorType(node.outputs[2].dtype, []).dtype_specs()[2] + + rval = """ + if (PyArray_NDIM(%(x)s) != 2) { +PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;} + if (PyArray_NDIM(%(y)s) != 2) { +PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;} + + if (PyArray_TYPE(%(x)s) != %(typenum_x)s) { + PyErr_SetString(PyExc_NotImplementedError, + "Invalid type for x"); + %(fail)s;} + + if (PyArray_TYPE(%(y)s) != %(typenum_y)s) { + PyErr_SetString(PyExc_NotImplementedError, + "Invalid type for y"); + %(fail)s;} + + if (PyArray_TYPE(%(p_data)s) != %(typenum_p)s) { + PyErr_SetString(PyExc_NotImplementedError, + "Invalid type for pattern"); + %(fail)s;} + + if (PyArray_DIMS(%(x)s)[1] != PyArray_DIMS(%(y)s)[1]) { + PyErr_SetString(PyExc_NotImplementedError, + "x's number of columns doesn't match y's rows! Note: sampling_dot is different from dot because y is assumed to be transposed."); + %(fail)s;} + + if (PyArray_DIMS(%(y)s)[0] != ((npy_int32 *)PyArray_DATA(%(p_ncols)s))[0] || + PyArray_DIMS(%(x)s)[0] != (PyArray_DIMS(%(p_ptr)s)[0] - 1)) + {PyErr_SetString(PyExc_NotImplementedError, + "The dimension of the pattern and the output must match"); %(fail)s;} + + // Allocate output + if (!%(z_data)s + || (PyArray_DIMS(%(z_data)s)[0] != PyArray_DIMS(%(p_data)s)[0]) + || (PyArray_TYPE(%(z_data)s) != %(typenum_zd)s) + || !(PyArray_ISCONTIGUOUS(%(z_data)s))) + { + {Py_XDECREF(%(z_data)s);} + npy_intp dims[] = {0}; + dims[0] = PyArray_DIMS(%(p_data)s)[0]; + %(z_data)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, + %(typenum_zd)s); + } + if (!%(z_ind)s + || (PyArray_DIMS(%(z_ind)s)[0] != PyArray_DIMS(%(p_ind)s)[0]) + || (PyArray_TYPE(%(z_ind)s) != %(typenum_zi)s) + || !(PyArray_ISCONTIGUOUS(%(z_ind)s))) + { + {Py_XDECREF(%(z_ind)s);} + npy_intp dims[] = {0}; + dims[0] = PyArray_DIMS(%(p_ind)s)[0]; + %(z_ind)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, + %(typenum_zi)s); + } + if (!%(z_ptr)s + || (PyArray_DIMS(%(z_ptr)s)[0] != PyArray_DIMS(%(p_ptr)s)[0]) + || (PyArray_TYPE(%(z_ptr)s) != %(typenum_zp)s) + || !(PyArray_ISCONTIGUOUS(%(z_ptr)s))) + { + {Py_XDECREF(%(z_ptr)s);} + npy_intp dims[] = {0}; + dims[0] = PyArray_DIMS(%(p_ptr)s)[0]; + %(z_ptr)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, + %(typenum_zp)s); + } + + { + // Product of MxK and NxK, output MxN + npy_intp M = PyArray_DIMS(%(x)s)[0]; + npy_intp N = PyArray_DIMS(%(y)s)[0]; + npy_intp K = PyArray_DIMS(%(y)s)[1]; + + // pointers to access actual data in the arrays passed as params. + const dtype_%(x)s* __restrict__ Dx = (dtype_%(x)s*)PyArray_DATA(%(x)s); + const dtype_%(y)s* __restrict__ Dy = (dtype_%(y)s*)PyArray_DATA(%(y)s); + const dtype_%(p_data)s* __restrict__ Dpd = (dtype_%(p_data)s*)PyArray_DATA(%(p_data)s); + const dtype_%(p_ind)s* __restrict__ Dpi = (dtype_%(p_ind)s*)PyArray_DATA(%(p_ind)s); + const dtype_%(p_ptr)s* __restrict__ Dpp = (dtype_%(p_ptr)s*)PyArray_DATA(%(p_ptr)s); + dtype_%(z_data)s* __restrict__ Dzd = (dtype_%(z_data)s*)PyArray_DATA(%(z_data)s); + dtype_%(z_ind)s* __restrict__ Dzi = (dtype_%(z_ind)s*)PyArray_DATA(%(z_ind)s); + dtype_%(z_ptr)s* __restrict__ Dzp = (dtype_%(z_ptr)s*)PyArray_DATA(%(z_ptr)s); + + const npy_intp Sdx = PyArray_STRIDES(%(x)s)[1]/PyArray_DESCR(%(x)s)->elsize; + const npy_intp Sdy = PyArray_STRIDES(%(y)s)[1]/PyArray_DESCR(%(y)s)->elsize; + const npy_intp Sdpd = PyArray_STRIDES(%(p_data)s)[0] / PyArray_DESCR(%(p_data)s)->elsize; + const npy_intp Sdpi = PyArray_STRIDES(%(p_ind)s)[0] / PyArray_DESCR(%(p_ind)s)->elsize; + const npy_intp Sdpp = PyArray_STRIDES(%(p_ptr)s)[0] / PyArray_DESCR(%(p_ptr)s)->elsize; + const npy_intp Sdzd = PyArray_STRIDES(%(z_data)s)[0] / PyArray_DESCR(%(z_data)s)->elsize; + const npy_intp Sdzi = PyArray_STRIDES(%(z_ind)s)[0] / PyArray_DESCR(%(z_ind)s)->elsize; + const npy_intp Sdzp = PyArray_STRIDES(%(z_ptr)s)[0] / PyArray_DESCR(%(z_ptr)s)->elsize; + + memcpy(Dzi, Dpi, PyArray_DIMS(%(p_ind)s)[0]*sizeof(dtype_%(p_ind)s)); + memcpy(Dzp, Dpp, PyArray_DIMS(%(p_ptr)s)[0]*sizeof(dtype_%(p_ptr)s)); + + // blas expects ints; convert here (rather than just making K etc ints) to avoid potential overflow in the negative-stride correction + if ((K > 0x7fffffffL)||(Sdx > 0x7fffffffL)||(Sdy > 0x7fffffffL)||(Sdx < -0x7fffffffL)||(Sdy < -0x7fffffffL)) + {PyErr_SetString(PyExc_NotImplementedError, "array too big for BLAS (overflows int32 index)"); %(fail)s;} + int K32 = K; + int Sdx32 = Sdx; + int Sdy32 = Sdy; + + for (npy_intp m = 0; m < M; ++m) { + for (npy_int32 n_idx = Dpp[m * Sdpp]; n_idx < Dpp[(m+1)*Sdpp]; ++n_idx) { + const npy_int32 n = Dpi[n_idx * Sdpi]; // row index of non-null value for column K + + const dtype_%(x)s* x_row = (dtype_%(x)s*)(PyArray_BYTES(%(x)s) + PyArray_STRIDES(%(x)s)[0] * m); + + const dtype_%(y)s* y_col = (dtype_%(y)s*)(PyArray_BYTES(%(y)s) + PyArray_STRIDES(%(y)s)[0] * n); + // dot expects pointer to the beginning of memory arrays, + // so when the stride is negative, we need to get the + // last element + if (Sdx < 0) + x_row += (K - 1) * Sdx; + if (Sdy < 0) + y_col += (K - 1) * Sdy; + + Dzd[n_idx * Sdzd] = Dpd[n_idx * Sdpd] * %(cdot)s(&K32, (const %(conv_type)s*)x_row, &Sdx32, (const %(conv_type)s*)y_col, &Sdy32); + } + } + } + """ % dict( + locals(), **sub + ) + + return rval + + +sampling_dot_csr = SamplingDotCSR() + + +# register a specialization to replace SamplingDot -> SamplingDotCsr +@node_rewriter([sparse.sampling_dot]) +def local_sampling_dot_csr(fgraph, node): + if not config.blas__ldflags: + # The C implementation of SamplingDotCsr relies on BLAS routines + return + if node.op == sparse.sampling_dot: + x, y, p = node.inputs + if p.type.format == "csr": + p_data, p_ind, p_ptr, p_shape = sparse.csm_properties(p) + + z_data, z_ind, z_ptr = sampling_dot_csr( + x, y, p_data, p_ind, p_ptr, p_shape[1] + ) + # This is a hack that works around some missing `Type`-related + # static shape narrowing. More specifically, + # `TensorType.convert_variable` currently won't combine the static + # shape information from `old_out.type` and `new_out.type`, only + # the broadcast patterns, and, since `CSR.make_node` doesn't do + # that either, we use `specify_shape` to produce an output `Type` + # with the same level of static shape information as the original + # `old_out`. + old_out = node.outputs[0] + new_out = specify_shape( + sparse.CSR(z_data, z_ind, z_ptr, p_shape), shape(old_out) + ) + return [new_out] + return False + + +register_specialize(local_sampling_dot_csr, "cxx_only", name="local_sampling_dot_csr") diff --git a/aesara/sparse/sharedvar.py b/aesara/sparse/sharedvar.py index d0a681ec96..47fc365b86 100644 --- a/aesara/sparse/sharedvar.py +++ b/aesara/sparse/sharedvar.py @@ -15,12 +15,6 @@ class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable): def sparse_constructor( value, name=None, strict=False, allow_downcast=None, borrow=False, format=None ): - """ - SharedVariable Constructor for SparseTensorType. - - writeme - - """ if not isinstance(value, scipy.sparse.spmatrix): raise TypeError( "Expected a sparse matrix in the sparse shared variable constructor. Received: ", diff --git a/aesara/sparse/type.py b/aesara/sparse/type.py index 767e8ea97f..e2ce91d64c 100644 --- a/aesara/sparse/type.py +++ b/aesara/sparse/type.py @@ -1,10 +1,17 @@ +from typing import Iterable, Optional, Union + import numpy as np import scipy.sparse +from typing_extensions import Literal import aesara from aesara import scalar as aes +from aesara.graph.basic import Variable from aesara.graph.type import HasDataType -from aesara.tensor.type import TensorType +from aesara.tensor.type import DenseTensorType, TensorType + + +SparsityTypes = Literal["csr", "csc", "bsr"] def _is_sparse(x): @@ -29,17 +36,6 @@ def _is_sparse(x): class SparseTensorType(TensorType, HasDataType): """A `Type` for sparse tensors. - Parameters - ---------- - dtype : numpy dtype string such as 'int64' or 'float64' (among others) - Type of numbers in the matrix. - format: str - The sparse storage strategy. - - Returns - ------- - An empty SparseVariable instance. - Notes ----- Currently, sparse tensors can only be matrices (i.e. have two dimensions). @@ -68,40 +64,48 @@ class SparseTensorType(TensorType, HasDataType): } ndim = 2 - # Will be set to SparseVariable SparseConstant later. - variable_type = None - Constant = None - - def __init__(self, format, dtype, shape=None, broadcastable=None, name=None): - if shape is None: + def __init__( + self, + format: SparsityTypes, + dtype: Union[str, np.dtype], + shape: Optional[Iterable[Optional[Union[bool, int]]]] = None, + name: Optional[str] = None, + broadcastable: Optional[Iterable[bool]] = None, + ): + if shape is None and broadcastable is None: shape = (None, None) - self.shape = shape - - if not isinstance(format, str): - raise TypeError("The sparse format parameter must be a string") - - if format in self.format_cls: - self.format = format - else: - raise NotImplementedError( + if format not in self.format_cls: + raise ValueError( f'unsupported format "{format}" not in list', ) - if broadcastable is None: - broadcastable = [False, False] - super().__init__(dtype, shape, name=name) + self.format = format + + super().__init__(dtype, shape=shape, name=name, broadcastable=broadcastable) - def clone(self, format=None, dtype=None, shape=None, **kwargs): - if format is None: - format = self.format + def clone( + self, + dtype=None, + shape=None, + broadcastable=None, + **kwargs, + ): + format: Optional[SparsityTypes] = kwargs.pop("format", self.format) if dtype is None: dtype = self.dtype if shape is None: shape = self.shape - return type(self)(format, dtype, shape) + return type(self)(format, dtype, shape=shape, **kwargs) def filter(self, value, strict=False, allow_downcast=None): + if isinstance(value, Variable): + raise TypeError( + "Expected an array-like object, but found a Variable: " + "maybe you are trying to call a function on a (possibly " + "shared) variable instead of a numeric array?" + ) + if ( isinstance(value, self.format_cls[self.format]) and value.dtype == self.dtype @@ -121,13 +125,10 @@ def filter(self, value, strict=False, allow_downcast=None): data = self.format_cls[self.format](value) up_dtype = aes.upcast(self.dtype, data.dtype) if up_dtype != self.dtype: - raise NotImplementedError( - f"Expected {self.dtype} dtype but got {data.dtype}" - ) + raise TypeError(f"Expected {self.dtype} dtype but got {data.dtype}") sp = data.astype(up_dtype) - if sp.format != self.format: - raise NotImplementedError() + assert sp.format == self.format return sp @@ -154,11 +155,25 @@ def may_share_memory(cls, a, b): def convert_variable(self, var): res = super().convert_variable(var) - if res and not isinstance(res.type, type(self)): - # TODO: Convert to this sparse format - raise NotImplementedError() + if res is None: + return res + + if not isinstance(res.type, type(self)): + if isinstance(res.type, DenseTensorType): + if self.format == "csr": + from aesara.sparse.basic import csr_from_dense + + return csr_from_dense(res) + else: + from aesara.sparse.basic import csc_from_dense + + return csc_from_dense(res) + + return None - # TODO: Convert sparse `var`s with different formats to this format? + if res.format != self.format: + # TODO: Convert sparse `var`s with different formats to this format? + return None return res diff --git a/aesara/tensor/__init__.py b/aesara/tensor/__init__.py index ef865e1bec..69c803ae12 100644 --- a/aesara/tensor/__init__.py +++ b/aesara/tensor/__init__.py @@ -103,15 +103,13 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int: # adds shared-variable constructors from aesara.tensor import sharedvar # noqa from aesara.tensor import ( # noqa - basic_opt, blas, blas_c, blas_scipy, nnet, - opt_uncanonicalize, - subtensor_opt, xlogx, ) +import aesara.tensor.rewriting # isort: off diff --git a/aesara/tensor/basic.py b/aesara/tensor/basic.py index 1ed3d855da..83a127b3c5 100644 --- a/aesara/tensor/basic.py +++ b/aesara/tensor/basic.py @@ -10,11 +10,14 @@ from collections.abc import Sequence from functools import partial from numbers import Number -from typing import Optional, Tuple, Union +from typing import Optional +from typing import Sequence as TypeSequence +from typing import Tuple, Union from typing import cast as type_cast import numpy as np from numpy.core.multiarray import normalize_axis_index +from numpy.core.numeric import normalize_axis_tuple import aesara import aesara.scalar.sharedvar @@ -24,12 +27,12 @@ from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op -from aesara.graph.opt_utils import optimize_graph +from aesara.graph.rewriting.utils import rewrite_graph from aesara.graph.type import Type from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.misc.safe_asarray import _asarray -from aesara.printing import min_informative_str, pprint +from aesara.printing import Printer, min_informative_str, pprint, set_precedence from aesara.raise_op import CheckAndRaise, assert_op from aesara.scalar import int32 from aesara.scalar.basic import ScalarConstant, ScalarVariable @@ -528,7 +531,7 @@ def get_scalar_constant_value( raise NotScalarConstantError() -class TensorFromScalar(Op): +class TensorFromScalar(COp): __props__ = () @@ -562,6 +565,25 @@ def grad(self, inp, grads): raise NotImplementedError("grad not implemented for complex dtypes") + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs + fail = sub["fail"] + + return ( + """ + %(z)s = (PyArrayObject*)PyArray_FromScalar(py_%(x)s, NULL); + if(py_%(z)s == NULL){ + %(fail)s; + } + Py_XINCREF(%(z)s); + """ + % locals() + ) + + def c_code_cache_version(self): + return (1,) + tensor_from_scalar = TensorFromScalar() @@ -921,9 +943,10 @@ def flatnonzero(a): nonzero_values : Return the non-zero elements of the input array """ - if a.ndim == 0: + _a = as_tensor_variable(a) + if _a.ndim == 0: raise ValueError("Nonzero only supports non-scalar arrays.") - return nonzero(a.flatten(), return_matrix=False)[0] + return nonzero(_a.flatten(), return_matrix=False)[0] def nonzero_values(a): @@ -1305,9 +1328,10 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None): tensor tensor the shape of x with ones on main diagonal and zeroes elsewhere of type of dtype. """ + _x = as_tensor_variable(x) if dtype is None: - dtype = x.dtype - return eye(x.shape[0], x.shape[1], k=0, dtype=dtype) + dtype = _x.dtype + return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype) def infer_broadcastable(shape): @@ -1316,7 +1340,8 @@ def infer_broadcastable(shape): `shape` will be validated and constant folded in order to determine which dimensions are broadcastable (i.e. equal to ``1``). """ - from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding + from aesara.tensor.rewriting.basic import topo_constant_folding + from aesara.tensor.rewriting.shape import ShapeFeature def check_type(s): if s.type.dtype in integer_dtypes: @@ -1336,7 +1361,7 @@ def check_type(s): features=[ShapeFeature()], clone=True, ) - folded_shape = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs + folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape) return sh, bcast @@ -1690,6 +1715,21 @@ def R_op(self, inputs, eval_points): make_vector = MakeVector() +class MakeVectorPrinter(Printer): + def process(self, r, pstate): + if r.owner is None: + raise TypeError("Can only print make_vector.") + elif isinstance(r.owner.op, MakeVector): + with set_precedence(pstate): + s = [pstate.pprinter.process(inp) for inp in r.owner.inputs] + return f"[{', '.join(s)}]" + else: + raise TypeError("Can only print make_vector.") + + +pprint.assign(MakeVector, MakeVectorPrinter()) + + @_get_vector_length.register(MakeVector) def _get_vector_length_MakeVector(op, var): return len(var.owner.inputs) @@ -1855,11 +1895,12 @@ def make_node(self, x, axis, splits): if splits.type.ndim == 1 and splits.type.dtype not in integer_dtypes: raise TypeError("`splits` parameter must be tensors of integer type") - if axis.type.dtype not in integer_dtypes: + if axis.type.dtype not in integer_dtypes or axis.ndim != 0: raise TypeError("`axis` parameter must be an integer scalar") inputs = [x, axis, splits] - outputs = [x.type() for i in range(self.len_splits)] + out_type = TensorType(dtype=x.dtype, shape=[None] * x.type.ndim) + outputs = [out_type() for i in range(self.len_splits)] return Apply(self, inputs, outputs) @@ -2449,30 +2490,33 @@ def roll(x, shift, axis=None): Output tensor, with the same shape as ``x``. """ + _x = as_tensor_variable(x) if axis is None: - if x.ndim > 1: - y = x.flatten() - return roll(y, shift, axis=0).reshape(x.shape) + if _x.ndim > 1: + y = _x.flatten() + return roll(y, shift, axis=0).reshape(_x.shape) else: axis = 0 if axis < 0: - axis += x.ndim + axis += _x.ndim # Shift may be larger than the size of the axis. If so, since the # roll operation is cyclic, we can take the shift modulo the size # of the axis - shift = shift % x.shape[axis] + shift = shift % _x.shape[axis] # A slice of all elements in a dimension ':' allslice = slice(None) # List of slices describing the front half [:, :, shift:, :] front_slice = slice(-shift, None) - front_list = [allslice] * axis + [front_slice] + [allslice] * (x.ndim - axis - 1) + front_list = [allslice] * axis + [front_slice] + [allslice] * (_x.ndim - axis - 1) # List of slices describing the back half [:, :, :shift, :] end_slice = slice(0, -shift) - end_list = [allslice] * axis + [end_slice] + [allslice] * (x.ndim - axis - 1) - return join(axis, x.__getitem__(tuple(front_list)), x.__getitem__(tuple(end_list))) + end_list = [allslice] * axis + [end_slice] + [allslice] * (_x.ndim - axis - 1) + return join( + axis, _x.__getitem__(tuple(front_list)), _x.__getitem__(tuple(end_list)) + ) def stack(*tensors, **kwargs): @@ -2667,7 +2711,7 @@ def is_flat(var, ndim=None, outdim=None): elif outdim is not None and ndim is not None: raise ValueError("You should only specify ndim") elif outdim is not None: - warnings.warn("flatten outdim parameter is deprecated, use ndim instead.") + warnings.warn("outdim` is deprecated; use `ndim` instead.") ndim = outdim return var.ndim == ndim @@ -2735,8 +2779,9 @@ def tile(x, reps, ndim=None): """ from aesara.tensor.math import ge - if ndim is not None and ndim < x.ndim: - raise ValueError("ndim should be equal or larger than x.ndim") + _x = as_tensor_variable(x) + if ndim is not None and ndim < _x.ndim: + raise ValueError("ndim should be equal or larger than _x.ndim") # If reps is a scalar, integer or vector, we convert it to a list. if not isinstance(reps, (list, tuple)): @@ -2761,8 +2806,8 @@ def tile(x, reps, ndim=None): # assert that reps.shape[0] does not exceed ndim offset = assert_op(offset, ge(offset, 0)) - # if reps.ndim is less than x.ndim, we pad the reps with - # "1" so that reps will have the same ndim as x. + # if reps.ndim is less than _x.ndim, we pad the reps with + # "1" so that reps will have the same ndim as _x. reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)] reps = reps_ @@ -2779,17 +2824,17 @@ def tile(x, reps, ndim=None): ): raise ValueError("elements of reps must be scalars of integer dtype") - # If reps.ndim is less than x.ndim, we pad the reps with - # "1" so that reps will have the same ndim as x + # If reps.ndim is less than _x.ndim, we pad the reps with + # "1" so that reps will have the same ndim as _x reps = list(reps) if ndim is None: - ndim = builtins.max(len(reps), x.ndim) + ndim = builtins.max(len(reps), _x.ndim) if len(reps) < ndim: reps = [1] * (ndim - len(reps)) + reps - _shape = [1] * (ndim - x.ndim) + [x.shape[i] for i in range(x.ndim)] + _shape = [1] * (ndim - _x.ndim) + [_x.shape[i] for i in range(_x.ndim)] alloc_shape = reps + _shape - y = alloc(x, *alloc_shape) + y = alloc(_x, *alloc_shape) shuffle_ind = np.arange(ndim * 2).reshape(2, ndim) shuffle_ind = shuffle_ind.transpose().flatten() y = y.dimshuffle(*shuffle_ind) @@ -3250,8 +3295,9 @@ def inverse_permutation(perm): Each row of input should contain a permutation of the first integers. """ + _perm = as_tensor_variable(perm) return permute_row_elements( - arange(perm.shape[-1], dtype=perm.dtype), perm, inverse=True + arange(_perm.shape[-1], dtype=_perm.dtype), _perm, inverse=True ) @@ -3537,12 +3583,14 @@ def diag(v, k=0): """ - if v.ndim == 1: - return AllocDiag(k)(v) - elif v.ndim >= 2: - return diagonal(v, offset=k) + _v = as_tensor_variable(v) + + if _v.ndim == 1: + return AllocDiag(k)(_v) + elif _v.ndim >= 2: + return diagonal(_v, offset=k) else: - raise ValueError("Input must has v.ndim >= 1.") + raise ValueError("Number of dimensions of `v` must be greater than one.") def stacklists(arg): @@ -3590,6 +3638,51 @@ def swapaxes(y, axis1, axis2): return y.dimshuffle(li) +def moveaxis( + a: Union[np.ndarray, TensorVariable], + source: Union[int, TypeSequence[int]], + destination: Union[int, TypeSequence[int]], +) -> TensorVariable: + """Move axes of a TensorVariable to new positions. + + Other axes remain in their original order. + + Parameters + ---------- + a + The TensorVariable whose axes should be reordered. + source + Original positions of the axes to move. These must be unique. + destination + Destination positions for each of the original axes. These must also be + unique. + + Returns + ------- + result + TensorVariable with moved axes. + + """ + + a = as_tensor_variable(a) + + source = normalize_axis_tuple(source, a.ndim, "source") + destination = normalize_axis_tuple(destination, a.ndim, "destination") + + if len(source) != len(destination): + raise ValueError( + "`source` and `destination` arguments must have the same number of elements" + ) + + order = [n for n in range(a.ndim) if n not in source] + + for dest, src in sorted(zip(destination, source)): + order.insert(dest, src) + + result = a.dimshuffle(order) + return result + + def choose(a, choices, mode="raise"): """ Construct an array from an index array and a set of arrays to choose from. @@ -3969,6 +4062,7 @@ def take_along_axis(arr, indices, axis=0): "atleast_3d", "choose", "swapaxes", + "moveaxis", "stacklists", "diag", "diagonal", diff --git a/aesara/tensor/basic_opt.py b/aesara/tensor/basic_opt.py index 1760537381..dc287a60ed 100644 --- a/aesara/tensor/basic_opt.py +++ b/aesara/tensor/basic_opt.py @@ -1,3577 +1,13 @@ -""" Tensor optimizations addressing the ops in basic.py.""" +import warnings -import logging -import sys -import time -import traceback -from collections import defaultdict -from io import StringIO -from typing import Optional -import numpy as np - -import aesara -import aesara.scalar.basic as aes -from aesara import compile -from aesara.compile.ops import ViewOp -from aesara.configdefaults import config -from aesara.graph.basic import ( - Apply, - Constant, - Variable, - ancestors, - equal_computations, - io_toposort, -) -from aesara.graph.features import AlreadyThere, Feature, ReplaceValidate -from aesara.graph.fg import FunctionGraph -from aesara.graph.op import compute_test_value, get_test_value -from aesara.graph.opt import ( - GlobalOptimizer, - OpRemove, - check_chain, - copy_stack_trace, - in2out, - local_optimizer, -) -from aesara.graph.optdb import SequenceDB -from aesara.graph.utils import ( - InconsistencyError, - MethodNotDefined, - TestValueError, - get_variable_trace_string, -) -from aesara.printing import Printer, pprint, set_precedence -from aesara.raise_op import Assert, CheckAndRaise, assert_op -from aesara.tensor.basic import ( - Alloc, - AllocEmpty, - Join, - MakeVector, - ScalarFromTensor, - Split, - TensorFromScalar, - alloc, - as_tensor_variable, - cast, - constant, - extract_constant, - fill, - get_scalar_constant_value, - join, - ones_like, - stack, - switch, - tensor_copy, - zeros, - zeros_like, -) -from aesara.tensor.elemwise import DimShuffle, Elemwise -from aesara.tensor.exceptions import NotScalarConstantError, ShapeError -from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape -from aesara.tensor.math import all as at_all -from aesara.tensor.math import eq -from aesara.tensor.shape import ( - Reshape, - Shape, - Shape_i, - SpecifyShape, - Unbroadcast, - shape_i, - shape_padleft, - specify_shape, - unbroadcast, -) -from aesara.tensor.sort import TopKOp -from aesara.tensor.subtensor import Subtensor, get_idx_list -from aesara.tensor.type import ( - DenseTensorType, - TensorType, - discrete_dtypes, - integer_dtypes, -) -from aesara.tensor.type_other import NoneConst -from aesara.tensor.var import TensorConstant -from aesara.utils import NoDuplicateOptWarningFilter - - -_logger = logging.getLogger("aesara.tensor.basic_opt") -_logger.addFilter(NoDuplicateOptWarningFilter()) - - -def encompasses_broadcastable(b1, b2): - """ - - Parameters - ---------- - b1 - The broadcastable attribute of a tensor type. - b2 - The broadcastable attribute of a tensor type. - - Returns - ------- - bool - True if the broadcastable patterns b1 and b2 are such that b2 is - broadcasted to b1's shape and not the opposite. - - """ - if len(b1) < len(b2): - return False - b1 = b1[-len(b2) :] - return not any(v1 and not v2 for v1, v2 in zip(b1, b2)) - - -def merge_broadcastables(broadcastables): - return [all(bcast) for bcast in zip(*broadcastables)] - - -def broadcast_like(value, template, fgraph, dtype=None): - """ - Return a Variable with the same shape and dtype as the template, - filled by broadcasting value through it. `value` will be cast as - necessary. - - """ - value = as_tensor_variable(value) - if value.type.is_super(template.type): - return value - if template not in fgraph.variables: - raise NotImplementedError( - "broadcast_like currently requires the " - "template Variable to be in the fgraph already" - ) - if dtype is None: - dtype = template.dtype - value = cast(value, dtype) - if value.type.is_super(template.type): - return value - if hasattr(fgraph, "shape_feature"): - new_shape = fgraph.shape_feature.shape_of[template] - else: - new_shape = template.shape - rval = alloc(value, *new_shape) - assert rval.type.dtype == dtype - - return rval - - -class InplaceElemwiseOptimizer(GlobalOptimizer): - r""" - This is parameterized so that it works for `Elemwise` `Op`\s. - """ - - def __init__(self, OP): - self.op = OP - - def add_requirements(self, fgraph): - from aesara.graph.destroyhandler import DestroyHandler - - fgraph.attach_feature(DestroyHandler()) - - @staticmethod - def print_profile(stream, prof, level=0): - blanc = " " * level - print(blanc, "InplaceElemwiseOptimizer ", prof["opt"].op, file=stream) - for k in [ - "node_before", - "nb_call_replace", - "nb_call_validate", - "nb_inconsistent", - ]: - print(blanc, k, prof[k], file=stream) - ndim = prof["ndim"] - if ndim: - print(blanc, "ndim", "nb", file=stream) - for n in sorted(ndim.keys()): - print(blanc, n, ndim[n], file=stream) - - def apply(self, fgraph): - """ - Usage: InplaceElemwiseOptimizer(op).optimize(fgraph) - - Attempts to replace all Broadcast ops by versions of them - that operate inplace. It operates greedily: for each Broadcast - Op that is encountered, for each output, tries each input to - see if it can operate inplace on that input. If so, makes the - change and go to the next output or Broadcast Op. - - Examples - -------- - - `x + y + z -> x += y += z` - - `(x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y)` - - """ - # We should not validate too often as this takes too much time to - # execute! - # It is the _dfs_toposort() fct in aesara/graph/destroyhandler.py - # that takes so much time. - # Should we try to use another lib that does toposort? - # igraph: http://igraph.sourceforge.net/ - # networkx: https://networkx.lanl.gov/ - # Should we try to use cython? - # Compiling only that fct is not enough, should we try to add the - # deque class too? - # And init the deque and other list to an upper bound number of - # elements? - # Maybe Aesara should do online toposort as in - # http://code.google.com/p/acyclic - # - # The next longest optimizer is the canonizer phase. - # Then I think it is the [io_?]toposort (need to validate) so check if - # the solution is also applicable there. - - # We execute `validate` after this number of change. - prof = { - "opt": self, - "node_before": len(fgraph.apply_nodes), - "nb_call_replace": 0, - "nb_call_validate": 0, - "nb_inconsistent": 0, - "ndim": defaultdict(lambda: 0), - } - - check_each_change = config.tensor__insert_inplace_optimizer_validate_nb - if check_each_change == -1: - if len(fgraph.apply_nodes) > 500: - check_each_change = 10 - else: - check_each_change = 1 - - nb_change_no_validate = 0 - chk = fgraph.checkpoint() - - if fgraph.update_mapping: - update_outs = [fgraph.outputs[i] for i in fgraph.update_mapping] - else: - update_outs = [] - - protected_inputs = [ - f.protected - for f in fgraph._features - if isinstance(f, aesara.compile.function.types.Supervisor) - ] - protected_inputs = sum(protected_inputs, []) # flatten the list - protected_inputs.extend(fgraph.outputs) - for node in list(io_toposort(fgraph.inputs, fgraph.outputs)): - op = node.op - if not isinstance(op, self.op): - continue - # If big graph and the outputs are scalar, do not make it - # inplace. - if ( - check_each_change != 1 - and - # If multiple outputs, they must all have the same size, - # so only check the first. - getattr(node.outputs[0].type, "ndim", -1) == 0 - ): - continue - - if op.inplace_pattern: - # Maybe this isn't needed anymore, but I don't want to - # rish regression now. This case only happen if the - # original node add already some inplace patter and we - # still try to add more pattern. - - baseline = op.inplace_pattern - candidate_outputs = [ - i for i in range(len(node.outputs)) if i not in baseline - ] - # node inputs that are Constant, already destroyed, - # or fgraph protected inputs and fgraph outputs can't be used as - # inplace target. - # Remove here as faster. - candidate_inputs = [ - i - for i in range(len(node.inputs)) - if i not in baseline.values() - and not isinstance(node.inputs[i], Constant) - and - # the next line should not be costly most of the time. - not fgraph.has_destroyers([node.inputs[i]]) - and node.inputs[i] not in protected_inputs - ] - else: - baseline = [] - candidate_outputs = list(range(len(node.outputs))) - # node inputs that are Constant, already destroyed, - # fgraph protected inputs and fgraph outputs can't be used as inplace - # target. - # Remove here as faster. - candidate_inputs = [ - i - for i in range(len(node.inputs)) - if not isinstance(node.inputs[i], Constant) - and not fgraph.has_destroyers([node.inputs[i]]) - and node.inputs[i] not in protected_inputs - ] - - verbose = False - - raised_warning = not verbose - - for candidate_output in candidate_outputs: - - # If the output of the node can be established as an update - # output of the fgraph, visit the candidate_inputs in an order - # that will improve the chances of making the node operate - # inplace on the input it's meant to update - candidate_out_var = node.outputs[candidate_output] - sorted_candidate_inputs = candidate_inputs - - if candidate_out_var in update_outs: - - # The candidate output is an update. Sort the - # variables in candidate_inputs in the following order: - # - Vars corresponding to the actual updated input - # (best case scenario is for the node that procudes - # an update to operate inplace on the variable to - # update) - # - Vars computed inplace on the updates input (second - # best scenario if for the node to work inplace on - # a variable obtained by a chain of inplace on the - # variable to update. In some cases, this will be - # equivalent to operating inplace on the variable to - # update) - # - Remaining variables - updated_inputs = [] - for i, f_out in enumerate(fgraph.outputs): - if f_out is candidate_out_var and i in fgraph.update_mapping: - updated_inp_idx = fgraph.update_mapping[i] - updated_inputs.append(fgraph.inputs[updated_inp_idx]) - - updated_vars = [] - vars_from_inplace = [] - other_vars = [] - for inp_idx in candidate_inputs: - inp = node.inputs[inp_idx] - if inp in updated_inputs: - # the candidate input is the actual updated input - updated_vars.append(inp_idx) - elif ( - hasattr(fgraph, "destroy_handler") - and inp.owner - and any( - fgraph.destroy_handler.root_destroyer.get(up_inp, None) - is inp.owner - for up_inp in updated_inputs - ) - ): - - # the candidate input is a variable computed - # inplace on the updated input via a sequence of - # one or more inplace operations - vars_from_inplace.append(inp_idx) - else: - other_vars.append(inp_idx) - - sorted_candidate_inputs = ( - updated_vars + vars_from_inplace + other_vars - ) - - for candidate_input in sorted_candidate_inputs: - # remove inputs that don't have the same dtype as the output - if ( - node.inputs[candidate_input].type - != node.outputs[candidate_output].type - ): - continue - - inplace_pattern = dict(baseline) - inplace_pattern[candidate_output] = candidate_input - try: - if hasattr(op.scalar_op, "make_new_inplace"): - new_scal = op.scalar_op.make_new_inplace( - aes.transfer_type( - *[ - inplace_pattern.get(i, o.dtype) - for i, o in enumerate(node.outputs) - ] - ) - ) - else: - new_scal = op.scalar_op.__class__( - aes.transfer_type( - *[ - inplace_pattern.get(i, None) - for i in range(len(node.outputs)) - ] - ) - ) - new_outputs = self.op(new_scal, inplace_pattern)( - *node.inputs, return_list=True - ) - new_node = new_outputs[0].owner - - for r, new_r in zip(node.outputs, new_outputs): - prof["nb_call_replace"] += 1 - fgraph.replace( - r, new_r, reason="inplace_elemwise_optimizer" - ) - nb_change_no_validate += 1 - prof["ndim"][candidate_out_var.ndim] += 1 - if nb_change_no_validate >= check_each_change: - prof["nb_call_validate"] += 1 - fgraph.validate() - chk = fgraph.checkpoint() - nb_change_no_validate = 0 - except (ValueError, InconsistencyError) as e: - prof["nb_inconsistent"] += 1 - if check_each_change != 1 and not raised_warning: - print( - ( - "Some inplace optimization was not " - "performed due to unexpected error:" - ), - file=sys.stderr, - ) - print(e, file=sys.stderr) - raised_warning = True - fgraph.revert(chk) - continue - candidate_inputs.remove(candidate_input) - node = new_node - baseline = inplace_pattern - break - - if nb_change_no_validate > 0: - try: - fgraph.validate() - except Exception: - if not raised_warning: - print( - ( - "Some inplace optimization was not " - "performed due to unexpected error" - ), - file=sys.stderr, - ) - fgraph.revert(chk) - return prof - - def print_summary(self, stream=sys.stdout, level=0, depth=-1): - print( - f"{' ' * level}{self.__class__.__name__} ({self.op})", - file=stream, - ) - return inplace_elemwise_optimizer - - -inplace_elemwise_optimizer = InplaceElemwiseOptimizer(Elemwise) -compile.optdb.register( - "inplace_elemwise_opt", - inplace_elemwise_optimizer, - "inplace_opt", # for historic reason - "inplace_elemwise_optimizer", - "fast_run", - "inplace", - position=75, -) - - -def register_useless(lopt, *tags, **kwargs): - if isinstance(lopt, str): - - def register(inner_lopt): - return register_useless(inner_lopt, lopt, *tags, **kwargs) - - return register - else: - name = kwargs.pop("name", None) or lopt.__name__ - - compile.mode.local_useless.register( - name, lopt, "fast_run", *tags, position="last", **kwargs - ) - return lopt - - -def register_canonicalize(lopt, *tags, **kwargs): - if isinstance(lopt, str): - - def register(inner_lopt): - return register_canonicalize(inner_lopt, lopt, *tags, **kwargs) - - return register - else: - name = kwargs.pop("name", None) or lopt.__name__ - compile.optdb["canonicalize"].register( - name, lopt, "fast_run", "fast_compile", *tags, **kwargs - ) - return lopt - - -def register_stabilize(lopt, *tags, **kwargs): - if isinstance(lopt, str): - - def register(inner_lopt): - return register_stabilize(inner_lopt, lopt, *tags, **kwargs) - - return register - else: - name = kwargs.pop("name", None) or lopt.__name__ - compile.optdb["stabilize"].register(name, lopt, "fast_run", *tags, **kwargs) - return lopt - - -def register_specialize(lopt, *tags, **kwargs): - if isinstance(lopt, str): - - def register(inner_lopt): - return register_specialize(inner_lopt, lopt, *tags, **kwargs) - - return register - else: - name = kwargs.pop("name", None) or lopt.__name__ - compile.optdb["specialize"].register(name, lopt, "fast_run", *tags, **kwargs) - return lopt - - -def register_uncanonicalize(lopt, *tags, **kwargs): - if isinstance(lopt, str): - - def register(inner_lopt): - return register_uncanonicalize(inner_lopt, lopt, *tags, **kwargs) - - return register - else: - name = (kwargs and kwargs.pop("name", None)) or lopt.__name__ - compile.optdb["uncanonicalize"].register( - name, lopt, "fast_run", *tags, **kwargs - ) - return lopt - - -def register_specialize_device(lopt, *tags, **kwargs): - if isinstance(lopt, str): - - def register(inner_lopt): - return register_specialize_device(inner_lopt, lopt, *tags, **kwargs) - - return register - else: - name = (kwargs and kwargs.pop("name", None)) or lopt.__name__ - compile.optdb["specialize_device"].register( - name, lopt, "fast_run", *tags, **kwargs - ) - return lopt - - -def apply_local_dimshuffle_lift(fgraph, var): - """ - lift recursively - """ - if not var.owner: - return var - new = local_dimshuffle_lift.transform(fgraph, var.owner) - if new: - return new[0] - return var - - -def is_dimshuffle_useless(new_order, input): - """ - Checks for two types of useless dimshuffles: - 1 - dimshuffle all dimensions in order. - 2 - dimshuffle a broadcastable dimension. - """ - is_useless = True - if len(new_order) == input.type.ndim: - all_broadcastable_dims = [ - i - for (i, is_broadcastable) in enumerate(input.type.broadcastable) - if is_broadcastable - ] + ["x"] - for i in range(input.type.ndim): - if new_order[i] == i or ( - i in all_broadcastable_dims and new_order[i] in all_broadcastable_dims - ): - is_useless = True - else: - is_useless = False - break - else: - is_useless = False - return is_useless - - -@register_canonicalize -@register_specialize -@local_optimizer([DimShuffle]) -def local_dimshuffle_lift(fgraph, node): - """ - "Lifts" DimShuffle through Elemwise operations and merges - consecutive DimShuffles. Basically, applies the following - transformations on the whole graph: - - DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y)) - DimShuffle(DimShuffle(x)) => DimShuffle(x) - DimShuffle{0,1,...}(x) => x (when the dimshuffle do nothing) - - After this transform, clusters of Elemwise operations are - void of DimShuffle operations. - - """ - op = node.op - if not isinstance(op, DimShuffle): - return False - - inp = node.inputs[0] - inode = inp.owner - new_order = op.new_order - if inode and isinstance(inode.op, Elemwise) and (len(fgraph.clients[inp]) == 1): - # Don't use make_node to have tag.test_value set. - new_inputs = [] - for inp in inode.inputs: - new_inp = op.__class__(inp.type.broadcastable, op.new_order)(inp) - new_inputs.append(apply_local_dimshuffle_lift(fgraph, new_inp)) - copy_stack_trace(node.outputs[0], new_inputs) - ret = inode.op(*new_inputs, return_list=True) - return ret - if inode and isinstance(inode.op, DimShuffle): - new_order = [x == "x" and "x" or inode.op.new_order[x] for x in new_order] - inp = inode.inputs[0] - - if is_dimshuffle_useless(new_order, inp): - return [inp] - elif inode and isinstance(inode.op, DimShuffle): - ret = op.__class__(inp.type.broadcastable, new_order)(inp) - ret = apply_local_dimshuffle_lift(fgraph, ret) - copy_stack_trace(node.outputs[0], ret) - return [ret] - - -@register_canonicalize -@register_specialize -@local_optimizer([DimShuffle]) -def local_useless_dimshuffle_makevector(fgraph, node): - r"""Remove `DimShuffle`\s that drop one dimensional broadcastable `MakeVector`s. - - This rewrite is needed in order to clean up after - `local_subtensor_remove_broadcastable_index`, which produces a - not-so-intuitive canonical form for `x[0]` when `x.shape == (1,)` - (i.e. one broadcastable dimension): i.e. `x.dimshuffle(())`. - """ - - # The `DimShuffle` should be removing the single broadcastable dimension - if node.op.new_order != (): - return - - makevector_out = node.inputs[0] - - if ( - not makevector_out.owner - or not isinstance(makevector_out.owner.op, MakeVector) - or not makevector_out.broadcastable == (True,) - ): - return - - assert len(makevector_out.owner.inputs) == 1 - - return [makevector_out.owner.inputs[0]] - - -@register_canonicalize -@local_optimizer([Reshape]) -def local_useless_dimshuffle_in_reshape(fgraph, node): - """ - Removes useless DimShuffle operation inside Reshape: - - reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp) - reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp) - reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp) - reshape(col.dimshuffle(0), shp) => reshape(col, shp) - - """ - op = node.op - if not isinstance(op, Reshape): - return False - if not ( - node.inputs[0].owner is not None - and isinstance(node.inputs[0].owner.op, DimShuffle) - ): - return False - - new_order = node.inputs[0].owner.op.new_order - inp = node.inputs[0].owner.inputs[0] - broadcastables = node.inputs[0].broadcastable - new_order_of_nonbroadcast = [] - for i, bd in zip(new_order, broadcastables): - if not bd: - new_order_of_nonbroadcast.append(i) - no_change_in_order = all( - new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1] - for i in range(len(new_order_of_nonbroadcast) - 1) - ) - if no_change_in_order: - shape = node.inputs[1] - ret = op.__class__(node.outputs[0].ndim)(inp, shape) - copy_stack_trace(node.outputs[0], ret) - return [ret] - - -@register_canonicalize -@register_specialize -@local_optimizer([TensorFromScalar]) -def local_tensor_scalar_tensor(fgraph, node): - """tensor_from_scalar(scalar_from_tensor(x)) -> x""" - if isinstance(node.op, TensorFromScalar): - s = node.inputs[0] - if s.owner and isinstance(s.owner.op, ScalarFromTensor): - t = s.owner.inputs[0] - - # We don't need to copy over any stack traces here - return [t] - - -@register_canonicalize -@register_specialize -@local_optimizer([ScalarFromTensor]) -def local_scalar_tensor_scalar(fgraph, node): - """scalar_from_tensor(tensor_from_scalar(x)) -> x""" - if isinstance(node.op, ScalarFromTensor): - t = node.inputs[0] - if t.owner and isinstance(t.owner.op, TensorFromScalar): - s = t.owner.inputs[0] - - # We don't need to copy over any stack traces here - return [s] - - -class MakeVectorPrinter(Printer): - def process(self, r, pstate): - if r.owner is None: - raise TypeError("Can only print make_vector.") - elif isinstance(r.owner.op, MakeVector): - with set_precedence(pstate): - s = [pstate.pprinter.process(inp) for inp in r.owner.inputs] - return f"[{', '.join(s)}]" - else: - raise TypeError("Can only print make_vector.") - - -pprint.assign(MakeVector, MakeVectorPrinter()) - - -class ShapeFeature(Feature): - """Graph optimizer for removing all calls to shape(). - - This optimizer replaces all Shapes and Subtensors of Shapes with - Shape_i and MakeVector Ops. - - This optimizer has several goals: - - 1. to 'lift' Shapes to as close to the inputs as possible. - - 2. to infer the shape of every node in the graph in terms of the - input shapes. - - 3. remove all fills ``(at.second, at.fill)`` from the graph - - Lifting shapes as close to the inputs as possible is important for - canonicalization because it is very bad form to have to compute - something just to know how big it will be. Firstly, it is a waste - of time to compute such outputs. But it is important to get rid - of these outputs as early as possible in the compilation process - because the extra computations make it appear as if many internal - graph nodes have multiple clients. Many optimizations refuse to - work on nodes with multiple clients. - - Lifting is done by using an `.infer_shape` function if one is - present, or else using a conservative default. An Op that - supports shape-lifting should define a infer_shape(self, fgraph, node, - input_shapes) function. The argument input_shapes is a tuple of - tuples... there is an interior tuple for each input to the node. - The tuple has as many elements as dimensions. The element in - position i of tuple j represents the i'th shape component of the - j'th input. The function should return a tuple of tuples. One - output tuple for each node.output. Again, the i'th element of the - j'th output tuple represents the output[j].shape[i] of the - function. If an output is not a TensorType, then None should be - returned instead of a tuple for that output. - - For example the infer_shape for a matrix-matrix product would accept - input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),). - - Inferring the shape of internal nodes in the graph is important - for doing size-driven optimizations. If we know how big various - intermediate results will be, we can estimate the cost of many Ops - accurately, and generate c-code that is specific [e.g. unrolled] - to particular sizes. - - In cases where you cannot figure out the shape, raise a ShapeError. - - Notes - ----- - Right now there is only the ConvOp that could really take - advantage of this shape inference, but it is worth it even - just for the ConvOp. All that's necessary to do shape - inference is 1) to mark shared inputs as having a particular - shape, either via a .tag or some similar hacking; and 2) to - add an optional In() argument to promise that inputs will - have a certain shape (or even to have certain shapes in - certain dimensions). We can't automatically infer the shape of - shared variables as they can change of shape during the - execution by default. (NOT IMPLEMENTED YET, BUT IS IN TRAC) - - - **Using Shape information in Optimizations** - - To use this shape information in OPTIMIZATIONS, use the - ``shape_of`` dictionary. - - For example: - - .. code-block:: python - - try: - shape_of = fgraph.shape_feature.shape_of - except AttributeError: - # This can happen when the mode doesn't include the ShapeFeature. - return - - shape_of_output_zero = shape_of[node.output[0]] - - The ``shape_of_output_zero`` symbol will contain a tuple, whose - elements are either integers or symbolic integers. - - TODO: check to see if the symbols are necessarily - non-constant... or are integer literals sometimes Aesara - constants?? That would be confusing. - - """ - - def get_node_infer_shape(self, node): - try: - shape_infer = node.op.infer_shape - except AttributeError: - shape_infer = self.default_infer_shape - - try: - o_shapes = shape_infer( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] - ) - except ShapeError: - o_shapes = self.default_infer_shape( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] - ) - except NotImplementedError as e: - raise NotImplementedError( - "Code called by infer_shape failed raising a " - "NotImplementedError. Raising NotImplementedError to " - "indicate that a shape cannot be computed is no longer " - "supported, and one should now use ShapeError " - f"instead. The original exception message is: {e}" - ).with_traceback(e.__traceback__) - except Exception as e: - msg = ( - f"Failed to infer_shape from Op {node.op}.\nInput shapes: " - f"{[self.shape_of[r] for r in node.inputs]}\nException encountered during infer_shape: " - f"{type(e)}\nException message: {str(e)}\nTraceback: {traceback.format_exc()}" - ) - if config.on_shape_error == "raise": - raise Exception(msg).with_traceback(e.__traceback__) - else: - _logger.warning(msg) - o_shapes = self.default_infer_shape( - self.fgraph, node, [self.shape_of[r] for r in node.inputs] - ) - - return o_shapes - - def get_shape(self, var, idx): - """Optimization can call this to get the current shape_i - - It is better to call this then use directly shape_of[var][idx] - as this method should update shape_of if needed. - - TODO: Up to now, we don't update it in all cases. Update in all cases. - """ - r = self.shape_of[var][idx] - if ( - r.owner - and isinstance(r.owner.op, Shape_i) - and r.owner.inputs[0] not in self.fgraph.variables - ): - assert var.owner - node = var.owner - # recur on inputs - for i in node.inputs: - if getattr(i.type, "ndim", None) > 0: - self.get_shape(i, 0) - o_shapes = self.get_node_infer_shape(node) - assert len(o_shapes) == len(node.outputs) - - # Only change the variables and dimensions that would introduce - # extra computation - for new_shps, out in zip(o_shapes, node.outputs): - if not hasattr(out.type, "ndim"): - continue - - merged_shps = list(self.shape_of[out]) - changed = False - for i in range(out.type.ndim): - n_r = merged_shps[i] - if ( - n_r.owner - and isinstance(n_r.owner.op, Shape_i) - and n_r.owner.inputs[0] not in self.fgraph.variables - ): - changed = True - merged_shps[i] = new_shps[i] - if changed: - self.set_shape(out, merged_shps, override=True) - r = self.shape_of[var][idx] - return r - - def shape_ir(self, i, r): - """Return symbolic r.shape[i] for tensor variable r, int i.""" - if hasattr(r.type, "shape") and r.type.shape[i] is not None: - return constant(r.type.shape[i], dtype="int64") - else: - # Do not call make_node for test_value - s = Shape_i(i)(r) - try: - s = get_scalar_constant_value(s) - except NotScalarConstantError: - pass - return s - - def shape_tuple(self, r): - """Return a tuple of symbolic shape vars for tensor variable r.""" - if not hasattr(r.type, "ndim"): - # This happen for NoneConst. - return None - return tuple(self.shape_ir(i, r) for i in range(r.type.ndim)) - - def default_infer_shape(self, fgraph, node, i_shapes): - """Return a list of shape tuple or None for the outputs of node. - - This function is used for Ops that don't implement infer_shape. - Ops that do implement infer_shape should use the i_shapes parameter, - but this default implementation ignores it. - - """ - rval = [] - for r in node.outputs: - try: - rval.append(self.shape_tuple(r)) - except AttributeError: - rval.append(None) - return rval - - def unpack(self, s_i, var): - """Return a symbolic integer scalar for the shape element s_i. - - The s_i argument was produced by the infer_shape() of an Op subclass. - - var: the variable that correspond to s_i. This is just for - error reporting. - - """ - # unpack the s_i that the Op returned - assert s_i is not None - if s_i == 1: - # don't make the optimizer merge a zillion ones together - # by always returning the same object to represent 1 - return self.lscalar_one - if isinstance(s_i, float) and int(s_i) == s_i: - s_i = int(s_i) - if isinstance(s_i, (np.integer, int)) or ( - isinstance(s_i, np.ndarray) and s_i.ndim == 0 - ): - # this shape is a constant - if s_i < 0: - msg = "There is a negative shape in the graph!" - msg += get_variable_trace_string(var) - # The rest of the pipeline don't handle correctly this - # case. So we have 2 choices, stop compilation or - # consider the shape as unknown. As we have more - # chance to give the stack trace here then later, I - # choose that options as it would give better error - # message. - raise AssertionError(msg) - return constant(s_i, dtype="int64") - if isinstance(s_i, (tuple, list)): - # this dimension is the same as many of the inputs - # which tells us that if one of the inputs is known, - # the others all become known. - # TODO: should be implemented in Elemwise, and Dot - # - # worst case, we loop over shape_of and replace things - raise NotImplementedError(s_i) - - # s_i is x.shape[i] for some x, we change it to shape_of[x][i] - if ( - s_i.owner - and isinstance(s_i.owner.op, Subtensor) - and s_i.owner.inputs[0].owner - and isinstance(s_i.owner.inputs[0].owner.op, Shape) - ): - assert s_i.type.ndim == 0 - assert len(s_i.owner.op.idx_list) == 1 - - # The current Subtensor always put constant index in the graph. - # This was not True in the past. So call the Subtensor function - # that will return the right index. - idx = get_idx_list(s_i.owner.inputs, s_i.owner.op.idx_list) - assert len(idx) == 1 - idx = idx[0] - try: - i = get_scalar_constant_value(idx) - except NotScalarConstantError: - pass - else: - # Executed only if no exception was raised - x = s_i.owner.inputs[0].owner.inputs[0] - # x should already have been imported, and should be in shape_of. - s_i = self.shape_of[x][i] - - if s_i.type.dtype in integer_dtypes: - if getattr(s_i.type, "ndim", 0): - raise TypeError("Shape element must be scalar", s_i) - return s_i - else: - raise TypeError( - "Unsupported shape element", s_i, type(s_i), getattr(s_i, "type", None) - ) - - def set_shape(self, r, s, override=False): - """Assign the shape `s` to previously un-shaped variable `r`. - - Parameters - ---------- - r : a variable - s : None or a tuple of symbolic integers - override : If False, it mean r is a new object in the fgraph. - If True, it mean r is already in the fgraph and we want to - override its shape. - - """ - if not override: - assert r not in self.shape_of, "r already in shape_of" - if s is None: - self.shape_of[r] = s - else: - if not isinstance(s, (tuple, list)): - raise TypeError("shapes must be tuple/list", (r, s)) - - if r.type.ndim != len(s): - sio = StringIO() - aesara.printing.debugprint(r, file=sio, print_type=True) - raise AssertionError( - f"Something inferred a shape with {len(s)} dimensions " - f"for a variable with {int(r.type.ndim)} dimensions" - f" for the variable:\n{sio.getvalue()}" - ) - - shape_vars = [] - for i in range(r.type.ndim): - if hasattr(r.type, "shape") and r.type.shape[i] is not None: - shape_vars.append(constant(r.type.shape[i], dtype="int64")) - else: - shape_vars.append(self.unpack(s[i], r)) - assert all( - not hasattr(r.type, "broadcastable") or not r.type.broadcastable[i] or - # The two following comparison are a speed optimization - # But we never timed this speed optimization! - self.lscalar_one.equals(shape_vars[i]) - or self.lscalar_one.equals(extract_constant(shape_vars[i])) - for i in range(r.type.ndim) - ) - self.shape_of[r] = tuple(shape_vars) - for sv in shape_vars: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - - def update_shape(self, r, other_r): - """Replace shape of r by shape of other_r. - - If, on some dimensions, the shape of other_r is not informative, - keep the shape of r on those dimensions. - - """ - # other_r should already have a shape - assert other_r in self.shape_of, ("other_r not in shape_of", other_r) - other_shape = self.shape_of[other_r] - - # If other_shape has no information, call is pointless. - if other_shape is None: - return - - if r in self.shape_of: - r_shape = self.shape_of[r] - else: - # If no info is known on r's shape, use other_shape - self.set_shape(r, other_shape) - return - if ( - other_r.owner - and r.owner - and other_r.owner.inputs == r.owner.inputs - and other_r.owner.op == r.owner.op - ): - # We are doing a merge. So the 2 shapes graph will be the - # same. This is only a speed optimization to call - # ancestors() less frequently. - return - - # Merge other_shape with r_shape, giving the priority to other_shape - merged_shape = [] - for i, ps in enumerate(other_shape): - if r_shape is None and other_shape: - merged_shape.append(other_shape[i]) - elif ( - ps.owner - and isinstance(getattr(ps.owner, "op", None), Shape_i) - and ps.owner.op.i == i - and ps.owner.inputs[0] in (r, other_r) - ): - # If other_shape[i] is uninformative, use r_shape[i]. - # For now, we consider 2 cases of uninformative other_shape[i]: - # - Shape_i(i)(other_r); - # - Shape_i(i)(r). - merged_shape.append(r_shape[i]) - elif isinstance(r_shape[i], (Constant, int)): - # We do this to call less often ancestors and make - # sure we have the simplest shape possible. - merged_shape.append(r_shape[i]) - elif isinstance(other_shape[i], (Constant, int)): - # We do this to call less often ancestors and make - # sure we have the simplest shape possible. - merged_shape.append(other_shape[i]) - elif other_shape[i] == r_shape[i]: - # This mean the shape is equivalent - # We do not want to do the ancestor check in those cases - merged_shape.append(r_shape[i]) - elif r_shape[i] in ancestors([other_shape[i]]): - # Another case where we want to use r_shape[i] is when - # other_shape[i] actually depends on r_shape[i]. In that case, - # we do not want to substitute an expression with another that - # is strictly more complex. Such a substitution could also lead - # to cycles: if (in the future) r_shape[i] gets replaced by an - # expression of other_shape[i], other_shape[i] may end up - # depending on itself. - merged_shape.append(r_shape[i]) - else: - merged_shape.append(other_shape[i]) - assert all( - ( - not hasattr(r.type, "broadcastable") - or not r.type.broadcastable[i] - and not other_r.type.broadcastable[i] - ) - or - # The two following comparison are a speed optimization - # But we never timed this speed optimization! - self.lscalar_one.equals(merged_shape[i]) - or self.lscalar_one.equals( - extract_constant(merged_shape[i], only_process_constants=True) - ) - for i in range(r.type.ndim) - ) - self.shape_of[r] = tuple(merged_shape) - for sv in self.shape_of[r]: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - - def set_shape_i(self, r, i, s_i): - """Replace element i of shape_of[r] by s_i""" - assert r in self.shape_of - prev_shape = self.shape_of[r] - # prev_shape is a tuple, so we cannot change it inplace, - # so we build another one. - new_shape = [] - for j, s_j in enumerate(prev_shape): - if j == i: - new_shape.append(self.unpack(s_i, r)) - else: - new_shape.append(s_j) - assert all( - not hasattr(r.type, "broadcastable") or not r.type.broadcastable[idx] or - # The two following comparison are a speed optimization - # But we never timed this speed optimization! - self.lscalar_one.equals(new_shape[idx]) - or self.lscalar_one.equals(extract_constant(new_shape[idx])) - for idx in range(r.type.ndim) - ) - self.shape_of[r] = tuple(new_shape) - for sv in self.shape_of[r]: - self.shape_of_reverse_index.setdefault(sv, set()).add(r) - - def init_r(self, r): - """Register r's shape in the shape_of dictionary.""" - if r not in self.shape_of: - self.set_shape(r, self.shape_tuple(r)) - - def make_vector_shape(self, r): - return as_tensor_variable(self.shape_of[r], ndim=1, dtype="int64") - - def on_attach(self, fgraph): - - if hasattr(fgraph, "shape_feature"): - raise AlreadyThere("This FunctionGraph already has a ShapeFeature") - - if hasattr(self, "fgraph") and self.fgraph != fgraph: - raise Exception("This ShapeFeature is already attached to a graph") - - self.fgraph = fgraph - - fgraph.shape_feature = self - # Must be local to the object as otherwise we reuse the same - # variable for multiple fgraph! - self.lscalar_one = constant(1, dtype="int64") - assert self.lscalar_one.type.dtype == "int64" - - self.fgraph = fgraph - # Variable -> tuple(scalars) or None (All tensor vars map to tuple) - self.shape_of = {} - # Variable -> - self.scheduled = {} - # shape var -> graph v - self.shape_of_reverse_index = {} - - for node in fgraph.toposort(): - self.on_import(fgraph, node, reason="on_attach") - - def on_detach(self, fgraph): - self.shape_of = {} - self.scheduled = {} - self.shape_of_reverse_index = {} - self.fgraph = None - del fgraph.shape_feature - - def on_import(self, fgraph, node, reason): - if node.outputs[0] in self.shape_of: - # this is a revert, not really an import - for r in node.outputs + node.inputs: - assert r in self.shape_of - return - - for i, r in enumerate(node.inputs): - # make sure we have shapes for the inputs - self.init_r(r) - - o_shapes = self.get_node_infer_shape(node) - - # this is packed information - # an element of o_shapes is either None or a tuple - # elements of the tuple can be either strings, or ints - if len(o_shapes) != len(node.outputs): - raise Exception( - ( - f'The infer_shape method for the Op "{node.op}" returned a list ' - f"with the wrong number of element: len(o_shapes) = {len(o_shapes)} " - f" != len(node.outputs) = {len(node.outputs)}" - ) - ) - - # Ensure shapes are in 'int64'. This is to make sure the assert - # found in the `local_useless_subtensor` optimization does not fail. - for sh_idx, sh in enumerate(o_shapes): - if sh is None: - continue - if not isinstance(sh, (list, tuple)): - raise ValueError( - f"infer_shape of {node} didn't return a list of" - f" list. It returned '{o_shapes}'" - ) - new_shape = [] - for i, d in enumerate(sh): - # Note: we ignore any shape element that is not typed (i.e., - # does not have a 'dtype' attribute). This means there may - # still remain int elements that are int32 on 32-bit platforms, - # but this works with `local_useless_subtensor`, so for now we - # keep it this way. See #266 for a better long-term fix. - if getattr(d, "dtype", "int64") != "int64": - assert d.dtype in discrete_dtypes, (node, d.dtype) - assert str(d.dtype) != "uint64", node - new_shape += sh[len(new_shape) : i + 1] - if isinstance(d, Constant): - casted_d = constant(d.data, dtype="int64") - else: - casted_d = cast(d, "int64") - new_shape[i] = casted_d - if new_shape: - # We replace the shape with wrong dtype by the one with - # 'int64'. - new_shape += sh[len(new_shape) :] - o_shapes[sh_idx] = tuple(new_shape) - - for r, s in zip(node.outputs, o_shapes): - self.set_shape(r, s) - - def on_change_input(self, fgraph, node, i, r, new_r, reason): - if new_r not in self.shape_of: - # It happen that the fgraph didn't called on_import for some - # new_r. This happen when new_r don't have an - # owner(i.e. it is a constant or an input of the graph) - # update_shape suppose that r and new_r are in shape_of. - self.init_r(new_r) - - # This tells us that r and new_r must have the same shape if - # we didn't know that the shapes are related, now we do. - self.update_shape(new_r, r) - - # change_input happens in two cases: - # 1) we are trying to get rid of r, or - # 2) we are putting things back after a failed transaction. - - # In case 1, if r has a shape_i client, we will want to - # replace the shape_i of r with the shape of new_r. Say that - # r is *scheduled*. - # At that point, node is no longer a client of r, but of new_r - for (shpnode, idx) in fgraph.clients[r] + [(node, i)]: - if isinstance(getattr(shpnode, "op", None), Shape_i): - idx = shpnode.op.i - repl = self.shape_of[new_r][idx] - if repl.owner is shpnode: - # This mean the replacement shape object is - # exactly the same as the current shape object. So - # no need for replacement. - continue - if ( - repl.owner - and repl.owner.inputs[0] is shpnode.inputs[0] - and isinstance(repl.owner.op, Shape_i) - and repl.owner.op.i == shpnode.op.i - ): - # The replacement is a shape_i of the same - # input. So no need to do this equivalent - # replacement. - continue - - if shpnode.outputs[0] in ancestors([repl]): - raise InconsistencyError( - "This substitution would insert a cycle in the graph:" - f"node: {node}, i: {i}, r: {r}, new_r: {new_r}" - ) - - self.scheduled[shpnode] = new_r - # In case 2, if r is a variable that we've scheduled for shape update, - # then we should cancel it. - unscheduled = [k for k, v in self.scheduled.items() if v == r] - for k in unscheduled: - del self.scheduled[k] - - # In either case, r could be in shape_of.values(), that is, r itself - # is the shape of something. In that case, we want to update - # the value in shape_of, to keep it up-to-date. - for v in self.shape_of_reverse_index.get(r, []): - # The reverse index is only approximate. It is not updated on - # deletion of variables, or on change_input so it might be the - # case that there are a few extra `v`'s in it that no longer have - # a shape of r or possibly have been deleted from shape_of - # entirely. The important thing is that it permits to recall - # all variables with r in their shape. - for ii, svi in enumerate(self.shape_of.get(v, [])): - if svi == r: - self.set_shape_i(v, ii, new_r) - self.shape_of_reverse_index[r] = set() - - def same_shape( - self, - x: Variable, - y: Variable, - dim_x: Optional[int] = None, - dim_y: Optional[int] = None, - ) -> bool: - """Return ``True`` if `x` and `y` have the same shape. - - Parameters - ========== - x - The `Variable` for which its shape is to be compared with `y`'s shape. - y - The `Variable` for which its shape is to be compared with `x`'s shape. - dim_x - If non ``None``, compare only the dimension of `x` equal to - `dim_x`. - dim_y - If non ``None``, compare only the dimension of `y` equal to - `dim_y`. - - """ - sx = self.shape_of[x] - sy = self.shape_of[y] - - if sx is None or sy is None: - return False - - if dim_x is not None: - sx = [sx[dim_x]] - - if dim_y is not None: - sy = [sy[dim_y]] - - if len(sx) != len(sy): - return False - - # Canonicalize the graphs so that comparisons are reasonable - # TODO FIXME: This should *not* need to be performed manually here. - # Instead, the shape information in `self.shape_of` should be operated - # upon alongside all the other elements in a `FunctionGraph` (e.g. as - # if `self.shape_of.values()` were additional outputs). - shapes_fg = FunctionGraph( - outputs=sx + sy, - # features=[self], - clone=True, - # copy_inputs=False, - ) - from aesara.graph.opt_utils import optimize_graph - - canon_shapes = optimize_graph( - shapes_fg, custom_opt=topo_constant_folding - ).outputs - - sx = canon_shapes[: len(sx)] - sy = canon_shapes[len(sx) :] - - for dx, dy in zip(sx, sy): - if not equal_computations([dx], [dy]): - return False - - return True - - def clone(self): - return type(self)() - - -class ShapeOptimizer(GlobalOptimizer): - """Optimizer that adds `ShapeFeature` as a feature.""" - - def add_requirements(self, fgraph): - fgraph.attach_feature(ShapeFeature()) - - def apply(self, fgraph): - pass - - -class UnShapeOptimizer(GlobalOptimizer): - """Optimizer that removes `ShapeFeature` as a feature.""" - - def apply(self, fgraph): - for feature in fgraph._features: - if isinstance(feature, ShapeFeature): - fgraph.remove_feature(feature) - - -# Register it after merge1 optimization at 0. We don't want to track -# the shape of merged node. -aesara.compile.mode.optdb.register( - "ShapeOpt", ShapeOptimizer(), "fast_run", "fast_compile", position=0.1 -) -# Not enabled by default for now. Some crossentropy opt use the -# shape_feature. They are at step 2.01. uncanonicalize is at step -# 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable. -aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) - - -@register_specialize("local_alloc_elemwise") -@local_optimizer([Elemwise]) -def local_elemwise_alloc(fgraph, node): - r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s. - - `Alloc`\s are effectively a type of `Elemwise` operation - (e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so - this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to - `Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it - broadcasts). - - In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant - `Alloc`\s. - - The rewrite essentially performs the following replacement: - ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``, - when ``y.shape`` for some input ``y`` (or the combined shapes of the - non-`Alloc`\s) is sufficient to maintain the same/correct output shape. - - In it's current form, it also explicitly accounts for `DimShuffle`\s of - `Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which - introduces them as a canonicalization of `Alloc`'s with leading - broadcastable dimensions. - """ - if not isinstance(node.op, Elemwise): - return False - - # Rewrite is only applicable when there are at least two inputs - if len(node.inputs) == 1: - return None - - if len(node.outputs) > 1: - # Ensure all outputs have the same broadcast pattern - # This is a supposition that I'm not sure is always true. - assert all( - o.type.broadcastable == node.outputs[0].type.broadcastable - for o in node.outputs[1:] - ) - - # The broadcast pattern of the output must match the broadcast - # pattern of at least one of the inputs. - if not any( - i.type.broadcastable == node.outputs[0].type.broadcastable for i in node.inputs - ): - return False - - def dimshuffled_alloc(i): - return ( - isinstance(i.owner.op, DimShuffle) - and i.owner.inputs[0].owner - and isinstance(i.owner.inputs[0].owner.op, Alloc) - ) - - # At least one input must have an owner that is either a `Alloc` or a - # `DimShuffle` with an owner that is a `Alloc` -- otherwise there is - # nothing to optimize. - if not any( - i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)) - for i in node.inputs - ): - return False - - # Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a - # baseline for the dimensions. - assert_op_idx = None - for idx, i in enumerate(node.inputs): - if i.type.broadcastable == node.outputs[0].type.broadcastable: - # Prefer an input that is not a `Alloc` nor a `DimShuffle` of a - # `Alloc` so that all `Alloc`s can be optimized. - if not ( - i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)) - ): - assert_op_idx = idx - break - - # If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one - if assert_op_idx is None: - for idx, i in enumerate(node.inputs): - if (i.type.broadcastable == node.outputs[0].type.broadcastable) and ( - i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)) - ): - assert_op_idx = idx - break - - assert_op_in = node.inputs[assert_op_idx] - cmp_op = assert_op_in - new_i = [] - same_shape = fgraph.shape_feature.same_shape - for i in node.inputs: - # Remove `Alloc` - if i.owner and isinstance(i.owner.op, Alloc): - assert i.type.ndim == cmp_op.ndim - if config.experimental__local_alloc_elemwise_assert: - get_shape = fgraph.shape_feature.get_shape - cond = [] - for idx in range(i.type.ndim): - if not i.type.broadcastable[idx] and not same_shape( - i, cmp_op, idx, idx - ): - i_shp = get_shape(i, idx) - cmp_shp = get_shape(cmp_op, idx) - cond.append(eq(i_shp, cmp_shp)) - if cond: - assert_op_in = assert_op(assert_op_in, *cond) - alloc_input = i.owner.inputs[0] - if alloc_input.ndim != i.ndim: - # The `Alloc` can add dimensions to the value. - # We replace those cases with a `DimShuffle` here. - nb_dim_to_add = i.ndim - alloc_input.ndim - alloc_input = alloc_input.dimshuffle( - ["x"] * nb_dim_to_add + list(range(alloc_input.ndim)) - ) - copy_stack_trace(i, alloc_input) - new_i.append(alloc_input) - - # Remove `Alloc` in `DimShuffle` - elif i.owner and dimshuffled_alloc(i): - assert i.type.ndim == cmp_op.type.ndim - if config.experimental__local_alloc_elemwise_assert: - assert_cond = [ - eq(i.shape[idx], cmp_op.shape[idx]) - for idx in range(i.type.ndim) - if not i.type.broadcastable[idx] - and not same_shape(i, cmp_op, idx, idx) - ] - if assert_cond: - assert_op_in = assert_op(assert_op_in, *assert_cond) - alloc_input = i.owner.inputs[0].owner.inputs[0] - if alloc_input.ndim != i.owner.inputs[0].ndim: - # The `Alloc` can add dimensions to the value. - # We replace those cases with a `DimShuffle` here. - # We let later optimizations merge the nested `DimShuffle`s - nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim - alloc_input = alloc_input.dimshuffle( - ["x"] * nb_dim_to_add + list(range(alloc_input.ndim)) - ) - - # We need to keep the old `DimShuffle`. It could swap axes or - # add dimensions anywhere. - r_i = i.owner.op(alloc_input) - copy_stack_trace(i, r_i) - new_i.append(r_i) - - else: - new_i.append(i) - new_i[assert_op_idx] = assert_op_in - - # If this assert is triggered, it means we are recreating an equivalent graph - # which would result in a cyclical merge optimization. - if all(new is old for new, old in zip(new_i, node.inputs)): - return - - ret = node.op(*new_i, return_list=True) - copy_stack_trace(node.outputs, ret) - return ret - - -@register_canonicalize -@local_optimizer([Elemwise]) -def local_fill_sink(fgraph, node): - """ - f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e))) - f need to be an elemwise that isn't a fill. - """ - if not hasattr(node, "op") or not isinstance(node.op, Elemwise) or node.op == fill: - return False - models = [] - inputs = [] - for inp in node.inputs: - if inp.owner and inp.owner.op == fill: - models.append(inp.owner.inputs[0]) - inputs.append(inp.owner.inputs[1]) - else: - inputs.append(inp) - if not models: - return False - c = node.op(*inputs) - for model in models: - if ( - model.type.dtype != c.type.dtype - or model.type.broadcastable != c.type.broadcastable - ): - c = fill(model, c) - - # The newly created node c doesn't has 'clients', - # so this iteration is took place with node.outputs[0] - replacements = {node.outputs[0]: c} - for client, cl_idx in fgraph.clients[node.outputs[0]]: - if ( - hasattr(client, "op") - and isinstance(client.op, Elemwise) - and client.op != fill - ): - client_inputs = client.inputs[:] - client_inputs[cl_idx] = c - new_client = client.op(*client_inputs) - - # Add clients to new_client - fgraph.clients[new_client.owner.outputs[0]] = fgraph.clients[ - client.outputs[0] - ] - r = local_fill_sink.transform(fgraph, new_client.owner) - if not r: - continue - replacements.update(r) - return replacements - - -@register_specialize -@register_stabilize -@local_optimizer([fill]) -def local_fill_to_alloc(fgraph, node): - r"""Remove `fill`\s or replace them with `Alloc`\s. - - `Alloc`\s are preferable because they replace explicit tensor dependencies - with their dependencies on those tensors' shapes, and sometimes those - shapes can be computed without needing to compute the tensors themselves. - - XXX: This rewrite can produce inconsistent results, so do *not* consider - making it a canonicalization until those inconsistencies are - resolved/justified. - """ - shape_ref, values_ref = node.inputs - out_type = node.outputs[0].type - - if values_ref.type.broadcastable == out_type.broadcastable: - # The assumption here is that `values_ref` already has the same shape - # as `shape_ref`, so a `fill`/`Alloc` is unnecessary. - - # XXX FIXME TODO: The only way this can be determined is if one - # absolutely knows that the shapes of `shape_ref` and `values_ref` are - # equal. - # This is an old rewrite, and it's only a - # "specialization/stabilization", so we're going to leave it be for - # now. - return [values_ref] - - if shape_ref.type.broadcastable == out_type.broadcastable: - # In this case, we assume that some broadcasting is needed (otherwise - # the condition above would've been true), so we replace the `fill` - # with an `Alloc`. - o = broadcast_like(values_ref, shape_ref, fgraph, dtype=values_ref.dtype) - copy_stack_trace(node.outputs[0], o) - return [o] - - return - - -# Register this after stabilize at 1.5 to make sure stabilize don't -# get affected by less canonicalized graph due to alloc. -compile.optdb.register( - "local_fill_to_alloc", in2out(local_fill_to_alloc), "fast_run", position=1.51 -) -# Needed to clean some extra alloc added by local_fill_to_alloc -compile.optdb.register( - "local_elemwise_alloc", in2out(local_elemwise_alloc), "fast_run", position=1.52 -) - - -@register_canonicalize("fast_compile") -@register_useless -@local_optimizer([fill]) -def local_useless_fill(fgraph, node): - """fill(s,v) -> v - - This optimization is only needed in FAST_COMPILE to make the code - more readable. Normally, it is done by the local_fill_to_alloc - opt. - - """ - r, v = node.inputs - out_type = node.outputs[0].type - - if ( - v.type.dtype == out_type.dtype - and v.type.broadcastable == out_type.broadcastable - ): - return [v] - - -@register_specialize -@register_stabilize -@register_canonicalize -@register_useless -@local_optimizer([Alloc]) -def local_useless_alloc(fgraph, node): - """ - If the input type is the same as the output type (dtype and broadcast) - there is no change in the shape of the input. So this is just a simple copy - of the input. This is not needed. - """ - if not isinstance(node.op, Alloc): - return False - - inp = node.inputs[0] - output = node.outputs[0] - - if ( - inp.type.dtype == output.type.dtype - and inp.type.broadcastable == output.type.broadcastable - ): - if inp.ndim == 0: - return [inp] - else: - return [ - Assert("Shapes must be equal")( - inp, at_all(eq(inp.shape, node.inputs[1:])) - ) - ] - - -@register_specialize -@register_stabilize -@register_canonicalize -@local_optimizer([Alloc]) -def local_alloc_sink_dimshuffle(fgraph, node): - r"""Convert broadcastable leading dimensions in an `Alloc` to `DimShuffle`\s.""" - op = node.op - if not isinstance(op, Alloc): - return False - - inp = node.inputs[0] - output = node.outputs[0] - - # Check if alloc adds a broadcastable dimension with shape 1. - output_shape = node.inputs[1:] - num_dims_with_size_1_added_to_left = 0 - for i in range(len(output_shape) - inp.ndim): - if extract_constant(output_shape[i], only_process_constants=True) == 1: - num_dims_with_size_1_added_to_left += 1 - else: - break - - new_output_shape = output_shape[num_dims_with_size_1_added_to_left:] - if num_dims_with_size_1_added_to_left > 0 and len(new_output_shape) >= inp.ndim: - if ( - output.broadcastable[num_dims_with_size_1_added_to_left:] - == inp.broadcastable - ): - inner = inp - else: - inner = op(*([inp] + new_output_shape)) - dimshuffle_new_order = ["x"] * num_dims_with_size_1_added_to_left + list( - range(len(new_output_shape)) - ) - return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] - - -@local_optimizer([AllocEmpty]) -def local_alloc_empty_to_zeros(fgraph, node): - """This convert AllocEmpty to Alloc of 0. - - This help investigate NaN with NanGuardMode. Not registered by - default. To activate it, use the Aesara flag - optimizer_including=alloc_empty_to_zeros. - """ - if isinstance(node.op, AllocEmpty): - return [zeros(node.inputs, dtype=node.outputs[0].dtype)] - - -compile.optdb.register( - "local_alloc_empty_to_zeros", - in2out(local_alloc_empty_to_zeros), - # After move to gpu and merge2, before inplace. - "alloc_empty_to_zeros", - position=49.3, -) - - -@register_specialize -@register_canonicalize -@local_optimizer([Shape]) -def local_shape_to_shape_i(fgraph, node): - if isinstance(node.op, Shape): - # This optimization needs ShapeOpt and fgraph.shape_feature - if not hasattr(fgraph, "shape_feature"): - return - shape_feature = fgraph.shape_feature - ret = shape_feature.make_vector_shape(node.inputs[0]) - - # We need to copy over stack trace from input to output - copy_stack_trace(node.outputs[0], ret) - return [ret] - - -@register_specialize -@register_canonicalize -@local_optimizer([Shape_i]) -def local_track_shape_i(fgraph, node): - if not isinstance(node.op, Shape_i): - return False - - try: - shape_feature = fgraph.shape_feature - except AttributeError: - return False - - if node not in shape_feature.scheduled: - return False - - # Don't unschedule node as it could be reinserted in the - # fgraph as we don't change it in the shapefeature internal - # structure. - replacement = shape_feature.scheduled[node] - return [shape_feature.shape_of[replacement][node.op.i]] - - -@register_useless -@register_canonicalize("fast_compile") -@register_specialize -@local_optimizer([Elemwise]) -def local_useless_elemwise(fgraph, node): - """ - eq(x, x) -> 1 - neq(x, x) -> 0 - mul(x) -> x - add(x) -> x - identity(x) -> x - and(x, 1) -> x (if x.dtype == 'bool') - and(x, 0) -> zeros_like(x) - or(x, 0) -> x - or(x, 1) -> ones_like(x) (if x.dtype == 'bool') - xor(x, x) -> zeros_like(x) - - """ - if isinstance(node.op, Elemwise): - # We call zeros_like and one_like with opt=True to generate a - # cleaner graph. - dtype = node.outputs[0].dtype - - if node.op.scalar_op == aes.eq and len(node.inputs) == 2: - if node.inputs[0] == node.inputs[1]: - # it is the same var in the graph. That will always be true - ret = ones_like(node.inputs[0], dtype=dtype, opt=True) - - # Copy stack trace from input to constant output - copy_stack_trace(node.outputs[0], ret) - return [ret] - elif node.op.scalar_op == aes.neq and len(node.inputs) == 2: - if node.inputs[0] == node.inputs[1]: - # it is the same var in the graph. That will always be false - ret = zeros_like(node.inputs[0], dtype=dtype, opt=True) - - # Copy stack trace from input to constant output - copy_stack_trace(node.outputs[0], ret) - return [ret] - - elif node.op.scalar_op == aes.mul and len(node.inputs) == 1: - # No need to copy over any stack trace - return [node.inputs[0]] - - elif node.op.scalar_op == aes.add and len(node.inputs) == 1: - # No need to copy over any stack trace - return [node.inputs[0]] - elif node.op.scalar_op == aes.identity and len(node.inputs) == 1: - return [node.inputs[0]] - - elif isinstance(node.op.scalar_op, aes.AND) and len(node.inputs) == 2: - - if isinstance(node.inputs[0], TensorConstant): - const_val = extract_constant( - node.inputs[0], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [zeros_like(node.inputs[1], dtype=dtype, opt=True)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise AND, - # and this optimization would be wrong - return [node.inputs[1].astype(node.outputs[0].dtype)] - - if isinstance(node.inputs[1], TensorConstant): - const_val = extract_constant( - node.inputs[1], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise AND, - # and this optimization would be wrong - return [node.inputs[0].astype(node.outputs[0].dtype)] - - elif isinstance(node.op.scalar_op, aes.OR) and len(node.inputs) == 2: - - if isinstance(node.inputs[0], TensorConstant): - const_val = extract_constant( - node.inputs[0], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [node.inputs[1].astype(node.outputs[0].dtype)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise OR, - # and this optimization would be wrong - return [ones_like(node.inputs[1], dtype=dtype, opt=True)] - - if isinstance(node.inputs[1], TensorConstant): - const_val = extract_constant( - node.inputs[1], only_process_constants=True - ) - if not isinstance(const_val, Variable): - if const_val == 0: - return [node.inputs[0].astype(node.outputs[0].dtype)] - elif node.outputs[0].dtype == "bool": - # If the output is not Boolean, it is the bitwise OR, - # and this optimization would be wrong - return [ones_like(node.inputs[0], dtype=dtype, opt=True)] - - elif isinstance(node.op.scalar_op, aes.XOR) and len(node.inputs) == 2: - if node.inputs[0] is node.inputs[1]: - return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] - - -@register_specialize -@local_optimizer([Elemwise]) -def local_alloc_unary(fgraph, node): - """unary(alloc(x, shp)) -> alloc(unary(x), shp)""" - if isinstance(node.op, Elemwise) and len(node.inputs) == 1: - a = node.inputs[0] - if a.owner and isinstance(a.owner.op, Alloc): - x = a.owner.inputs[0] - shp = a.owner.inputs[1:] - v = node.op(x) - # at.alloc does not preserve the stacktrace of v, - # so we need to copy it over from x. - copy_stack_trace(node.outputs[0], v) - ret = alloc(cast(v, node.outputs[0].dtype), *shp) - - # at.cast does not preserve the stacktrace of x, - # so we need to copy it over to the output. - copy_stack_trace([node.outputs[0], a], ret) - return [ret] - - -@register_canonicalize -@register_specialize -@local_optimizer([Elemwise]) -def local_cast_cast(fgraph, node): - """cast(cast(x, dtype1), dtype2) - - when those contrain: - dtype1 == dtype2 - OR the base dtype is the same (int, uint, float, complex) - and the first cast cause an upcast. - - """ - if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, aes.Cast): - return - x = node.inputs[0] - if ( - not x.owner - or not isinstance(x.owner.op, Elemwise) - or not isinstance(x.owner.op.scalar_op, aes.Cast) - ): - return - - type1 = x.owner.op.scalar_op.o_type - type2 = node.op.scalar_op.o_type - base = x.owner.inputs[0] - - if type1 == type2: - # We don't need to copy over any stack traces here - return [x] - - if is_an_upcast(base.dtype, type1.dtype): - # Checking for further redundancy. Eg: int8 -> int32 -> int8 - if type2.dtype == base.dtype: - return x.owner.inputs - else: - # Apply the second cast only - v = node.op(base) - # Copy stack trace from the output of the original cast - copy_stack_trace(node.outputs[0], v) - return [v] - - -def is_an_upcast(type1, type2): - """Given two data types (as strings), check if converting to - type2 from type1 constitutes an upcast. - Differs from aesara.scalar.upcast - - """ - category = { - # The first number in the pair is the dtype (bool, uint, int, float, - # complex). Conversion from higher to lower is never an upcast. - # The second number roughly indicates the precision. Again, conversion - # from higher to lower is never an upcast. - "bool": (0, 0), - "uint8": (1, 1), - "uint16": (1, 2), - "uint32": (1, 3), - "uint64": (1, 4), - "int8": (2, 1), - "int16": (2, 2), - "int32": (2, 3), - "int64": (2, 4), - "float16": (3, 1.5), - "float32": (3, 2.5), - "float64": (3, 3.5), - "complex64": (4, 3), - "complex128": (4, 4), - } - - cat1 = category[type1] - cat2 = category[type2] - - if cat2[0] >= cat1[0] and cat2[1] > cat1[1]: - return True - else: - return False - - -@register_useless -@register_specialize -@local_optimizer(None) -def local_remove_useless_assert(fgraph, node): - if not isinstance(node.op, CheckAndRaise): - return False - - new_conds = [] - n_conds = len(node.inputs[1:]) - for c in node.inputs[1:]: - try: - const = get_scalar_constant_value(c) - - if 0 != const.ndim or const == 0: - # Should we raise an error here? How to be sure it - # is not caught? - new_conds.append(c) - except NotScalarConstantError: - new_conds.append(c) - - if len(new_conds) == 0: - return [node.inputs[0]] - - if len(new_conds) < n_conds: - new_var = node.op(*(node.inputs[:1] + new_conds)) - copy_stack_trace(node.outputs[0], new_var) - return [new_var] - - -@local_optimizer([Assert]) -def local_remove_all_assert(fgraph, node): - """An optimization disabled by default that removes all asserts from - the graph. - - Notes - ----- - See the :ref:`unsafe` section to know how to enable it. - - """ - if not isinstance(node.op, Assert): - return - - return [node.inputs[0]] - - -compile.optdb["canonicalize"].register( - "local_remove_all_assert", - local_remove_all_assert, - "unsafe", - use_db_name_as_tag=False, -) -compile.optdb["stabilize"].register( - "local_remove_all_assert", - local_remove_all_assert, - "unsafe", - use_db_name_as_tag=False, -) -compile.optdb["specialize"].register( - "local_remove_all_assert", - local_remove_all_assert, - "unsafe", - use_db_name_as_tag=False, -) -compile.optdb["useless"].register( - "local_remove_all_assert", - local_remove_all_assert, - "unsafe", - use_db_name_as_tag=False, -) - - -@register_canonicalize -@local_optimizer([Elemwise]) -def local_upcast_elemwise_constant_inputs(fgraph, node): - """This explicitly upcasts constant inputs to elemwise Ops, when - those Ops do implicit upcasting anyway. - - Rationale: it helps merge things like (1-x) and (1.0 - x). - - """ - if len(node.outputs) > 1: - return - try: - shape_i = fgraph.shape_feature.shape_i - except AttributeError: - shape_i = None - if isinstance(node.op, Elemwise): - scalar_op = node.op.scalar_op - # print "aa", scalar_op.output_types_preference - if getattr(scalar_op, "output_types_preference", None) in ( - aes.upgrade_to_float, - aes.upcast_out, - ): - # this is the kind of op that we can screw with the input - # dtypes by upcasting explicitly - output_dtype = node.outputs[0].type.dtype - new_inputs = [] - for i in node.inputs: - if i.type.dtype == output_dtype: - new_inputs.append(i) - else: - try: - # works only for scalars - cval_i = get_scalar_constant_value( - i, only_process_constants=True - ) - if all(i.broadcastable): - new_inputs.append( - shape_padleft(cast(cval_i, output_dtype), i.ndim) - ) - else: - if shape_i is None: - return - new_inputs.append( - alloc( - cast(cval_i, output_dtype), - *[shape_i(d)(i) for d in range(i.ndim)], - ) - ) - # print >> sys.stderr, "AAA", - # *[Shape_i(d)(i) for d in range(i.ndim)] - except NotScalarConstantError: - # for the case of a non-scalar - if isinstance(i, TensorConstant): - new_inputs.append(cast(i, output_dtype)) - else: - new_inputs.append(i) - - if new_inputs != node.inputs: - rval = [node.op(*new_inputs)] - if not node.outputs[0].type.is_super(rval[0].type): - # This can happen for example when floatX=float32 - # and we do the true division between and int64 - # and a constant that will get typed as int8. - - # As this is just to allow merging more case, if - # the upcast don't work, we can just skip it. - return - - # Copy over output stacktrace from before upcasting - copy_stack_trace(node.outputs[0], rval) - return rval - - -@register_useless -@register_canonicalize -@register_specialize -@local_optimizer([Unbroadcast]) -def local_useless_unbroadcast(fgraph, node): - """Remove `Unbroadcast` if it does not actually change the broadcasting pattern. - - TODO: Implement equivalent rewrite for SpecifyShape - """ - if isinstance(node.op, Unbroadcast): - x = node.inputs[0] - if x.broadcastable == node.outputs[0].broadcastable: - # No broadcastable flag was modified - # No need to copy over stack trace, - # because x should already have a stack trace. - return [x] - else: - # Keep the flags that modify something - new_axes = tuple(ax for ax in node.op.axes if x.type.shape[ax] == 1) - if new_axes == node.op.axes: - # All flags are useful - return None - else: - r = unbroadcast(x, *new_axes) - # Copy over stacktrace from previous output - copy_stack_trace(node.outputs, r) - return [r] - - -@register_canonicalize -@register_specialize -@local_optimizer([Unbroadcast]) -def local_unbroadcast_lift(fgraph, node): - """ - Lifts `Unbroadcast` through unary Elemwise operations, - and merges consecutive `Unbroadcast`s. - - Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x)) - Unbroadcast(Unbroadcast(x)) => Unbroadcast(x) - - TODO: Implement equivalent Elemwise lift for SpecifyShape - """ - op = node.op - if not isinstance(op, Unbroadcast): - return False - - inp = node.inputs[0] - inode = inp.owner - if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1: - if len(fgraph.clients.get(inp, ())) == 1: - unbroadcasted = unbroadcast(inode.inputs[0], *op.axes) - copy_stack_trace(node.outputs, unbroadcasted) - - rval = inode.op.make_node(unbroadcasted).outputs - - # Copy over stacktrace from previous output (after unbroadcasting) - # and input (after elemwise operation) to new output, because an - # error in the new graph could have been caused by either of the - # two ops. - copy_stack_trace(node.outputs + node.inputs, rval) - return rval - - if inode and isinstance(inode.op, Unbroadcast): - # Merge axis of each unbroadcast - axis = tuple(set(inode.op.axes).union(set(op.axes))) - iinput = inode.inputs[0] - rval = [unbroadcast(iinput, *axis)] - # Copy over stacktrace from previous output (after second unbroadcasting) - # and from previous input (after first unbroadcasting) because an error in - # the new graph could have been caused by either of the two Unbroadcast ops. - copy_stack_trace(node.outputs + node.inputs, rval) - return rval - - -@register_specialize -@register_canonicalize -@register_useless -@local_optimizer([Join]) -def local_join_1(fgraph, node): - """Join(i, x) => x - - Remove Join() when only one element is joined. - - """ - if not isinstance(node.op, Join): - return - tensors = node.inputs[1:] - if len(tensors) == 1: - # We don't need to copy over any stacktrace here, because the - # input variable should already have its own stacktrace. - return [tensors[0]] - - -# TODO: merge in local_useless_join -@register_useless -@register_specialize -@register_canonicalize -@local_optimizer([Join]) -def local_join_empty(fgraph, node): - """Join(i, x, y, empty) => Join(i, x, y) - - Remove empty inputs to joins. The empty inputs can be anywhere. - - """ - if not isinstance(node.op, Join): - return - new_inputs = [] - try: - join_idx = get_scalar_constant_value( - node.inputs[0], only_process_constants=True - ) - except NotScalarConstantError: - return - for idx in range(1, len(node.inputs)): - inp = node.inputs[idx] - # We can not use size == 0,, as this can change shape from 3,0 - # to 2,0. This trigger DebugMode error. This happen with - # stack(...,[]) as this add a dimshuffle on [], that add a - # dimensions with shape 1. - if isinstance(inp, Constant) and inp.data.shape[join_idx] == 0: - continue - new_inputs.append(inp) - if len(new_inputs) < len(node.inputs) - 1: - if len(new_inputs) == 0: - # at.join do not work in that case. - # constant folding will take care of this case. - return - ret = join(node.inputs[0], *new_inputs) - o = node.outputs[0] - if ret.dtype != o.dtype: - # Join can upcast some inputs - return - - # Copy over stacktrace from previous output (after join op) - # to new output, because an error in the new op must be caused - # by an error in the old join op. - copy_stack_trace(node.outputs, ret) - - return [ret] - - -@register_specialize -@register_canonicalize -@register_useless -@local_optimizer([Join]) -def local_join_make_vector(fgraph, node): - r"""Merge `MakeVector` inputs within a `Join`. - - For example: - - Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...) - - This in combination with the `local_join_1` optimization can make `Join`\s - completely disappear. - """ - if not isinstance(node.op, Join) or node.outputs[0].ndim != 1: - return - new_inputs = [node.inputs[1]] - for idx in range(2, len(node.inputs)): - inp = node.inputs[idx] - if ( - inp.owner - and isinstance(inp.owner.op, MakeVector) - and new_inputs[-1].owner - and isinstance(new_inputs[-1].owner.op, MakeVector) - and - # MakeVector have a dtype parameter - inp.owner.op == new_inputs[-1].owner.op - ): - inps = new_inputs[-1].owner.inputs + inp.owner.inputs - new_inputs[-1] = inp.owner.op(*inps) - - # Copy over stacktrace from previous output (after join op) - # to new intermediate output, because an error in the intermediate - # op must be caused by an error in the old join op. - copy_stack_trace(node.outputs, new_inputs[-1]) - else: - new_inputs.append(inp) - if len(new_inputs) < len(node.inputs) - 1: - ret = join(node.inputs[0], *new_inputs) - - # Copy over stacktrace from previous output (after join op) - # to new output, because an error in the new op must be caused - # by an error in the old join op. - copy_stack_trace(node.outputs, ret) - return [ret] - - -@register_useless("local_remove_switch_const_cond") -@register_canonicalize("fast_compile", "local_remove_switch_const_cond") -@register_specialize -@local_optimizer([Elemwise]) -def local_useless_switch(fgraph, node): - """ - This optimization makes the following changes in the graph: - - ``at.switch(cond, left, right)`` -> - ``if cond is constant and cond == 0``: right - ``if cond is constant and cond != 0``: left - ``if left is right`` -> ``left`` - - and - - ``at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X))`` -> ``shape_i{id}(X)`` - - """ - if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Switch): - - cond = extract_constant(node.inputs[0], only_process_constants=True) - - if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance( - cond, (np.number, np.bool_) - ): - if cond == 0: - correct_out = node.inputs[2] - else: - correct_out = node.inputs[1] - - if correct_out.dtype != node.outputs[0].dtype: - out = cast(correct_out, node.outputs[0].dtype) - else: - out = correct_out - - out_shape = broadcast_shape(*node.inputs) - out = alloc(out, *out_shape) - - # Copy over stacktrace from selected output to new output - copy_stack_trace(node.outputs + correct_out, out) - return [out] - - # if left is right -> left - if node.inputs[1] is node.inputs[2]: - # Note: No need to copy over stacktrace, because the input node - # already has its own stacktrace - if cond.type.is_super(node.inputs[1].type): - return [node.inputs[1]] - - ret = fill(cond, node.inputs[1]) - - # Copy over stacktrace from switch output and correct branch - copy_stack_trace(node.outputs + node.inputs[1], ret) - return [ret] - - # This case happens with scan. - # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) - left = node.inputs[1] - right = node.inputs[2] - cond_var = node.inputs[0] - if ( - cond_var.owner - and isinstance(cond_var.owner.op, Elemwise) - and isinstance(cond_var.owner.op.scalar_op, aes.LE) - and cond_var.owner.inputs[0].owner - and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) - and extract_constant(cond_var.owner.inputs[1], only_process_constants=True) - == 0 - and extract_constant(left, only_process_constants=True) == 0 - and right is cond_var.owner.inputs[0] - ): - assert node.outputs[0].type.is_super(right.type) - # No need to copy over stacktrace, because the right input node - # already has its own stacktrace - return [right] - return False - return False - - -@register_canonicalize -@local_optimizer([Elemwise]) -def local_merge_switch_same_cond(fgraph, node): - """ - Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same - condition, to enable further simplification of their branches - Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y) - """ - # node must be binary elemwise or add or mul - if not isinstance(node.op, Elemwise) or not isinstance( - node.op.scalar_op, (aes.BinaryScalarOp, aes.Add, aes.Mul) - ): - return - # all inputs must be switch - if not all( - s.owner - and isinstance(s.owner.op, Elemwise) - and isinstance(s.owner.op.scalar_op, aes.Switch) - for s in node.inputs - ): - return - # all switch conditions must be the same - cond = node.inputs[0].owner.inputs[0] - if not all(s.owner.inputs[0] is cond for s in node.inputs[1:]): - return - # pull out switch - return [ - switch( - cond, - node.op(*[s.owner.inputs[1] for s in node.inputs]), - node.op(*[s.owner.inputs[2] for s in node.inputs]), - ) - ] - - -@register_useless -@register_canonicalize -@register_specialize -@local_optimizer([Split]) -def local_useless_split(fgraph, node): - """Split{n_splits=1}(x, y) -> x - - Remove Split with only 1 split. - - """ - if isinstance(node.op, Split): - if node.op.len_splits == 1: - x, axis, splits = node.inputs - out = assert_op(x, eq(splits.shape[0], 1)) - # Copy over stacktrace from previous output node. - copy_stack_trace(node.outputs, out) - out2 = assert_op(out, eq(x.shape[axis], splits[0])) - # Copy over stacktrace from previous output node. - copy_stack_trace(out, out2) - - return [out2] - - -def local_reshape_chain(op): - @local_optimizer([op]) - def f(fgraph, node): - """ - Reshape(Reshape(shape1),shape2) -> Reshape(shape2) - - """ - if not check_chain(node, op, op): - return False - - # TODO: this can permit a failing program to run by eliminating - # the lower reshape - rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) - - # Copy over stacktrace from previous output node, as any error - # in new computational graph would have been caused by last op - # in the old computational graph. - copy_stack_trace(node.outputs, rval) - - # It might happen that the desired output of this node has a - # broadcastable pattern that does not match that of 'rval'. This is - # when originally, we were able to figure out that one of the - # dimensions of the reshape is one, but some other transformation - # replaced the shape by one for which this cannot be guessed. - # We should try to figure out why we lost the information about this - # constant value... but in the meantime, better not apply this - # optimization. - if rval.broadcastable == node.outputs[0].broadcastable: - return [rval] - else: - return False - - return f - - -register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain") - - -@register_useless -@register_canonicalize -@register_stabilize -@local_optimizer([Reshape]) -def local_useless_reshape(fgraph, node): - """ - Remove two kinds of useless reshape. - - Remove Reshape when both the input and output have a single dimension. - Remove Reshape when reshaping to the shape of the input. - - """ - op = node.op - if not isinstance(op, Reshape): - return False - - inp = node.inputs[0] - output = node.outputs[0] - output_shape = node.inputs[1] - - if inp.ndim != output.ndim: - return False - - # Simple case: both input and output have a single dimension. - # This could hide errors if the user provides inconsistent shapes. - if inp.ndim == 1 and output.ndim == 1 and inp.broadcastable == output.broadcastable: - return [inp] - - # Second case: all the shapes match the input shape - # Match Reshape(x, x.shape) - if output_shape.owner and isinstance(output_shape.owner.op, Shape): - shape_input = output_shape.owner.inputs[0] - if shape_input == inp: - return [inp] - - # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for - # broadcastable and constant dimensions - if output_shape.owner and isinstance(output_shape.owner.op, MakeVector): - output_shape_is = output_shape.owner.inputs - - shape_feature = getattr(fgraph, "shape_feature", None) - - nb_m1 = 0 - shape_match = [False] * inp.ndim - for dim in range(inp.ndim): - outshp_i = output_shape_is[dim] - # Match Shape_i{dim}(input) - if ( - outshp_i.owner - and isinstance(outshp_i.owner.op, Shape_i) - and outshp_i.owner.op.i == dim - and outshp_i.owner.inputs[0] == inp - ): - shape_match[dim] = True - continue - - # Match Shape(input)[dim] - if ( - outshp_i.owner - and isinstance(outshp_i.owner.op, Subtensor) - and len(outshp_i.owner.inputs) == 2 - and extract_constant(outshp_i.owner.inputs[1]) == dim - ): - subtensor_inp = outshp_i.owner.inputs[0] - if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape): - shape_input_i = subtensor_inp.owner.inputs[0] - if shape_input_i == inp: - shape_match[dim] = True - continue - - # Match 1 if input.broadcastable[dim] is True - cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) - if inp.broadcastable[dim] and cst_outshp_i == 1: - shape_match[dim] = True - continue - - # Match -1 - if cst_outshp_i == -1: - shape_match[dim] = True - nb_m1 += 1 - continue - - # Match shape_of[input][dim] or its constant equivalent - if shape_feature: - inpshp_i = shape_feature.get_shape(inp, dim) - if inpshp_i == outshp_i or ( - extract_constant(inpshp_i, only_process_constants=1) - == extract_constant(outshp_i, only_process_constants=1) - ): - shape_match[dim] = True - continue - - if all(shape_match) and nb_m1 <= 1: - return [inp] - - # TODO later: if all the shapes except one match, we may want to - # consider it useless as well, like we do in the 1-dim case. - return False - - -@register_canonicalize -@local_optimizer([Reshape]) -def local_reshape_to_dimshuffle(fgraph, node): - """ - Broadcastable dimensions in Reshape are replaced with dimshuffle. - - The goal is to avoid using reshape to add or remove broadcastable - dimensions, but use dimshuffle instead, so dimshuffles can cancel out - or be removed later on. - - For example: - - reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,)) - - reshape(x, (1, m, 1, n, 1, 1)) - --> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n))) - """ - op = node.op - if not isinstance(op, Reshape): - return False - - inp = node.inputs[0] - output = node.outputs[0] - output_shape = node.inputs[1] - - dimshuffle_new_order = [] - new_output_shape = [] - index = 0 # index over the output of the new reshape - for i in range(output.ndim): - # Since output_shape is a symbolic vector, we trust extract_constant - # to go through however it is formed to see if its i-th element is 1. - # We need only_process_constants=False for that. - dim = extract_constant( - output_shape[i], only_process_constants=False, elemwise=False - ) - if dim == 1: - dimshuffle_new_order.append("x") - else: - dimshuffle_new_order.append(index) - new_output_shape.append(dim) - index = index + 1 - if index != output.ndim: - inner = op.__class__(len(new_output_shape))(inp, new_output_shape) - copy_stack_trace(output, inner) - new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] - copy_stack_trace(output, new_node) - return new_node - - -@register_canonicalize -@register_stabilize -@local_optimizer([Reshape]) -def local_reshape_lift(fgraph, node): - """ - Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x)) - - This optimization is needed by optimization - log1msigm_to_softplus to get applied when there is a reshape. - - """ - if ( - isinstance(node.op, Reshape) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Elemwise) - and len(node.inputs[0].owner.inputs) == 1 - ): - r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) - # Copy stacktrace from previous Reshape op, as an error in new - # Reshape op could only have been caused by old one. - copy_stack_trace(node.outputs, r) - - e = node.inputs[0].owner.op(r) - # Copy stacktrace from both previous Reshape and UnaryElemwise op - # because an error in new cg could have been caused by either ops. - copy_stack_trace(node.outputs + node.inputs, e) - return [e] - - -register_canonicalize(OpRemove(tensor_copy), name="remove_tensor_copy") - - -@local_optimizer(None) -def constant_folding(fgraph, node): - - if not node.op.do_constant_folding(fgraph, node): - return False - - if not all(isinstance(inp, Constant) for inp in node.inputs): - return False - - storage_map = {i: [i.data] for i in node.inputs} - compute_map = {i: [True] for i in node.inputs} - for o in node.outputs: - storage_map[o] = [None] - compute_map[o] = [False] - - thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[]) - required = thunk() - - # A node whose inputs are all provided should always return successfully - assert not required - - rval = [] - for output in node.outputs: - data = storage_map[output][0] - assert compute_map[output][0], (output, data) - - # TODO: `Type` itself should provide an interface for constructing - # instances appropriate for a given constant. - # TODO: Add handling for sparse types. - if isinstance(output.type, DenseTensorType): - output_type = TensorType( - output.type.dtype, - tuple(s == 1 for s in data.shape), - name=output.type.name, - ) - else: - output_type = output.type - - v = output_type.make_constant(data) - - # We need to "narrow" types when we have additional information, - # and not "broaden" them. This is a case in which types are - # unnecessarily "broadened" - # assert not hasattr(output.type, "broadcastable") or output.type.broadcastable == tuple(s == 1 for s in data.shape) - - copy_stack_trace(output, v) - - rval.append(v) - - return rval - - -topo_constant_folding = in2out( - constant_folding, ignore_newtrees=True, name="topo_constant_folding" +warnings.warn( + "The module `aesara.tensor.basic_opt` is deprecated; use `aesara.tensor.rewriting.basic` instead.", + DeprecationWarning, + stacklevel=2, ) -register_canonicalize(topo_constant_folding, "fast_compile", final_opt=True) -register_uncanonicalize(topo_constant_folding, "fast_compile", final_opt=True) -register_stabilize(topo_constant_folding, "fast_compile", final_opt=True) -register_specialize(topo_constant_folding, "fast_compile", final_opt=True) - - -def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None): - r"""Create a recursive function that fuses `Elemwise` `Op`\s. - - The basic idea is that we loop through an `Elemwise` node's inputs, find - other `Elemwise` nodes, determine the scalars input types for all of the - `Elemwise` `Op`\s, construct a new scalar `Op` using the scalar input types - and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a - new "fused" `Elemwise`. - - It's parameterized in order to work for `Elemwise` `Op`\s. - - Parameters - ---------- - op_class : type - `Elemwise` class (the one that we want to fuse) - max_input_fct : callable - A function that returns the maximum number of inputs that this `Elemwise` - can take. - On the CPU we limit to 32 input variables since that is the maximum - NumPy support. - - maker: callable - A function with the signature ``(node, *args)`` that constructs an - `op_class` instance (e.g. ``op_class(*args)``). - - """ - if maker is None: - - def maker(node, scalar_op): - return op_class(scalar_op) - - def local_fuse(fgraph, node): - r"""Fuse `Elemwise` `Op`\s in a node. - - As part of specialization, we fuse two consecutive `Elemwise` `Op`\s of the - same shape. - - For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C - compiler do the cast. - - The number of dimensions is validated at call time by Aesara itself. - - """ - # META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!! - # TODO: use broadcast flag? - - # TODO: don't do this optimization as a localOptimizer. - # Analyze the graph in terms of elemwise subgraphs, and then - # replace each subgraph with a Composite version. - - # TODO: use malloc and copy to transfer arguments that don't - # fit within the parameter space of 256 bytes - # - # TODO: Merge with multiple output to merge when an inputs - # have multiple clients. This can't be done with a local - # optimiser. - - # TODO: Related: Support composites with multiple outputs - - # TODO: Use Composite to combine Elemwise and Reduce - # operations. We have to loop over the data anyway... might - # as well sum it up while we're at it (this can be trickier - # than i'm making it seound here. The data-traversal should be - # done contiguously, and the summing-up might not be easy or - # worthwhile if the summation axis doesn't line up with a - # contiguous dimension) - - if type(node.op) is not op_class: - return False - - if len(node.outputs) > 1: - # We don't support fusion for nodes with multiple outputs. - return - - inputs = [] # inputs of the new Elemwise op. - s_inputs = [] # inputs of the new scalar op used by the Composite. - # Inputs of the new scalar op that represents the current node. - s_g = [] - - # There is a hard limit of 256 bytes for the formal argument list to a - # GPU kernel function. - max_nb_input = max_input_fct(node) - # The number of inputs to the new fused op if we do not fuse more - # inputs. - new_nb_input = len(node.inputs) - # Did we fuse something? - # Needed as we can fuse unary op that don't change the number of - # inputs. - # And there is a case where the inputs are the same as the current - # node. That won't change the number of inputs of the new op. - fused = False - - for i in node.inputs: - scalar_node: Optional[Apply] = None - # Will store inputs of the fused node that are not currently inputs - # of the node we want to create (to avoid duplicating inputs). - tmp_input = [] - # Same as tmp_input, but for scalars. - tmp_scalar = [] - - # We should not check the number of inputs here - # As fusing op don't always change the number of input. - # If a variable is used as multiple into to the same node, - # we still want to fusion. So we take the set. - if ( - i.owner - and isinstance(i.owner.op, op_class) - and len({n for n, idx in fgraph.clients[i]}) == 1 - and - # Do not merge elemwise that don't have the same - # broadcastable pattern to don't redo duplicate - # computation due to broadcast. - i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable - ): - try: - tmp_s_input = [] - # we should not put duplicate input into s_inputs and inputs - for ii in i.owner.inputs: - if ii in inputs: - tmp_s_input.append(s_inputs[inputs.index(ii)]) - elif ii in tmp_input: - tmp_s_input.append(tmp_scalar[tmp_input.index(ii)]) - else: - tmp = aes.get_scalar_type(ii.type.dtype).make_variable() - - try: - tv = get_test_value(ii) - # Sometimes the original inputs have - # zero-valued shapes in some dimensions, which - # implies that this whole scalar thing doesn't - # make sense (i.e. we're asking for the scalar - # value of an entry in a zero-dimensional - # array). - # This will eventually lead to an error in the - # `compute_test_value` call below when/if - # `config.compute_test_value_opt` is enabled - # (for debugging, more or less) - tmp.tag.test_value = tv.item() - except (TestValueError, ValueError): - pass - - tmp_s_input.append(tmp) - tmp_input.append(ii) - tmp_scalar.append(tmp_s_input[-1]) - - # Use the `Op.make_node` interface in case `Op.__call__` - # has been customized - scalar_node = i.owner.op.scalar_op.make_node(*tmp_s_input) - - if config.compute_test_value_opt != "off": - # This is required because `Op.make_node` won't do it - compute_test_value(scalar_node) - - # If the scalar_op doesn't have a C implementation, we skip - # its fusion to allow fusion of the other ops - i.owner.op.scalar_op.c_code( - scalar_node, - "test_presence_of_c_code", - ["x" for x in i.owner.inputs], - ["z" for z in i.owner.outputs], - {"fail": "%(fail)s"}, - ) - - except (NotImplementedError, MethodNotDefined): - _logger.warning( - ( - "Optimization Warning: " - f"The Op {i.owner.op.scalar_op} does not provide a C implementation." - " As well as being potentially slow, this also disables " - "loop fusion." - ) - ) - scalar_node = None - - # Compute the number of inputs in case we fuse this input. - # We subtract 1 because we replace the existing input with the new - # inputs from `tmp_input`. - new_nb_input_ = new_nb_input + len(tmp_input) - 1 - - # If the new input is already an input of the current node, it was - # already counted when `new_nb_input` was initialized to - # len(node.inputs). - # This can happen when a variable is used both by the Elemwise to - # fuse and the current node. - for x in tmp_input: - if x in node.inputs: - new_nb_input_ -= 1 - - if scalar_node and (new_nb_input_ <= max_nb_input): - fused = True - new_nb_input = new_nb_input_ - inputs.extend(tmp_input) - s_inputs.extend(tmp_scalar) - s_g.extend(scalar_node.outputs) - else: - # We must support the case where the same variable appears many - # times within the inputs - if inputs.count(i) == node.inputs.count(i): - s = s_inputs[inputs.index(i)] - else: - s = aes.get_scalar_type(i.type.dtype).make_variable() - if config.compute_test_value_opt != "off": - try: - v = get_test_value(i) - # See the zero-dimensional test value situation - # described above. - s.tag.test_value = v.item() - except (TestValueError, ValueError): - pass - - inputs.append(i) - s_inputs.append(s) - s_g.append(s) - - if not fused: - return False - - if new_nb_input != len(inputs) or len(s_inputs) != len(inputs): - raise Exception( - """Something has gone wrong with the elemwise -fusion optimization. We skip this optimization. You can ignore this message, -your code will run correctly, but may be slower.""" - ) - - s_new_out = node.op.scalar_op(*s_g, return_list=True) - try: - s_new_out[0].owner.op.c_code( - s_new_out[0].owner, - "test_presence_of_c_code", - ["x" for x in s_g], - ["z" for x in s_new_out], - {"fail": "%(fail)s"}, - ) - except (NotImplementedError, MethodNotDefined): - name = str(s_new_out[0].owner.op) - _logger.warning( - ( - "Optimization Warning: " - f"The Op {name} does not provide a C implementation." - " As well as being potentially slow, this also disables " - "loop fusion." - ) - ) - return False - - # create the composite op. - composite_op = aes.Composite(s_inputs, s_new_out) - - # create the new node. - # Do not call make_node to have test_value - new_node = maker(node, composite_op)(*inputs).owner - - assert len(new_node.outputs) == 1 - assert node.outputs[0].type.dtype == new_node.outputs[0].type.dtype - - if len(new_node.inputs) > max_nb_input: - _logger.warning( - "Loop fusion failed because the resulting node " - "would exceed the kernel argument limit." - ) - return False - - # we fuse as many that we can at the same time to make debug mode faster - # debug mode will be faster as it won't test all intermediate step. - while True: - ret = local_fuse(fgraph, new_node) - if ret is not False and ret is not None: - assert len(ret) == len(new_node.outputs) - assert len(ret) == 1 - new_node = ret[0].owner - else: - break - - return new_node.outputs - - return local_fuse - - -def elemwise_max_input_fct(node): - # `Elemwise.perform` uses NumPy ufuncs and they are limited to 31 inputs. - if not config.cxx: - return 31 - return 1024 - - -local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct) - - -class FusionOptimizer(GlobalOptimizer): - """Graph optimizer that simply runs local fusion operations. - - TODO: This is basically a `EquilibriumOptimizer`; we should just use that. - - """ - - def __init__(self, local_optimizer): - super().__init__() - self.optimizer = local_optimizer - - def add_requirements(self, fgraph): - fgraph.attach_feature(ReplaceValidate()) - - def apply(self, fgraph): - did_something = True - nb_iter = 0 - nb_replacement = 0 - nb_inconsistency_replace = 0 - time_toposort = 0 - if fgraph.profile: - validate_before = fgraph.profile.validate_time - callbacks_before = fgraph.execute_callbacks_times.copy() - callback_before = fgraph.execute_callbacks_time - while did_something: - t0 = time.time() - nodelist = list(fgraph.toposort()) - time_toposort += time.time() - t0 - nodelist.reverse() - did_something = False - for node in nodelist: - # Don't try to fuse node that have already been fused. - if node in fgraph.apply_nodes: - new_outputs = self.optimizer(fgraph, node) - if new_outputs: - assert len(new_outputs) == len(node.outputs) - try: - fgraph.replace_all_validate( - list(zip(node.outputs, new_outputs)), - reason=self.__class__.__name__, - ) - did_something = True - nb_replacement += 1 - except InconsistencyError: - nb_inconsistency_replace += 1 - nb_iter += 1 - - if fgraph.profile: - validate_time = fgraph.profile.validate_time - validate_before - callback_time = fgraph.execute_callbacks_time - callback_before - callbacks_time = {} - for k, v in fgraph.execute_callbacks_times.items(): - if k in callbacks_before: - callbacks_time[k] = v - callbacks_before[k] - else: - callbacks_time[k] = v - else: - validate_time = None - callback_time = None - callbacks_time = {} - return ( - self, - nb_iter, - nb_replacement, - nb_inconsistency_replace, - validate_time, - callback_time, - callbacks_time, - time_toposort, - ) - - @staticmethod - def print_profile(stream, prof, level=0): - blanc = " " * level - print(blanc, "FusionOptimizer", file=stream) - print(blanc, " nb_iter", prof[1], file=stream) - print(blanc, " nb_replacement", prof[2], file=stream) - print(blanc, " nb_inconsistency_replace", prof[3], file=stream) - print(blanc, " validate_time", prof[4], file=stream) - print(blanc, " callback_time", prof[5], file=stream) - if prof[5] is not None and prof[5] > 1: - print(blanc, " callbacks_time", file=stream) - for i in sorted(prof[6].items(), key=lambda a: a[1])[::-1]: - if i[1] > 0: - print(blanc, " ", i) - print(blanc, " time_toposort", prof[7], file=stream) - - -if config.tensor__local_elemwise_fusion: - _logger.debug("Enabling Elemwise fusion optimizations in fast_run") - # Must be after gpu(48.5) and before AddDestroyHandler(49.5) - fuse_seqopt = SequenceDB() - fuse_seqopt.register( - "composite_elemwise_fusion", - FusionOptimizer(local_elemwise_fusion), - "fast_run", - "fusion", - position=1, - ) - compile.optdb.register( - "elemwise_fusion", - fuse_seqopt, - "fast_run", - "fusion", - "local_elemwise_fusion", - "FusionOptimizer", - position=49, - ) -else: - _logger.debug("not enabling optimization fusion elemwise in fast_run") - compile.optdb.register( - "elemwise_fusion", - FusionOptimizer(local_elemwise_fusion), - "fusion", - "local_elemwise_fusion", - "FusionOptimizer", - position=49, - ) - - -@register_canonicalize -@local_optimizer([Elemwise]) -def local_useless_composite(fgraph, node): - """For elemwise Composite that have multiple outputs, remove the - outputs that are not used. - - """ - if not isinstance(node.op, Elemwise) or not isinstance( - node.op.scalar_op, aes.Composite - ): - return - comp = node.op.scalar_op - idx = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]] - if len(idx) < len(node.outputs): - new_outputs = [comp.outputs[i] for i in idx] - c = aes.Composite(inputs=comp.inputs, outputs=new_outputs) - e = Elemwise(scalar_op=c)(*node.inputs, return_list=True) - return dict(zip([node.outputs[i] for i in idx], e)) - - -@register_canonicalize("fast_compile") -@register_useless("fast_compile") -@local_optimizer(None) -def local_view_op(fgraph, node): - if isinstance(node.op, ViewOp): - return node.inputs - - -@register_useless -@register_canonicalize -@register_stabilize -@register_specialize -@local_optimizer([Alloc]) -def local_merge_alloc(fgraph, node): - # This opt takes care of several cases: - # Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) - # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) - # Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> Alloc(m, x, assert(y1, y1==y2), z, w) - if not isinstance(node.op, Alloc): - return False - if not node.inputs[0].owner or not isinstance(node.inputs[0].owner.op, Alloc): - return False - inputs_outer = node.inputs - inputs_inner = node.inputs[0].owner.inputs - dims_outer = inputs_outer[1:] - dims_inner = inputs_inner[1:] - dims_outer_rev = dims_outer[::-1] - dims_inner_rev = dims_inner[::-1] - # check if the pattern of broadcasting is matched, in the reversed ordering. - # The reverse ordering is needed when an Alloc add an implicit new - # broadcasted dimensions to its inputs[0]. Eg: - # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) - i = 0 - for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev): - if dim_inner != dim_outer: - if isinstance(dim_inner, Constant) and dim_inner.data == 1: - pass - else: - dims_outer[-1 - i] = Assert( - "You have a shape error in your graph. To see a better" - " error message and a stack trace of where in your code" - " the error is created, use the Aesara flags" - " optimizer=None or optimizer=fast_compile." - )(dim_outer, eq(dim_outer, dim_inner)) - i += 1 - return [alloc(inputs_inner[0], *dims_outer)] - - -@register_useless("fast_compile") -@local_optimizer([TopKOp]) -def local_useless_topk(fgraph, node): - """ - TopKOp generates two outputs by default - This opt removes the useless ones - - """ - op = node.op - if not isinstance(op, TopKOp): - return - if not (op.return_values and op.return_indices): - return False - - x, k = node.inputs - ret_val = bool(fgraph.clients[node.outputs[0]]) - ret_idx = bool(fgraph.clients[node.outputs[1]]) - - if not (ret_val ^ ret_idx): - # both true -> nothing to remove - # both false -> let pruner handle - return False - - old_output = node.outputs[ret_idx] - new_output = TopKOp( - axis=op.axis, - sorted=op.sorted, - idx_dtype=op.idx_dtype, - return_values=ret_val, - return_indices=ret_idx, - )(x, k) - copy_stack_trace(node.outputs[0], new_output) - return {old_output: new_output} - - -@register_useless -@register_canonicalize -@local_optimizer([SpecifyShape]) -def local_merge_consecutive_specify_shape(fgraph, node): - """Replace ``specify_shape(specify_shape(x, s1), s2)`` with ``specify_shape(x, s3)``, - where s3 is the union of specified dimensions in s1 and s2, with preference given to s2. - """ - - if not isinstance(node.op, SpecifyShape): - return False - - obj = node.inputs[0] - if not (obj.owner and isinstance(obj.owner.op, SpecifyShape)): - return False - - inner_obj, *shape = obj.owner.inputs - for dim, sh in enumerate(node.inputs[1:]): - if not NoneConst.equals(sh): - shape[dim] = sh - - # TODO: We could make sure that the overlapping shapes of the two `SpecifyShape`s are - # the same. - - return [specify_shape(inner_obj, shape)] - - -@register_useless -@register_canonicalize -@local_optimizer([Shape]) -def local_Shape_of_SpecifyShape(fgraph, node): - """Replace ``specify_shape(x, s).shape`` with ``s``.""" - - if not isinstance(node.op, Shape): - return False - - specified_shape = node.inputs[0] - - if not isinstance(getattr(specified_shape.owner, "op", None), SpecifyShape): - return False - - x, *shape = specified_shape.owner.inputs - - # Replace `NoneConst` by `shape_i` - for i, sh in enumerate(shape): - if NoneConst.equals(sh): - shape[i] = shape_i(x, i, fgraph) - - return [stack(shape).astype(np.int64)] - - -@register_useless -@register_canonicalize -@local_optimizer([Shape_i]) -def local_Shape_i_of_broadcastable(fgraph, node): - """Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``.""" - - if not isinstance(node.op, Shape_i): - return False - - shape_arg = node.inputs[0] - - if not isinstance(shape_arg.type, TensorType): - return False - - if shape_arg.broadcastable[node.op.i]: - return [as_tensor_variable(1, dtype=np.int64)] - - -@register_useless -@register_canonicalize -@local_optimizer([Unique]) -def local_Unique_scalar(fgraph, node): - """Convert ``unique(x)`` to ``x`` when ``x`` is a scalar.""" - if not isinstance(node.op, Unique): - return False - - if node.op.return_index or node.op.return_inverse or node.op.return_counts: - return False - - uniqued_var = node.inputs[0] - - if uniqued_var.ndim != 0: - return False - - old_out = node.outputs[0] - res = as_tensor_variable(uniqued_var, ndim=old_out.ndim, dtype=old_out.dtype) - return [res] - - -@register_useless -@register_canonicalize -@local_optimizer([Unique]) -def local_Unique_Alloc_lift(fgraph, node): - """Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``. - - This isn't really so much a lift as a "reduction/consumption". - """ - if not isinstance(node.op, Unique): - return False - - if ( - node.op.return_index - or node.op.return_inverse - or node.op.return_counts - or node.op.axis is not None - ): - return False - - alloc_var = node.inputs[0] - - if not (alloc_var.owner and isinstance(alloc_var.owner.op, Alloc)): - return False - - alloced_var, *alloc_shape = alloc_var.owner.inputs - - new_unique, *_ = node.op.make_node(alloced_var).outputs - - old_out = node.outputs[0] - new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype) - return [new_x] - - -@register_useless -@register_canonicalize -@local_optimizer([Unique]) -def local_Unique_BroadcastTo_lift(fgraph, node): - """Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``. - - This isn't really so much a lift as a "reduction/consumption". - """ - if not isinstance(node.op, Unique): - return False - - if ( - node.op.return_index - or node.op.return_inverse - or node.op.return_counts - or node.op.axis is not None - ): - return False - - bcast_var = node.inputs[0] - - if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)): - return False - - bcasted_var, *bcast_shape = bcast_var.owner.inputs - - new_unique, *_ = node.op.make_node(bcasted_var).outputs - - old_out = node.outputs[0] - new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype) - return [new_x] - - -@register_useless -@register_canonicalize -@local_optimizer([Unique]) -def local_Unique_Repeat_lift(fgraph, node): - """Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``. - - This isn't really so much a lift as a "reduction/consumption". - """ - if not isinstance(node.op, Unique): - return False - - if ( - node.op.return_index - or node.op.return_inverse - or node.op.return_counts - or node.op.axis is not None - ): - return False - - repeat_var = node.inputs[0] - - if not (repeat_var.owner and isinstance(repeat_var.owner.op, Repeat)): - return False - - repeated_var, *repeat_shape = repeat_var.owner.inputs - - new_unique, *_ = node.op.make_node(repeated_var).outputs - - old_out = node.outputs[0] - new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype) - return [new_x] - - -@register_useless -@register_canonicalize -@local_optimizer([Unique]) -def local_Unique_second(fgraph, node): - """Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``. - - This isn't really so much a lift as a "reduction/consumption". - """ - if not isinstance(node.op, Unique): - return False - - if ( - node.op.return_index - or node.op.return_inverse - or node.op.return_counts - or node.op.axis is not None - ): - return False - - second_var = node.inputs[0] - - if not ( - second_var.owner - and isinstance(second_var.owner.op, Elemwise) - and isinstance(second_var.owner.op.scalar_op, aes.Second) - ): - return False - - shape_var, seconded_var = second_var.owner.inputs - - new_unique, *_ = node.op.make_node(seconded_var).outputs - - old_out = node.outputs[0] - new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype) - return [new_x] - - -@register_useless -@register_canonicalize -@local_optimizer([BroadcastTo]) -def local_remove_scalar_BroadcastTo(fgraph, node): - - bcast_shape = node.inputs[1:] - if not bcast_shape: - bcasted_var = node.inputs[0] - # If this isn't true, the graph is invalid - assert bcasted_var.ndim == 0 - return [bcasted_var] +from aesara.tensor.rewriting.basic import * # noqa: F401 E402 F403 +from aesara.tensor.rewriting.elemwise import * # noqa: F401 E402 F403 +from aesara.tensor.rewriting.extra_ops import * # noqa: F401 E402 F403 +from aesara.tensor.rewriting.shape import * # noqa: F401 E402 F403 diff --git a/aesara/tensor/blas.py b/aesara/tensor/blas.py index 60ac8d8314..1a3f48defa 100644 --- a/aesara/tensor/blas.py +++ b/aesara/tensor/blas.py @@ -145,25 +145,25 @@ from aesara.graph.basic import Apply, view_roots from aesara.graph.features import ReplacementDidNotRemoveError, ReplaceValidate from aesara.graph.op import Op -from aesara.graph.opt import ( - EquilibriumOptimizer, - GlobalOptimizer, +from aesara.graph.rewriting.basic import ( + EquilibriumGraphRewriter, + GraphRewriter, copy_stack_trace, in2out, - local_optimizer, + node_rewriter, ) -from aesara.graph.optdb import SequenceDB +from aesara.graph.rewriting.db import SequenceDB from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.printing import FunctionPrinter, debugprint, pprint from aesara.scalar import bool as bool_t from aesara.tensor import basic as at -from aesara.tensor.basic_opt import local_dimshuffle_lift from aesara.tensor.blas_headers import blas_header_text, blas_header_version from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.math import Dot, add, mul, neg, sub +from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift from aesara.tensor.shape import specify_broadcastable from aesara.tensor.type import ( DenseTensorType, @@ -1496,7 +1496,7 @@ def _gemm_from_node2(fgraph, node): return None, t1 - t0, 0, 0 -class GemmOptimizer(GlobalOptimizer): +class GemmOptimizer(GraphRewriter): """Graph optimizer for inserting Gemm operations.""" def __init__(self): @@ -1526,7 +1526,9 @@ def on_import(new_node): if new_node is not node: nodelist.append(new_node) - u = aesara.graph.opt.Updater(on_import, None, None, name="GemmOptimizer") + u = aesara.graph.rewriting.basic.DispatchingFeature( + on_import, None, None, name="GemmOptimizer" + ) fgraph.attach_feature(u) while did_something: nb_iter += 1 @@ -1616,10 +1618,10 @@ def on_import(new_node): callbacks_time, ) - @staticmethod - def print_profile(stream, prof, level=0): + @classmethod + def print_profile(cls, stream, prof, level=0): blanc = " " * level - print(blanc, "GemmOptimizer", file=stream) + print(blanc, cls.__name__, file=stream) print(blanc, " nb_iter", prof[1], file=stream) print(blanc, " nb_replacement", prof[2], file=stream) print(blanc, " nb_replacement_didn_t_remove", prof[3], file=stream) @@ -1733,7 +1735,7 @@ def c_code_cache_version(self): _dot22 = Dot22() -@local_optimizer([Dot]) +@node_rewriter([Dot]) def local_dot_to_dot22(fgraph, node): # This works for tensor.outer too because basic.outer is a macro that # produces a dot(dimshuffle,dimshuffle) of form 4 below @@ -1766,7 +1768,7 @@ def local_dot_to_dot22(fgraph, node): _logger.info(f"Not optimizing dot with inputs {x} {y} {x.type} {y.type}") -@local_optimizer([gemm_no_inplace], inplace=True) +@node_rewriter([gemm_no_inplace], inplace=True) def local_inplace_gemm(fgraph, node): if node.op == gemm_no_inplace: new_out = [gemm_inplace(*node.inputs)] @@ -1774,7 +1776,7 @@ def local_inplace_gemm(fgraph, node): return new_out -@local_optimizer([gemv_no_inplace], inplace=True) +@node_rewriter([gemv_no_inplace], inplace=True) def local_inplace_gemv(fgraph, node): if node.op == gemv_no_inplace: new_out = [gemv_inplace(*node.inputs)] @@ -1782,7 +1784,7 @@ def local_inplace_gemv(fgraph, node): return new_out -@local_optimizer([ger], inplace=True) +@node_rewriter([ger], inplace=True) def local_inplace_ger(fgraph, node): if node.op == ger: new_out = [ger_destructive(*node.inputs)] @@ -1790,7 +1792,7 @@ def local_inplace_ger(fgraph, node): return new_out -@local_optimizer([gemm_no_inplace]) +@node_rewriter([gemm_no_inplace]) def local_gemm_to_gemv(fgraph, node): """GEMM acting on row or column matrices -> GEMV.""" if node.op == gemm_no_inplace: @@ -1807,7 +1809,7 @@ def local_gemm_to_gemv(fgraph, node): return new_out -@local_optimizer([gemm_no_inplace]) +@node_rewriter([gemm_no_inplace]) def local_gemm_to_ger(fgraph, node): """GEMM computing an outer-product -> GER.""" if node.op == gemm_no_inplace: @@ -1839,7 +1841,7 @@ def local_gemm_to_ger(fgraph, node): # TODO: delete this optimization when we have the proper dot->gemm->ger pipeline # working -@local_optimizer([_dot22]) +@node_rewriter([_dot22]) def local_dot22_to_ger_or_gemv(fgraph, node): """dot22 computing an outer-product -> GER.""" if node.op == _dot22: @@ -1904,7 +1906,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node): blas_optdb.register("gemm_optimizer", GemmOptimizer(), "fast_run", position=10) blas_optdb.register( "local_gemm_to_gemv", - EquilibriumOptimizer( + EquilibriumGraphRewriter( [ local_gemm_to_gemv, local_gemm_to_ger, @@ -2033,7 +2035,7 @@ def c_code_cache_version(self): _dot22scalar = Dot22Scalar() -@local_optimizer([mul]) +@node_rewriter([mul]) def local_dot22_to_dot22scalar(fgraph, node): """ Notes @@ -2651,7 +2653,7 @@ def infer_shape(self, fgraph, node, shapes): # from opt import register_specialize, register_canonicalize # @register_specialize -@local_optimizer([sub, add]) +@node_rewriter([sub, add]) def local_print_as_we_go_along(fgraph, node): if node.op in (sub, add): debugprint(node) diff --git a/aesara/tensor/blas_c.py b/aesara/tensor/blas_c.py index bb710fc9ec..c808528b97 100644 --- a/aesara/tensor/blas_c.py +++ b/aesara/tensor/blas_c.py @@ -1,5 +1,5 @@ from aesara.configdefaults import config -from aesara.graph.opt import in2out +from aesara.graph.rewriting.basic import in2out from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType from aesara.scalar import bool as bool_t @@ -15,7 +15,7 @@ ger, ger_destructive, ldflags, - local_optimizer, + node_rewriter, optdb, ) @@ -344,7 +344,7 @@ def c_code_cache_version(self): cger_no_inplace = CGer(False) -@local_optimizer([ger, ger_destructive]) +@node_rewriter([ger, ger_destructive]) def use_c_ger(fgraph, node): if not config.blas__ldflags: return @@ -355,7 +355,7 @@ def use_c_ger(fgraph, node): return [CGer(True)(*node.inputs)] -@local_optimizer([CGer(False)]) +@node_rewriter([CGer(False)]) def make_c_ger_destructive(fgraph, node): if isinstance(node.op, CGer) and not node.op.destructive: return [cger_inplace(*node.inputs)] @@ -699,7 +699,7 @@ def check_force_gemv_init(): check_force_gemv_init._force_init_beta = None -@local_optimizer([gemv_inplace, gemv_no_inplace]) +@node_rewriter([gemv_inplace, gemv_no_inplace]) def use_c_gemv(fgraph, node): if not config.blas__ldflags: return @@ -710,7 +710,7 @@ def use_c_gemv(fgraph, node): return [cgemv_inplace(*node.inputs)] -@local_optimizer([CGemv(inplace=False)]) +@node_rewriter([CGemv(inplace=False)]) def make_c_gemv_destructive(fgraph, node): if isinstance(node.op, CGemv) and not node.op.inplace: inputs = list(node.inputs) diff --git a/aesara/tensor/blas_scipy.py b/aesara/tensor/blas_scipy.py index dfee0dd6ce..6dc9ff36e5 100644 --- a/aesara/tensor/blas_scipy.py +++ b/aesara/tensor/blas_scipy.py @@ -4,14 +4,14 @@ import numpy as np -from aesara.graph.opt import in2out +from aesara.graph.rewriting.basic import in2out from aesara.tensor.blas import ( Ger, blas_optdb, ger, ger_destructive, have_fblas, - local_optimizer, + node_rewriter, optdb, ) @@ -58,13 +58,13 @@ def perform(self, node, inputs, output_storage): scipy_ger_inplace = ScipyGer(True) -@local_optimizer([ger, ger_destructive]) +@node_rewriter([ger, ger_destructive]) def use_scipy_ger(fgraph, node): if node.op == ger: return [scipy_ger_no_inplace(*node.inputs)] -@local_optimizer([scipy_ger_no_inplace]) +@node_rewriter([scipy_ger_no_inplace]) def make_ger_destructive(fgraph, node): if node.op == scipy_ger_no_inplace: return [scipy_ger_inplace(*node.inputs)] diff --git a/aesara/tensor/elemwise.py b/aesara/tensor/elemwise.py index 0515ddaec7..34f9ea5459 100644 --- a/aesara/tensor/elemwise.py +++ b/aesara/tensor/elemwise.py @@ -418,24 +418,45 @@ def get_output_info(self, dim_shuffle, *inputs): # of all inputs in parallel... the all() gives us each output # broadcastable bit in turn. + def get_most_specialized_shape(shapes): + shapes = set(shapes) + # All shapes are the same + if len(shapes) == 1: + return tuple(shapes)[0] + + # Only valid indeterminate case + if shapes == {None, 1}: + return None + + shapes.discard(1) + shapes.discard(None) + if len(shapes) > 1: + raise ValueError + return tuple(shapes)[0] + # it is multiplied by nout because Elemwise supports multiple outputs # (nout of them) - out_broadcastables = [ - [ - all(bcast) - for bcast in zip(*[input.type.broadcastable for input in inputs]) - ] - ] * shadow.nout + try: + out_shapes = [ + [ + get_most_specialized_shape(shape) + for shape in zip(*[inp.type.shape for inp in inputs]) + ] + ] * shadow.nout + except ValueError: + raise ValueError( + f"Incompatible Elemwise input shapes {[inp.type.shape for inp in inputs]}" + ) # inplace_pattern maps output idx -> input idx inplace_pattern = self.inplace_pattern if inplace_pattern: for overwriter, overwritten in inplace_pattern.items(): for ob, ib in zip( - out_broadcastables[overwriter], + out_shapes[overwriter], inputs[overwritten].type.broadcastable, ): - if ib and not ob: + if ib and not ob == 1: raise ValueError( "Operation cannot be done inplace on an input " "with broadcasted dimensions." @@ -451,8 +472,8 @@ def get_output_info(self, dim_shuffle, *inputs): ([i.type.dtype for i in inputs], out_dtypes, inplace_pattern), ) ) - assert len(out_dtypes) == len(out_broadcastables) - return out_dtypes, out_broadcastables, inputs + assert len(out_dtypes) == len(out_shapes) + return out_dtypes, out_shapes, inputs def make_node(self, *inputs): """ @@ -461,12 +482,10 @@ def make_node(self, *inputs): using DimShuffle. """ inputs = [as_tensor_variable(i) for i in inputs] - out_dtypes, out_broadcastables, inputs = self.get_output_info( - DimShuffle, *inputs - ) + out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs) outputs = [ - TensorType(dtype=dtype, shape=broadcastable)() - for dtype, broadcastable in zip(out_dtypes, out_broadcastables) + TensorType(dtype=dtype, shape=shape)() + for dtype, shape in zip(out_dtypes, out_shapes) ] return Apply(self, inputs, outputs) @@ -806,7 +825,7 @@ def perform(self, node, inputs, output_storage): def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]: if len(node.outputs) > 1: - from aesara.tensor.basic_opt import ShapeError + from aesara.tensor.exceptions import ShapeError raise ShapeError( "Multiple outputs are not supported by the default `Elemwise.infer_shape`" diff --git a/aesara/tensor/extra_ops.py b/aesara/tensor/extra_ops.py index 623d10b1df..54c2339888 100644 --- a/aesara/tensor/extra_ops.py +++ b/aesara/tensor/extra_ops.py @@ -1,6 +1,6 @@ from collections.abc import Collection from functools import reduce -from typing import Iterable, Tuple, Union +from typing import Iterable, Set, Tuple, Union import numpy as np import numpy.core.numeric @@ -14,7 +14,7 @@ disconnected_type, grad_undefined, ) -from aesara.graph.basic import Apply, Variable, equal_computations +from aesara.graph.basic import Apply, Constant, Variable, equal_computations from aesara.graph.op import Op from aesara.link.c.op import COp from aesara.link.c.params_type import ParamsType @@ -1491,7 +1491,12 @@ def broadcast_shape_iter( array_shapes = [ (one_at,) * (max_dims - len(a)) - + tuple(one_at if getattr(sh, "value", sh) == 1 else sh for sh in a) + + tuple( + one_at + if getattr(sh, "value", sh) == 1 + else (aes.as_scalar(sh) if not isinstance(sh, Variable) else sh) + for sh in a + ) for a in arrays ] else: @@ -1523,41 +1528,83 @@ def broadcast_shape_iter( else: # More than one shape might not be broadcastable in this dimension - all_dims_equal = all( - # TODO FIXME: This is a largely deficient means of comparing graphs - # (and especially shapes) - equal_computations([maybe_non_bcast_shapes[0]], [dim]) - for dim in maybe_non_bcast_shapes[1:] - ) + nonconst_nb_shapes: Set[int] = set() + const_nb_shapes: Set[Variable] = set() + for shape in maybe_non_bcast_shapes: + if isinstance(shape, Constant): + const_nb_shapes.add(shape.value.item()) + else: + nonconst_nb_shapes.add(shape) - if all_dims_equal: - result_dims.append(maybe_non_bcast_shapes[0]) - continue + if len(const_nb_shapes) > 1: + raise ValueError("Could not broadcast dimensions") + elif len(const_nb_shapes) == 1: + (const_nb_shape,) = const_nb_shapes - non_bcast_vec = [ - aes.switch(aes.eq(nbv, 1), -one_at, nbv) - for nbv in maybe_non_bcast_shapes - ] - dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec)) + assert const_nb_shape != 1 - assert_dim = Assert("Could not broadcast dimensions") - assert_cond = reduce( - aes.and_, - ( - aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max)) - for nbv in non_bcast_vec - ), - ) - bcast_dim = assert_dim(dim_max, assert_cond) + const_nt_shape_var = aesara.scalar.ScalarConstant( + aesara.scalar.int64, const_nb_shape + ) + + if len(nonconst_nb_shapes) > 0: + # All the potential non-broadcast shapes need to either + # be broadcastable or equal to the one non-broadcastable + # constant `const_nt_shape_var`. + assert_dim = Assert("Could not broadcast dimensions") + assert_cond = reduce( + aes.and_, + ( + aes.or_( + aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var) + ) + for nbv in nonconst_nb_shapes + ), + ) + bcast_dim = assert_dim(const_nt_shape_var, assert_cond) + else: + bcast_dim = const_nt_shape_var + else: + # There are no constant, non-broadcastable shapes in this + # dimension. + + all_dims_equal = all( + # TODO FIXME: This is a largely deficient, and expensive, means + # of comparing graphs (and especially shapes) + equal_computations([maybe_non_bcast_shapes[0]], [dim]) + for dim in maybe_non_bcast_shapes[1:] + ) + + if all_dims_equal: + result_dims.append(maybe_non_bcast_shapes[0]) + continue + + non_bcast_vec = [ + aes.switch(aes.eq(nbv, 1), -one_at, nbv) + for nbv in maybe_non_bcast_shapes + ] + dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec)) + + assert_dim = Assert("Could not broadcast dimensions") + assert_cond = reduce( + aes.and_, + ( + aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max)) + for nbv in non_bcast_vec + ), + ) + bcast_dim = assert_dim(dim_max, assert_cond) result_dims.append(bcast_dim) return tuple(result_dims) -class BroadcastTo(Op): +class BroadcastTo(COp): """An `Op` for `numpy.broadcast_to`.""" + __props__ = () + view_map = {0: [0]} def __call__(self, a, shape, **kwargs): @@ -1607,6 +1654,56 @@ def grad(self, inputs, outputs_gradients): def infer_shape(self, fgraph, node, ins_shapes): return [node.inputs[1:]] + def c_code(self, node, name, inputs, outputs, sub): + (x, *shape) = inputs + (out,) = outputs + ndims = len(shape) + fail = sub["fail"] + + # TODO: Could just use `PyArray_Return`, no? + dims_array = ", ".join( + [ + f"((dtype_{shape}*)(PyArray_DATA({shape})))[0]" + for i, shape in enumerate(shape) + ] + ) + + src = ( + """ + npy_intp itershape[%(ndims)s] = {%(dims_array)s}; + + PyArrayObject *ops[1] = {%(x)s}; + npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK; + npy_uint32 op_flags[1] = {NPY_ITER_READONLY}; + PyArray_Descr *op_dtypes[1] = {NULL}; + int oa_ndim = %(ndims)s; + int* op_axes[1] = {NULL}; + npy_intp buffersize = 0; + + NpyIter *iter = NpyIter_AdvancedNew( + 1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize + ); + + %(out)s = NpyIter_GetIterView(iter, 0); + + if(%(out)s == NULL){ + NpyIter_Deallocate(iter); + %(fail)s; + } + + if (NpyIter_Deallocate(iter) != NPY_SUCCEED) { + %(fail)s; + } + + """ + % locals() + ) + + return src + + def c_code_cache_version(self): + return (1,) + broadcast_to_ = BroadcastTo() diff --git a/aesara/tensor/inplace.py b/aesara/tensor/inplace.py index b172d5b110..807099486d 100644 --- a/aesara/tensor/inplace.py +++ b/aesara/tensor/inplace.py @@ -233,6 +233,11 @@ def erfcx_inplace(a): """scaled complementary error function""" +@scalar_elemwise +def owens_t_inplace(h, a): + """owens t function""" + + @scalar_elemwise def gamma_inplace(a): """gamma function""" diff --git a/aesara/tensor/math.py b/aesara/tensor/math.py index 323f612de6..ef48ddeba1 100644 --- a/aesara/tensor/math.py +++ b/aesara/tensor/math.py @@ -1048,10 +1048,6 @@ def abs(a): """|`a`|""" -# These are deprecated and will be removed -abs_ = abs - - pprint.assign(abs, printing.PatternPrinter(("|%(0)s|", -1000))) @@ -1080,10 +1076,6 @@ def reciprocal(a): """1.0/a""" -# This is deprecated and will be removed -inv = reciprocal - - @scalar_elemwise def log(a): """base e logarithm of a""" @@ -1339,6 +1331,11 @@ def erfcinv(a): """inverse complementary error function""" +@scalar_elemwise +def owens_t(h, a): + """owens t function""" + + @scalar_elemwise def gamma(a): """gamma function""" @@ -3021,13 +3018,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None "invert", "bitwise_not", "abs", - "abs_", "exp", "exp2", "expm1", "neg", "reciprocal", - "inv", "log", "log2", "log10", @@ -3064,6 +3059,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None "erfcx", "erfinv", "erfcinv", + "owens_t", "gamma", "gammaln", "psi", @@ -3123,3 +3119,28 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None "logaddexp", "logsumexp", ] + +DEPRECATED_NAMES = [ + ("abs_", "`abs_` is deprecated; use `abs` instead.", abs), + ("inv", "`inv` is deprecated; use `reciprocal` instead.", reciprocal), +] + + +def __getattr__(name): + """Intercept module-level attribute access of deprecated symbols. + + Adapted from https://stackoverflow.com/a/55139609/3006474. + + """ + from warnings import warn + + for old_name, msg, old_object in DEPRECATED_NAMES: + if name == old_name: + warn(msg, DeprecationWarning, stacklevel=2) + return old_object + + raise AttributeError(f"module {__name__} has no attribute {name}") + + +def __dir__(): + return sorted(__all__ + [names[0] for names in DEPRECATED_NAMES]) diff --git a/aesara/tensor/math_opt.py b/aesara/tensor/math_opt.py index 87bea7a1c5..60cf6a8152 100644 --- a/aesara/tensor/math_opt.py +++ b/aesara/tensor/math_opt.py @@ -1,3566 +1,10 @@ -""" Tensor optimizations addressing the ops in math.py.""" +import warnings -import itertools -import operator -from functools import partial, reduce -import numpy as np - -import aesara.scalar.basic as aes -import aesara.scalar.math as aes_math -from aesara.graph.basic import Constant, Variable -from aesara.graph.opt import ( - LocalOptGroup, - LocalOptimizer, - PatternSub, - copy_stack_trace, - in2out, - local_optimizer, -) -from aesara.graph.opt_utils import get_clients_at_depth -from aesara.misc.safe_asarray import _asarray -from aesara.raise_op import assert_op -from aesara.tensor.basic import ( - Alloc, - Join, - MakeVector, - alloc, - as_tensor_variable, - cast, - constant, - extract_constant, - fill, - get_scalar_constant_value, - ones_like, - switch, - zeros_like, -) -from aesara.tensor.basic_opt import ( - FusionOptimizer, - broadcast_like, - encompasses_broadcastable, - fuse_seqopt, - local_fill_sink, - register_canonicalize, - register_specialize, - register_specialize_device, - register_stabilize, - register_uncanonicalize, - register_useless, -) -from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from aesara.tensor.exceptions import NotScalarConstantError -from aesara.tensor.math import ( - All, - Any, - Dot, - NonZeroCAReduce, - Prod, - ProdWithoutZeros, - Sum, - _conj, -) -from aesara.tensor.math import abs as at_abs -from aesara.tensor.math import ( - add, - dot, - eq, - erf, - erfc, - exp, - expm1, - ge, - int_div, - isinf, - le, - log, - log1mexp, - log1p, - makeKeepDims, -) -from aesara.tensor.math import max as at_max -from aesara.tensor.math import maximum, mul, neg -from aesara.tensor.math import pow as at_pow -from aesara.tensor.math import prod, reciprocal, sgn, sigmoid, softplus, sqr, sqrt, sub -from aesara.tensor.math import sum as at_sum -from aesara.tensor.math import true_div -from aesara.tensor.shape import Shape, Shape_i -from aesara.tensor.subtensor import Subtensor -from aesara.tensor.type import ( - complex_dtypes, - uint_dtypes, - values_eq_approx_remove_inf, - values_eq_approx_remove_inf_nan, - values_eq_approx_remove_nan, -) -from aesara.tensor.var import TensorConstant, get_unique_value - - -def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): - """Partition a list of variables into two kinds: - scalar constants, and the rest.""" - consts = [] - origconsts = [] - nonconsts = [] - for i in inputs: - try: - v = get_scalar_constant_value( - i, elemwise=elemwise, only_process_constants=only_process_constants - ) - consts.append(v) - origconsts.append(i) - except NotScalarConstantError: - nonconsts.append(i) - return consts, origconsts, nonconsts - - -def get_constant(v): - """ - - Returns - ------- - object - A numeric constant if v is a Constant or, well, a - numeric constant. If v is a plain Variable, returns None. - - """ - if isinstance(v, Constant): - unique_value = get_unique_value(v) - if unique_value is not None: - data = unique_value - else: - data = v.data - if data.ndim == 0: - return data - else: - return None - elif isinstance(v, Variable): - return None - else: - return v - - -def fill_chain(new_out, orig_inputs): - for i in orig_inputs: - new_out = fill(i, new_out) - return [new_out] - - -@register_canonicalize -@register_stabilize -@local_optimizer([Dot]) -def local_0_dot_x(fgraph, node): - if not isinstance(node.op, Dot): - return False - - x = node.inputs[0] - y = node.inputs[1] - replace = False - try: - if get_scalar_constant_value(x, only_process_constants=True) == 0: - replace = True - except NotScalarConstantError: - pass - - try: - if get_scalar_constant_value(y, only_process_constants=True) == 0: - replace = True - except NotScalarConstantError: - pass - - if replace: - constant_zero = constant(0, dtype=node.outputs[0].type.dtype) - if x.ndim == 2 and y.ndim == 2: - constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0])) - return [alloc(constant_zero, x.shape[0], y.shape[1])] - elif x.ndim == 1 and y.ndim == 2: - constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0])) - return [alloc(constant_zero, y.shape[1])] - elif x.ndim == 2 and y.ndim == 1: - constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0])) - return [alloc(constant_zero, x.shape[0])] - elif x.ndim == 1 and y.ndim == 1: - constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0])) - return [constant_zero] - - -@register_canonicalize -@local_optimizer([DimShuffle]) -def local_lift_transpose_through_dot(fgraph, node): - """Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)`` - - These optimizations "lift" (propagate towards the inputs) DimShuffle - through dot product. It allows to put the graph in a more standard shape, - and to later merge consecutive DimShuffles. - - The transformation should be apply whether or not the transpose is - inplace. The newly-introduced transpositions are not inplace, this will - be taken care of in a later optimization phase. - - """ - if not (isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)): - return False - if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)): - return False - x, y = node.inputs[0].owner.inputs - - if x.ndim == y.ndim == 2: - # Output is dot product of transposed inputs in reverse order - ret = [dot(y.T, x.T)] - - # Copy over stack trace to output from result of dot-product - copy_stack_trace(node.inputs[0], ret) - return ret - - -def is_inverse_pair(node_op, prev_op, inv_pair): - """ - Given two consecutive operations, check if they are the - provided pair of inverse functions. - - """ - node_is_op0 = isinstance(node_op, inv_pair[0]) - node_is_op1 = isinstance(node_op, inv_pair[1]) - prev_is_op0 = isinstance(prev_op, inv_pair[0]) - prev_is_op1 = isinstance(prev_op, inv_pair[1]) - - return (node_is_op0 and prev_is_op1) or (node_is_op1 and prev_is_op0) - - -@register_canonicalize -@register_specialize -@local_optimizer([Elemwise]) -def local_func_inv(fgraph, node): - """ - Check for two consecutive operations that are functional inverses - and remove them from the function graph. - - """ - inv_pairs = ( - (aes.Deg2Rad, aes.Rad2Deg), - (aes.Cosh, aes.ArcCosh), - (aes.Tanh, aes.ArcTanh), - (aes.Sinh, aes.ArcSinh), - (aes.Conj, aes.Conj), - (aes.Neg, aes.Neg), - (aes.Reciprocal, aes.Reciprocal), - ) - x = node.inputs[0] - - if not isinstance(node.op, Elemwise): - return - if not x.owner or not isinstance(x.owner.op, Elemwise): - return - - prev_op = x.owner.op.scalar_op - node_op = node.op.scalar_op - - for inv_pair in inv_pairs: - if is_inverse_pair(node_op, prev_op, inv_pair): - # We don't need to copy stack trace, because the optimization - # is trivial and maintains the earlier stack trace - ottype = node.out.dtype - inp = x.owner.inputs[0] - # Functions may have casted integer input to float - if inp.dtype != ottype: - inp = cast(inp, ottype) - return [inp] - - return - - -@register_canonicalize -@register_specialize -@local_optimizer([Elemwise]) -def local_exp_log(fgraph, node): - x = node.inputs[0] - - if not isinstance(node.op, Elemwise): - return - if not x.owner or not isinstance(x.owner.op, Elemwise): - return - - prev_op = x.owner.op.scalar_op - node_op = node.op.scalar_op - - # Case for log(exp(x)) -> x - if isinstance(prev_op, aes.Exp) and isinstance(node_op, aes.Log): - new_out = x.owner.inputs[0] - old_out = node.outputs[0] - # Exp may have cast integer input to float - if new_out.dtype != old_out.dtype: - new_out = cast(new_out, old_out.dtype) - return [new_out] - - # Case for log1p(expm1(x)) -> x - if isinstance(prev_op, aes.Expm1) and isinstance(node_op, aes.Log1p): - new_out = x.owner.inputs[0] - old_out = node.outputs[0] - # Expm1 may have cast integer input to float - if new_out.dtype != old_out.dtype: - new_out = cast(new_out, old_out.dtype) - return [new_out] - - # Case for exp(softplus(x)) aka exp(log1pexp) -> 1 + exp(x) - if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Exp): - x = x.owner.inputs[0] - return [add(1, exp(x))] - - # Case for expm1(softplus(x)) aka expm1(log1pexp) -> exp(x) - if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Expm1): - x = x.owner.inputs[0] - return [exp(x)] - - -@register_specialize -@local_optimizer([Elemwise]) -def local_exp_log_nan_switch(fgraph, node): - # Rewrites of the kind exp(log...(x)) that require a `nan` switch - x = node.inputs[0] - - if not isinstance(node.op, Elemwise): - return - if not x.owner or not isinstance(x.owner.op, Elemwise): - return - - prev_op = x.owner.op.scalar_op - node_op = node.op.scalar_op - - # Case for exp(log(x)) -> x - if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Exp): - x = x.owner.inputs[0] - old_out = node.outputs[0] - new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype)) - return [new_out] - - # Case for exp(log1p(x)) -> x + 1 - if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Exp): - x = x.owner.inputs[0] - old_out = node.outputs[0] - new_out = switch(ge(x, -1), add(1, x), np.asarray(np.nan, old_out.dtype)) - return [new_out] - - # Case for expm1(log(x)) -> x - 1 - if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Expm1): - x = x.owner.inputs[0] - old_out = node.outputs[0] - new_out = switch(ge(x, 0), sub(x, 1), np.asarray(np.nan, old_out.dtype)) - return [new_out] - - # Case for expm1(log1p(x)) -> x - if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Expm1): - x = x.owner.inputs[0] - old_out = node.outputs[0] - new_out = switch(ge(x, -1), x, np.asarray(np.nan, old_out.dtype)) - return [new_out] - - # Case for exp(log1mexp(x)) -> 1 - exp(x) - if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Exp): - x = x.owner.inputs[0] - old_out = node.outputs[0] - new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype)) - return [new_out] - - # Case for expm1(log1mexp(x)) -> -exp(x) - if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Expm1): - x = x.owner.inputs[0] - old_out = node.outputs[0] - new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype)) - return [new_out] - - -@register_canonicalize -@register_specialize -@local_optimizer([Sum]) -def local_sumsqr2dot(fgraph, node): - """ - This optimization detects - ``at.sqr(W.dimshuffle("x", 0, 1) * G.dimshuffle(0, "x", 1) ).sum(axis=(1, 2))`` - and converts it to ``at.dot(at.sqr(G), at.sqr(W).sum(axis=0))``. - """ - if ( - isinstance(node.op, Sum) - and isinstance(node.op.scalar_op, aes.Add) - and node.op.axis == (1, 2) - ): - in1 = node.inputs[0] - out = node.outputs[0] - - if ( - in1.owner - and isinstance(in1.owner.op, Elemwise) - and isinstance(in1.owner.op.scalar_op, aes.Sqr) - ): - in_sqr = in1.owner.inputs[0] - if ( - in_sqr.owner - and isinstance(in_sqr.owner.op, Elemwise) - and isinstance(in_sqr.owner.op.scalar_op, aes.Mul) - and len(in_sqr.owner.inputs) == 2 - ): - in_mul1, in_mul2 = in_sqr.owner.inputs - - if ( - isinstance(in_mul1.owner.op, DimShuffle) - and in_mul1.owner.op.new_order == ("x", 0, 1) - and isinstance(in_mul2.owner.op, DimShuffle) - and in_mul2.owner.op.new_order == (0, "x", 1) - ): - W = in_mul1.owner.inputs[0] - G = in_mul2.owner.inputs[0] - - new_out = dot(sqr(G), sqr(W).sum(axis=0)) - if new_out.dtype != out.dtype: - new_out = cast(new_out, dtype=out.dtype) - return [new_out] - - -@register_stabilize -@register_specialize -@register_canonicalize -@local_optimizer([Elemwise]) -def local_expm1(fgraph, node): - """ - This optimization detects exp(a)-1 and converts this to expm1(a). - """ - if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Sub): - in1, in2 = node.inputs - out = node.outputs[0] - - if ( - in1.owner - and isinstance(in1.owner.op, Elemwise) - and isinstance(in1.owner.op.scalar_op, aes.Exp) - and extract_constant(in2, only_process_constants=False) == 1 - ): - in11 = in1.owner.inputs[0] - new_out = expm1(in11) - - if new_out.dtype != out.dtype: - new_out = cast(new_out, dtype=out.dtype) - - if not out.type.is_super(new_out.type): - return - return [new_out] - - -@register_specialize -@register_canonicalize -@local_optimizer([mul]) -def local_mul_switch_sink(fgraph, node): - """ - This optimization makes the following changes in the graph: - ``at.mul(A, at.switch(cond, 0, iff), B)`` -> ``at.switch(cond, 0, at.mul(A, B, iff))`` - ``at.mul(A, at.switch(cond, ift, 0), B)`` -> ``at.switch(cond, at.mul(A, B, ift), 0)`` - ``A`` and ``B`` being several (or none) symbolic variables. - This is useful because ``A`` and ``B`` may not be numerically stable and give - NaN or inf values for cases where the switch returns 0. - With this optimization ``at.grad(at.switch(...))`` has the right behavior. - - Examples - -------- - - x -> f(x) - x -> g(x) - y = at.switch(cond, f(x), g(x)) - - without the optimization: - - at.grad(y, x) -> grad(f(x), x) * grad(y, f(x)) + grad(g(x), x) * grad(y, g(x)) - - with the optimization - - at.grad(y, x) -> switch(cond, grad(f(x), x), 0) + switch(cond, 0, grad(g(x), x)) - - This will be particularly useful for the lazy ``if`` because we skip an entire - part of the graph. - - """ - if node.op != mul: - return False - for idx, i in enumerate(node.inputs): - if i.owner and i.owner.op == switch: - switch_node = i.owner - try: - if ( - get_scalar_constant_value( - switch_node.inputs[1], only_process_constants=True - ) - == 0.0 - ): - listmul = node.inputs[:idx] + node.inputs[idx + 1 :] - fmul = mul(*(listmul + [switch_node.inputs[2]])) - - # Copy over stacktrace for elementwise multiplication op - # from previous elementwise multiplication op. - # An error in the multiplication (e.g. errors due to - # inconsistent shapes), will point to the - # multiplication op. - copy_stack_trace(node.outputs, fmul) - - fct = [switch(switch_node.inputs[0], 0, fmul)] - fct[0].tag.values_eq_approx = values_eq_approx_remove_nan - - # Copy over stacktrace for switch op from both previous - # elementwise multiplication op and previous switch op, - # because an error in this part can be caused by either - # of the two previous ops. - copy_stack_trace(node.outputs + switch_node.outputs, fct) - return fct - except NotScalarConstantError: - pass - try: - if ( - get_scalar_constant_value( - switch_node.inputs[2], only_process_constants=True - ) - == 0.0 - ): - listmul = node.inputs[:idx] + node.inputs[idx + 1 :] - fmul = mul(*(listmul + [switch_node.inputs[1]])) - # Copy over stacktrace for elementwise multiplication op - # from previous elementwise multiplication op. - # An error in the multiplication (e.g. errors due to - # inconsistent shapes), will point to the - # multiplication op. - copy_stack_trace(node.outputs, fmul) - - fct = [switch(switch_node.inputs[0], fmul, 0)] - fct[0].tag.values_eq_approx = values_eq_approx_remove_nan - - # Copy over stacktrace for switch op from both previous - # elementwise multiplication op and previous switch op, - # because an error in this part can be caused by either - # of the two previous ops. - copy_stack_trace(node.outputs + switch_node.outputs, fct) - return fct - except NotScalarConstantError: - pass - return False - - -@register_canonicalize -@local_optimizer([true_div, int_div]) -def local_div_switch_sink(fgraph, node): - """ - This optimization makes the following changes in the graph: - - ``at.div(at.switch(cond, 0, iff), A)`` -> ``at.switch(cond, 0, at.div(iff, A))`` - ``at.div(at.switch(cond, ift, 0), A)`` -> ``at.switch(cond, at.div(ift, A), 0)`` - - where ``A`` is a symbolic variable. - - This is useful because ``A`` may not be numerically stable and give - ``nan`` or ``inf`` values for cases where the switch returns 0. - - See `local_mul_switch_sink` for more details. - - """ - if node.op != true_div and node.op != int_div: - return False - op = node.op - if node.inputs[0].owner and node.inputs[0].owner.op == switch: - switch_node = node.inputs[0].owner - try: - if ( - get_scalar_constant_value( - switch_node.inputs[1], only_process_constants=True - ) - == 0.0 - ): - fdiv = op(switch_node.inputs[2], node.inputs[1]) - # Copy over stacktrace for elementwise division op - # from previous elementwise multiplication op. - # An error in the division (e.g. errors due to - # inconsistent shapes or division by zero), - # will point to the new division op. - copy_stack_trace(node.outputs, fdiv) - - fct = [switch(switch_node.inputs[0], 0, fdiv)] - fct[0].tag.values_eq_approx = values_eq_approx_remove_nan - - # Copy over stacktrace for switch op from both previous - # elementwise division op and previous switch op, - # because an error in this part can be caused by either - # of the two previous ops. - copy_stack_trace(node.outputs + switch_node.outputs, fct) - return fct - except NotScalarConstantError: - pass - try: - if ( - get_scalar_constant_value( - switch_node.inputs[2], only_process_constants=True - ) - == 0.0 - ): - fdiv = op(switch_node.inputs[1], node.inputs[1]) - # Copy over stacktrace for elementwise division op - # from previous elementwise multiplication op. - # An error in the division (e.g. errors due to - # inconsistent shapes or division by zero), - # will point to the new division op. - copy_stack_trace(node.outputs, fdiv) - - fct = [switch(switch_node.inputs[0], fdiv, 0)] - fct[0].tag.values_eq_approx = values_eq_approx_remove_nan - - # Copy over stacktrace for switch op from both previous - # elementwise division op and previous switch op, - # because an error in this part can be caused by either - # of the two previous ops. - copy_stack_trace(node.outputs + switch_node.outputs, fct) - return fct - except NotScalarConstantError: - pass - return False - - -class AlgebraicCanonizer(LocalOptimizer): - r"""Simplification tool. - - The variable is a ``local_optimizer``. It is best used - with a ``TopoOptimizer`` in ``in_to_out`` order. - - Usage: ``AlgebraicCanonizer(main, inverse, reciprocal, calculate)`` - - Parameters - ---------- - main - A suitable ``Op`` class that is commutative, associative and - takes one to an arbitrary number of inputs, e.g. add or - mul - inverse - An ``Op`` class such that ``inverse(main(x, y), y) == x`` - e.g. ``sub`` or true_div - reciprocal - A function such that ``main(x, reciprocal(y)) == inverse(x, y)`` - e.g. ``neg`` or ``reciprocal`` - calculate - Function that takes a list of numpy.ndarray instances - for the numerator, another list for the denumerator, - and calculates ``inverse(main(\*num), main(\*denum))``. It - takes a keyword argument, aslist. If True, the value - should be returned as a list of one element, unless - the value is such that value = main(). In that case, - the return value should be an empty list. - - Examples - -------- - >>> import aesara.tensor as at - >>> from aesara.tensor.math_opt import AlgebraicCanonizer - >>> add_canonizer = AlgebraicCanonizer(add, sub, neg, \\ - ... lambda n, d: sum(n) - sum(d)) - >>> mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, \\ - ... lambda n, d: prod(n) / prod(d)) - - Examples of optimizations ``mul_canonizer`` can perform: - - | x / x -> 1 - | (x * y) / x -> y - | x / y / x -> 1 / y - | x / y / z -> x / (y * z) - | x / (y / z) -> (x * z) / y - | (a / b) * (b / c) * (c / d) -> a / d - | (2.0 * x) / (4.0 * y) -> (0.5 * x) / y - | 2 * x / 2 -> x - | x * y * z -> Elemwise(mul){x,y,z} #only one pass over the memory. - | !-> Elemwise(mul){x,Elemwise(mul){y,z}} - - """ - - def __init__(self, main, inverse_fn, reciprocal_fn, calculate, use_reciprocal=True): - self.main = main - self.inverse = inverse_fn - self.reciprocal = reciprocal_fn - self.calculate = calculate - self.use_reciprocal = use_reciprocal - - self.external_simplifiers = [] - - def add_simplifier(self, simplifier, reason): - self.external_simplifiers.append((reason, simplifier)) - - def tracks(self): - return [self.main, self.inverse, self.reciprocal] - - def get_num_denum(self, inp): - r""" - This extract two lists, ``num`` and ``denum``, such that the input is: - ``self.inverse(self.main(\*num), self.main(\*denum))``. It returns - the two lists in a ``(num, denum)`` pair. - - For example, for main, inverse and ``reciprocal = \*, / and inv()``, - - | input -> returned value (num, denum) - - | x*y -> ([x, y], []) - | inv(x) -> ([], [x]) - | inv(x) * inv(y) -> ([], [x, y]) - | x*y/z -> ([x, y], [z]) - | log(x) / y * (z + x) / y -> ([log(x), z + x], [y, y]) - | (((a / b) * c) / d) -> ([a, c], [b, d]) - | a / (b / c) -> ([a, c], [b]) - | log(x) -> ([log(x)], []) - | x**y -> ([x**y], []) - | x * y * z -> ([x, y, z], []) - - """ - # This function is recursive. The idea is that there is a - # get_num_denum recursion in which the internal ops are all - # one of (main, inverse, reciprocal, DimShuffle) and the - # internal data nodes all have the dtype of the 'input' - # argument. The leaf-Variables of the graph covered by the - # recursion may be of any Variable type. - - if inp.owner is None or inp.owner.op not in [ - self.main, - self.inverse, - self.reciprocal, - ]: - if inp.owner and isinstance(inp.owner.op, DimShuffle): - # If input is a DimShuffle of some input which does - # something like this: - - # * change a vector of length N into a 1xN row matrix - # * change a scalar into a 1x1x1 tensor - # * in general, complete the shape of a tensor - # with broadcastable 1s to the *left* - # Then we will simply discard the DimShuffle and return - # the num/denum of its input - dsn = inp.owner # dimshuffle node - dsop = dsn.op # dimshuffle op - - # the first input of the dimshuffle i.e. the ndarray to redim - dsi0 = dsn.inputs[0] - - # The compatible order is a DimShuffle "new_order" of the form: - # ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim) - - # That kind of DimShuffle only adds broadcastable - # dimensions on the left, without discarding any - # existing broadcastable dimension and is inserted - # automatically by Elemwise when the inputs have - # different numbers of dimensions (hence why we can - # discard its information - we know we can retrieve it - # later on). - compatible_order = ("x",) * (inp.type.ndim - dsi0.type.ndim) + tuple( - range(dsi0.type.ndim) - ) - if dsop.new_order == compatible_order: - # If the "new_order" is the one we recognize, - # we return the num_denum of the dimshuffled input. - return self.get_num_denum(inp.owner.inputs[0]) - else: - # This is when the input isn't produced by main, - # inverse or reciprocal. - return [inp], [] - else: - return [inp], [] - num = [] - denum = [] - parent = inp.owner - - # We get the (num, denum) pairs for each input - # pairs = [self.get_num_denum(input2) if input2.type.dtype == - # input.type.dtype else ([input2], []) for input2 in - # parent.inputs] - pairs = [self.get_num_denum(input2) for input2 in parent.inputs] - - if parent.op == self.main: - # If we have main(x, y, ...), numx, denumx, numy, denumy, ... - # then num is concat(numx, numy, num...) and denum is - # concat(denumx, denumy, denum...) note that main() can have any - # number of arguments >= 0 concat is list concatenation - num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs)) - denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs)) - elif parent.op == self.inverse: - # If we have inverse(x, y), numx, denumx, numy and denumy - # then num is concat(numx, denumy) and denum is - # concat(denumx, numy) note that inverse() is binary - num = pairs[0][0] + pairs[1][1] - denum = pairs[0][1] + pairs[1][0] - elif parent.op == self.reciprocal: - # If we have reciprocal(x), numx, denumx - # then num is denumx and denum is numx - # note that reciprocal() is unary - num = pairs[0][1] - denum = pairs[0][0] - return num, denum - - def merge_num_denum(self, num, denum): - r""" - Utility function which takes two lists, num and denum, and - returns something which is equivalent to inverse(main(\*num), - main(\*denum)), but depends on the length of num and the length - of denum (in order to minimize the number of operations). - - Let n = len(num) and d = len(denum): - - | n=0, d=0: neutral element (given by self.calculate([], [])) - | (for example, this would be 0 if main is addition - | and 1 if main is multiplication) - | n=1, d=0: num[0] - | n=0, d=1: reciprocal(denum[0]) - | n=1, d=1: inverse(num[0], denum[0]) - | n=0, d>1: reciprocal(main(\*denum)) - | n>1, d=0: main(\*num) - | n=1, d>1: inverse(num[0], main(\*denum)) - | n>1, d=1: inverse(main(\*num), denum[0]) - | n>1, d>1: inverse(main(\*num), main(\*denum)) - - Given the values of n and d to which they are associated, all - of the above are equivalent to: - inverse(main(\*num), main(\*denum)) - - """ - - ln, ld = len(num), len(denum) - if not ln and not ld: - return as_tensor_variable(self.calculate([], [])) - if not ln: - if self.use_reciprocal: - return self.reciprocal(self.merge_num_denum(denum, [])) - else: - ln = [self.calculate([], [], aslist=False)] - if not ld: - if ln == 1: - # num[0] should always be a variable - assert isinstance(num[0], Variable) - return num[0] - else: - return self.main(*num) - return self.inverse( - self.merge_num_denum(num, []), self.merge_num_denum(denum, []) - ) - - def simplify(self, num, denum, out_type): - """ - Shorthand for: - - .. code-block:: python - - self.simplify_constants(*self.simplify_factors(num, denum)) - - """ - rval = self.simplify_constants( - *self.simplify_factors(num, denum), out_type=out_type - ) - for reason, simplifier in self.external_simplifiers: - # TODO: document that 'reason' is associated with this - # simplification to help auditing when things go - # wrong - rval = simplifier(*rval) - return rval - - def simplify_factors(self, num, denum): - """ - For any Variable r which is both in num and denum, removes it - from both lists. Modifies the lists inplace. Returns the - modified lists. For example: - - | [x], [x] -> [], [] - | [x, y], [x] -> [y], [] - | [a, b], [c, d] -> [a, b], [c, d] - - """ - ln = len(num) - ld = len(denum) - if ld > 2 and ln > 2: - # Faster version for "big" inputs. - while True: - s = set(num) - # Inputs can appear multiple times - redo = len(s) != len(num) - inter = s.intersection(denum) - for v in inter: - num.remove(v) - denum.remove(v) - if not redo or not inter: - break - else: - for v in list(num): - if v in denum: - num.remove(v) - denum.remove(v) - return num, denum - - def simplify_constants(self, orig_num, orig_denum, out_type=None): - """ - Find all constants and put them together into a single constant. - - Finds all constants in orig_num and orig_denum (using - get_constant) and puts them together into a single - constant. The constant is inserted as the first element of the - numerator. If the constant is the neutral element, it is - removed from the numerator. - - Examples - -------- - Let main be multiplication: - - | [2, 3, x], [] -> [6, x], [] - | [x, y, 2], [4, z] -> [0.5, x, y], [z] - | [x, 2, y], [z, 2] -> [x, y], [z] - - """ - # Lists representing the numerator and denumerator - num, denum = [], [] - - # Lists representing the *constant* elements of num and denum - numct, denumct = [], [] - - for v in orig_num: - ct = get_constant(v) - if ct is not None: - # We found a constant in the numerator! - # We add it to numct - numct.append(ct) - else: - num.append(v) - for v in orig_denum: - ct = get_constant(v) - if ct is not None: - denumct.append(ct) - else: - denum.append(v) - - if self.use_reciprocal or num: - # This will calculate either: - # [inverse(main(*numct), main(*denumct))] - # [] - if inverse(main(*numct), main(*denumct)) is the - # neutral element - ct = self.calculate(numct, denumct, aslist=True, out_type=out_type) - else: - # This happens if we don't allow the reciprocal and the - # numerator is empty. That means we will need to represent - # reciprocal(x) like inverse(neutral_element, x) so - # we can't allow ct == [] - # TODO: why is this branch needed when merge_num_denum - # does it for us? - ct = [self.calculate(numct, denumct, aslist=False, out_type=out_type)] - - # Wrapping ct in a Constant with the right dtype - ct = [constant(c, dtype=out_type.dtype) for c in ct] - - if orig_num and len(numct) == 1 and len(denumct) == 0 and ct: - # In that case we should only have one constant in `ct`. - assert len(ct) == 1 - first_num_ct = get_constant(orig_num[0]) - if first_num_ct is not None and ct[0].type.values_eq( - ct[0].data, first_num_ct - ): - # This is an important trick :( if it so happens that: - # * there's exactly one constant on the numerator and none on - # the denominator - # * it's not the neutral element (ct is an empty list in that - # case) - # * the constant is the same as the first argument in the - # numerator (we only check the first argument because the - # canonizer puts the computed constants first) - # -> then we return very exactly the original num/denum. - # If we don't do that the optimizer will just loop - # infinitely because it will not catch on that there are - # no changes to be made and every time it will want to - # replace something by the same thing... - # Note that it is important to use `values_eq` instead of - # the == operator, to handle NaN values correctly. - return orig_num, orig_denum - - return ct + num, denum - - def transform(self, fgraph, node): - op = node.op - if op not in [self.main, self.inverse, self.reciprocal]: - return False - - assert len(node.outputs) == 1 - out = node.outputs[0] - - out_clients = fgraph.clients.get(out) - - if not out_clients: - return False - - # check if any of the clients of this node would be part of - # this canonized graph... if so, we do nothing and wait for - # them to be transformed. - for c, c_idx in out_clients: - if c == "output": - continue - while ( - isinstance(getattr(c, "op", None), DimShuffle) - and len(fgraph.clients[c.outputs[0]]) <= 1 - ): - c = fgraph.clients[c.outputs[0]][0][0] - if getattr(c, "op", "") in [self.main, self.inverse, self.reciprocal]: - return False - - # Here we make the canonical version of the graph around this node - # See the documentation of get_num_denum and simplify - orig_num, orig_denum = self.get_num_denum(node.outputs[0]) - num, denum = self.simplify(list(orig_num), list(orig_denum), out.type) - - def same(x, y): - return len(x) == len(y) and all(np.all(xe == ye) for xe, ye in zip(x, y)) - - if ( - same(orig_num, num) - and same(orig_denum, denum) - and - # Check to see if we've collapsed some nested ops. - not ( - len(orig_denum) == 0 - and - # Make sure this change would increase the number of vector - # arguments--decreasing the number of unnecessary `self.main` - # nodes. - len(node.inputs) < len(orig_num) - ) - and - # Do a similar check for the reciprocal op. - not ( - self.use_reciprocal - and node.op == self.reciprocal - and len(orig_num) == 0 - and node.inputs[0].owner - and len(node.inputs[0].owner.inputs) < len(orig_denum) - ) - ): - return False - - new = self.merge_num_denum(num, denum) - if new.type.dtype != out.type.dtype: - new = cast(new, out.type.dtype) - - if new.type != out.type: - new = fill_chain(new, node.inputs)[0] - - if new.type == out.type: - new.tag.values_eq_approx = values_eq_approx_remove_inf_nan - copy_stack_trace(out, new) - return [new] - else: - return False - - def __str__(self): - return getattr( - self, - "name", - f"AlgebraicCanonizer({self.main}, {self.inverse}, {self.reciprocal})", - ) - - -def mul_calculate(num, denum, aslist=False, out_type=None): - if not num and not denum: - # Smallest 1 possible. - if aslist: - return [] - else: - return np.int8(1) - - # Make sure we do not accidentally upcast data types. - if out_type is None: - out_dtype = aes.upcast(*[v.dtype for v in (num + denum)]) - else: - out_dtype = out_type.dtype - one = _asarray(1, dtype=out_dtype) - - v = reduce(np.multiply, num, one) / reduce(np.multiply, denum, one) - if aslist: - if np.all(v == 1): - return [] - else: - return [v] - return v - - -local_mul_canonizer = AlgebraicCanonizer( - mul, true_div, reciprocal, mul_calculate, False -) -register_canonicalize(local_mul_canonizer, name="local_mul_canonizer") - - -@register_canonicalize -@local_optimizer([neg]) -def local_neg_to_mul(fgraph, node): - if node.op == neg: - return [mul(np.array(-1, dtype=node.inputs[0].dtype), node.inputs[0])] - - -@register_specialize -@local_optimizer([Sum, Prod]) -def local_sum_prod_mul_by_scalar(fgraph, node): - """ - sum(scalar * smth) -> scalar * sum(smth) - sum(-smth) -> -sum(smth) - - or - - prod(scalar * smth) -> scalar ** size(smth) * prod(smth) - prod(-smth) -> -1 ** size(smth) * prod(smth) - - """ - # TODO: if the the thing inside the Sum is a division, - # we should get at the numerator.... - if isinstance(node.op, (Sum, Prod)): - (node_inps,) = node.inputs - if node_inps.owner and node_inps.owner.op == mul: - terms = node_inps.owner.inputs - scalars = [t.dimshuffle() for t in terms if all(t.type.broadcastable)] - - if len(scalars) == 0: - # Nothing to optimize here - return - - non_scalars = [t for t in terms if not all(t.broadcastable)] - - # Perform the op only on the non-scalar inputs, if applicable - if len(non_scalars) == 0: - new_op_input_nb_elements = 1 - new_op_output = 1 - elif len(non_scalars) == 1: - new_op_input_nb_elements = non_scalars[0].size - new_op_output = node.op(non_scalars[0]) - else: - new_op_input = mul(*non_scalars) - # We assume that errors always come from the prod/mul op in the - # original computational graph, and therefore need to only - # copy over its output stacktrace. - copy_stack_trace(node.outputs, new_op_input) - - new_op_input_nb_elements = new_op_input.size - new_op_output = node.op(new_op_input) - - if len(non_scalars) != 0: - # Copy over stacktrace from previous output to new mul op, - # for same reason as above. - copy_stack_trace(node.outputs, new_op_output) - - # If `node.op` is a `Prod`, then the scalars need to be raised to - # the power of the number of elements in the input to the `Prod` - if isinstance(node.op, Prod) and new_op_input_nb_elements != 1: - - scalars = [s**new_op_input_nb_elements for s in scalars] - - # Scale the output of the op by the scalars and return as - # replacement for the original output - mul_inputs = scalars - if new_op_input_nb_elements != 1: - mul_inputs.append(new_op_output) - - if len(mul_inputs) == 1: - # Copy over stacktrace from previous output to new mul op, - # for same reason as above. - copy_stack_trace(node.outputs, mul_inputs) - - return mul_inputs - else: - ret = mul(*mul_inputs) - # Copy over stacktrace from previous output to new mul op, - # for same reason as above. - copy_stack_trace(node.outputs, [ret] + mul_inputs) - - return [ret] - - if isinstance(node.op, Sum) and node_inps.owner and node_inps.owner.op == neg: - s = node.op(node_inps.owner.inputs[0]) - ret = neg(s) - # There are never errors in the negative op, thus - # we need only to copy over stacktrace from previous output node to - # the two new ops. - copy_stack_trace(node.outputs, [s, ret]) - - return [ret] - - -@register_specialize -@local_optimizer([Elemwise]) -def local_elemwise_sub_zeros(fgraph, node): - """ - Elemwise{sub}(X,X) -> zeros_like(X) - """ - if ( - isinstance(node.op, Elemwise) - and node.op.scalar_op.nin == 2 - and node.op.scalar_op == aes.sub - and node.inputs[0] == node.inputs[1] - ): - res = zeros_like(node.inputs[0]) - # Copy over stacktrace from previous output. - # This could help for failures due to out-of-memory. - copy_stack_trace(node.outputs, res) - return [res] - - -@register_useless -@register_specialize -@register_stabilize -@register_canonicalize -@local_optimizer([Elemwise]) -def local_useless_elemwise_comparison(fgraph, node): - """... - - :note: These cases appear in the graph generated by scan. - These optimizations will make the graph easier to read. - # Comparing to itself is constant - Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) - Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) - Elemwise[{minimum,maximum}](X, X) -> X - - # Comparing shape to 0 can be constant - Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X) - Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) - Elemwise[maximum](X.shape[i], 0) -> X.shape[i] - Elemwise[maximum](0, X.shape[i]) -> X.shape[i] - Elemwise[minimum](X.shape[i], 0) -> 0 - Elemwise[minimum](0, X.shape[i]) -> 0 - - # The shape can be replaced with sum of shapes - Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) - Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) - - # Shapes are never negative - # Needed by Reshape.infer_shape - Elemwise[EQ](Subtensor(Shape(x)), -N) -> Elemwise[zeros](X) - - """ - if not isinstance(node.op, Elemwise): - return - if node.op.scalar_op.nin != 2: - return - - # We call zeros_like and one_like with opt=True to generate a - # cleaner graph. - dtype = node.outputs[0].dtype - - # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) - if ( - isinstance(node.op.scalar_op, (aes.LT, aes.GT)) - and node.inputs[0] is node.inputs[1] - ): - res = zeros_like(node.inputs[0], dtype=dtype, opt=True) - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] - # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) - if ( - isinstance(node.op.scalar_op, (aes.LE, aes.GE)) - and node.inputs[0] is node.inputs[1] - ): - res = ones_like(node.inputs[0], dtype=dtype, opt=True) - - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] - # Elemwise[{minimum,maximum}](X, X) -> X - if ( - isinstance(node.op.scalar_op, (aes.ScalarMinimum, aes.ScalarMaximum)) - and node.inputs[0] is node.inputs[1] - ): - res = node.inputs[0] - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] - - # Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X) - if ( - isinstance(node.op.scalar_op, aes.LT) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 - ): - res = zeros_like(node.inputs[0], dtype=dtype, opt=True) - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] - # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) - if ( - isinstance(node.op.scalar_op, aes.GE) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 - ): - res = ones_like(node.inputs[0], dtype=dtype, opt=True) - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] - # Elemwise[maximum](X.shape[i], 0) -> X.shape[i] - if ( - isinstance(node.op.scalar_op, aes.ScalarMaximum) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 - ): - # No need to copy over stacktrace. - return [node.inputs[0]] - # Elemwise[maximum](0, X.shape[i]) -> X.shape[i] - if ( - isinstance(node.op.scalar_op, aes.ScalarMaximum) - and extract_constant(node.inputs[0], only_process_constants=True) == 0 - and node.inputs[1].owner - and isinstance(node.inputs[1].owner.op, Shape_i) - ): - # No need to copy over stacktrace. - return [node.inputs[1]] - # Elemwise[minimum](X.shape[i], 0) -> 0 - if ( - isinstance(node.op.scalar_op, aes.ScalarMinimum) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Shape_i) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 - ): - res = zeros_like(node.inputs[0], dtype=dtype, opt=True) - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] - - # Elemwise[minimum](0, X.shape[i]) -> 0 - if ( - isinstance(node.op.scalar_op, aes.ScalarMinimum) - and extract_constant(node.inputs[0], only_process_constants=True) == 0 - and node.inputs[1].owner - and isinstance(node.inputs[1].owner.op, Shape_i) - ): - res = zeros_like(node.inputs[1], dtype=dtype, opt=True) - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] - - # Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) - if ( - isinstance(node.op.scalar_op, aes.LT) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Elemwise) - and isinstance(node.inputs[0].owner.op.scalar_op, aes.Add) - and all( - isinstance(var.owner and var.owner.op, Shape_i) - for var in node.inputs[0].owner.inputs - ) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 - ): - res = zeros_like(node.inputs[0], dtype=dtype, opt=True) - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] - # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) - if ( - isinstance(node.op.scalar_op, aes.GE) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Elemwise) - and isinstance(node.inputs[0].owner.op.scalar_op, aes.Add) - and all( - isinstance(var.owner and var.owner.op, Shape_i) - for var in node.inputs[0].owner.inputs - ) - and extract_constant(node.inputs[1], only_process_constants=True) == 0 - ): - res = ones_like(node.inputs[0], dtype=dtype, opt=True) - - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - return [res] - - # Elemwise[EQ](Subtensor(Shape(x)), -N) - # Elemwise[EQ](somegraph that only depend of shape, -N) - # TODO: handle the case where the -N is on either side - """ - |Elemwise{eq,no_inplace} [id B] '' - | |Subtensor{int64} [id C] '' - | | |Join [id D] '' - | | | |TensorConstant{0} [id E] - | | | |Subtensor{int64:int64:} [id F] '' - | | | | |Shape [id G] '' - """ - - def investigate(node): - "Return True if values will be shapes, so >= 0" - if isinstance(node.op, (Shape, Shape_i)): - return True - elif isinstance(node.op, Subtensor) and node.inputs[0].owner: - return investigate(node.inputs[0].owner) - elif isinstance(node.op, Join): - return all(v.owner and investigate(v.owner) for v in node.inputs[1:]) - elif isinstance(node.op, MakeVector): - return all(v.owner and investigate(v.owner) for v in node.inputs) - - if ( - isinstance(node.op.scalar_op, aes.EQ) - and node.inputs[0].owner - and investigate(node.inputs[0].owner) - ): - try: - cst = get_scalar_constant_value(node.inputs[1], only_process_constants=True) - - res = zeros_like(node.inputs[0], dtype=dtype, opt=True) - - if cst < 0: - # Copy over stacktrace from previous output. - copy_stack_trace(node.outputs, res) - - return [res] - - except NotScalarConstantError: - pass - return - - -@register_canonicalize -@register_specialize -@local_optimizer([Sum, Prod]) -def local_sum_prod_div_dimshuffle(fgraph, node): - """ - sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b, - if dimension l of the DimShuffle is 'x' - - or - - prod(a / dimshuffle{...}(b), axis=l) -> - prod(a, axis={...}) / b ** a.shape[l], - if dimension l of the DimShuffle is 'x' - """ - - # It does not make much sense now to extend it to the case where the - # dimshuffle is in the numerator, since elemwise inversion of the - # denominator would still be needed before the summation or production. - - if isinstance(node.op, (Sum, Prod)): - axis = node.op.axis - if axis is None: - axis = list(range(node.inputs[0].ndim)) - node_input = node.inputs[0] - if node_input.owner and node_input.owner.op == true_div: - numerator, denominator = node_input.owner.inputs - - if denominator.owner and isinstance(denominator.owner.op, DimShuffle): - dimshuffle_input = denominator.owner.inputs[0] - dimshuffle_order = denominator.owner.op.new_order - - compatible_dims = [] - incompatible_dims = [] - for ax in axis: - if ax < len(dimshuffle_order) and dimshuffle_order[ax] == "x": - compatible_dims.append(ax) - else: - incompatible_dims.append(ax) - reordered_incompatible_dims = [] - for ic_ax in incompatible_dims: - reordered_incompatible_dims.append( - ic_ax - sum(1 for c_ax in compatible_dims if c_ax < ic_ax) - ) - - if len(compatible_dims) > 0: - optimized_dimshuffle_order = list( - ax - for i, ax in enumerate(dimshuffle_order) - if (i not in axis) or (ax != "x") - ) - - # Removing leading 'x' (since it will be done automatically) - while ( - len(optimized_dimshuffle_order) > 0 - and optimized_dimshuffle_order[0] == "x" - ): - del optimized_dimshuffle_order[0] - - # if optimized_dimshuffle_order is sorted with - # not 'x', then dimshuffle is useless. - if all(i == e for i, e in enumerate(optimized_dimshuffle_order)): - optimized_dimshuffle = dimshuffle_input - else: - optimized_dimshuffle = DimShuffle( - dimshuffle_input.type.broadcastable, - optimized_dimshuffle_order, - )(dimshuffle_input) - - if isinstance(node.op, Sum): - op_on_compatible_dims = at_sum(numerator, axis=compatible_dims) - rval = true_div(op_on_compatible_dims, optimized_dimshuffle) - if len(reordered_incompatible_dims) > 0: - rval = at_sum(rval, axis=reordered_incompatible_dims) - elif isinstance(node.op, Prod): - op_on_compatible_dims = prod(numerator, axis=compatible_dims) - dtype = numerator.dtype - rval = true_div( - op_on_compatible_dims, - ( - optimized_dimshuffle - ** prod( - [ - numerator.shape[ax].astype(dtype) - for ax in compatible_dims - ] - ) - ), - ) - if len(reordered_incompatible_dims) > 0: - rval = prod(rval, axis=reordered_incompatible_dims) - return [rval] - - -@register_canonicalize -@local_optimizer([Sum, Prod]) -def local_sum_prod_all_to_none(fgraph, node): - """ - Sum{0,1,...N} -> Sum{} or - Prod{0,1,...N} -> Prod{} - - """ - if isinstance(node.op, Sum) or isinstance(node.op, Prod): - opt_type = Sum if isinstance(node.op, Sum) else Prod - # if all the axes are named, then use None as a shorthand - # this permits more merging - if node.op.axis is None: - return - if set(node.op.axis) == set(range(node.inputs[0].type.ndim)): - return [opt_type(axis=None, dtype=node.op.dtype)(node.inputs[0])] - - -@register_canonicalize -@local_optimizer([Sum, Prod]) -def local_op_of_op(fgraph, node): - """ - Prod(Prod()) -> single Prod() - or - Sum(Sum()) -> single Sum() - - """ - if isinstance(node.op, Prod) or isinstance(node.op, Sum): - opt_type = Sum if isinstance(node.op, Sum) else Prod - (node_inps,) = node.inputs - out_dtype = node.op.dtype - # We manipulate the graph so this is done to make sure the opt - # doesn't affect other computations. - if len(fgraph.clients[node_inps]) == 1: - if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)): - - # check to see either the inner or outer prod is doing a - # product over all axis, in which case we can remove it - if node_inps.owner.op.axis is None or node.op.axis is None: - return [opt_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])] - - # figure out which axes were in the original sum - newaxis = list(tuple(node_inps.owner.op.axis)) - for i in node.op.axis: - new_i = i - for ii in node_inps.owner.op.axis: - if new_i >= ii: - new_i += 1 - assert new_i not in newaxis - newaxis.append(new_i) - - assert len(newaxis) == len( - list(node_inps.owner.op.axis) + list(node.op.axis) - ) - - combined = opt_type(newaxis, dtype=out_dtype) - return [combined(node_inps.owner.inputs[0])] - - -ALL_REDUCE = ( - [ - CAReduce, - All, - Any, - Sum, - Prod, - ProdWithoutZeros, - ] - + CAReduce.__subclasses__() - + NonZeroCAReduce.__subclasses__() -) - - -@register_canonicalize -@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce -@local_optimizer(ALL_REDUCE) -def local_reduce_join(fgraph, node): - """ - CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b) - - Notes - ----- - Supported scalar.op are Maximum, Minimum in some cases and Add and Mul in - all cases. - - Currently we must reduce on axis 0. It is probably extensible to the case - where we join and reduce on the same set of axis. - - """ - if ( - isinstance(node.op, CAReduce) - and node.inputs[0].owner - and isinstance(node.inputs[0].owner.op, Join) - ): - join_node = node.inputs[0].owner - if extract_constant(join_node.inputs[0], only_process_constants=True) != 0: - return - - if isinstance(node.op.scalar_op, (aes.ScalarMaximum, aes.ScalarMinimum)): - # Support only 2 inputs for now - if len(join_node.inputs) != 3: - return - elif not isinstance(node.op.scalar_op, (aes.Add, aes.Mul)): - return - elif len(join_node.inputs) <= 2: - # This is a useless join, that will get removed by another opt. - return - - new_inp = [] - for inp in join_node.inputs[1:]: - inp = inp.owner - if not inp: - return - if not isinstance(inp.op, DimShuffle) or inp.op.new_order != ("x",) + tuple( - range(inp.inputs[0].ndim) - ): - return - new_inp.append(inp.inputs[0]) - ret = Elemwise(node.op.scalar_op)(*new_inp) - - if ret.dtype != node.outputs[0].dtype: - # The reduction do something about the dtype. - return - - reduce_axis = node.op.axis - if reduce_axis is None: - reduce_axis = tuple(range(node.inputs[0].ndim)) - - if len(reduce_axis) != 1 or 0 not in reduce_axis: - return - - # We add the new check late to don't add extra warning. - try: - join_axis = get_scalar_constant_value( - join_node.inputs[0], only_process_constants=True - ) - - if join_axis != reduce_axis[0]: - return - except NotScalarConstantError: - return - - return [ret] - - -@register_canonicalize("fast_compile", "local_cut_useless_reduce") -@register_useless("local_cut_useless_reduce") -@local_optimizer(ALL_REDUCE) -def local_useless_reduce(fgraph, node): - """Sum(a, axis=[]) -> a""" - if isinstance(node.op, CAReduce): - (summed,) = node.inputs - # if reduce were doing anything, the output ndim would be reduced - if summed.type == node.outputs[0].type: - return [summed] - - -@register_canonicalize -@register_uncanonicalize -@register_specialize -@local_optimizer(ALL_REDUCE) -def local_reduce_broadcastable(fgraph, node): - """Remove reduction over broadcastable dimensions.""" - if isinstance(node.op, CAReduce): - (reduced,) = node.inputs - odtype = node.outputs[0].dtype - if node.op.axis is None: - if all(reduced.broadcastable): - return [reduced.dimshuffle().astype(odtype)] - else: - axis = list(node.op.axis) - cuttable = [a for a in axis if reduced.broadcastable[a]] - if cuttable: - # -- we can remove some axes of summation. - new_axis = [] - pattern = [] - ii = 0 - for p in range(reduced.ndim): - if p not in cuttable: - if p in axis: - new_axis.append(ii) - pattern.append(p) - ii += 1 - new_reduced = reduced.dimshuffle(*pattern) - if new_axis: - if type(node.op) == CAReduce: - # This case handles `CAReduce` instances - # (e.g. generated by `scalar_elemwise`), and not the - # scalar `Op`-specific subclasses - # TODO FIXME: This highlights a major design flaw in - # `CAReduce` (or at least our use of it), and it needs - # to be fixed - new_op = node.op.__class__(node.op.scalar_op, axis=new_axis) - else: - new_op = node.op.__class__(axis=new_axis) - return [new_op(new_reduced)] - else: - # -- in this case we can remove the reduction completely - return [new_reduced.astype(odtype)] - - -@register_specialize -@local_optimizer([Sum, Prod]) -def local_opt_alloc(fgraph, node): - """ - sum(alloc(constant,shapes...)) => constant*prod(shapes) - or - prod(alloc(constant,shapes...)) => constant**prod(shapes) - - """ - if isinstance(node.op, Sum) or isinstance(node.op, Prod): - (node_inps,) = node.inputs - if node_inps.owner and isinstance(node_inps.owner.op, Alloc): - inp = node_inps.owner.inputs[0] - shapes = node_inps.owner.inputs[1:] - try: - val = get_scalar_constant_value(inp, only_process_constants=True) - assert val.size == 1 - val = val.reshape(1)[0] - # check which type of op - size = mul(*shapes) - if inp.dtype in ("float16", "float32"): - # shapes are ints and normally int64. - # We don't want to have a float64 upcast - # We don't want to downcast to float16 - # as we fear it could loose too much precision - # that will be amplified by the mul/pow below. - size = size.astype("float32") - if node.op.axis is None or node.op.axis == tuple(range(inp.ndim)): - if isinstance(node.op, Sum): - val = val * size - else: - val = val**size - # Sum can change the input dtype (upcast or bool - # -> float32) by default or by user request. - # We can ignore the acc_dtype, as there is only 1 - # elemwise we will do and not a sequence, so there is no - # accumulation of errors. - # So mostly, we just need to cast the output to the old - # dtype. - val = val.astype(node.outputs[0].dtype) - return [val] - to_prod = [shapes[i] for i in range(len(shapes)) if i in node.op.axis] - if to_prod: - size = mul(*to_prod) - if isinstance(node.op, Sum): - val *= size - else: - val = val**size - # See comments above. - val = val.astype(node.outputs[0].dtype) - return [ - alloc( - val, - *[ - shapes[i] - for i in range(len(shapes)) - if i not in node.op.axis - ], - ) - ] - except NotScalarConstantError: - pass - - -@register_specialize -@local_optimizer([neg]) -def local_neg_div_neg(fgraph, node): - """ - - (-a / b) -> a / b - - Also performs - (c / b) -> ((-c) / b) when c is a scalar constant. - - """ - if node.op == neg: - if node.inputs[0].owner and node.inputs[0].owner.op == true_div: - frac = node.inputs[0] - num, denom = frac.owner.inputs - if num.owner and num.owner.op == neg: - if len(fgraph.clients[frac]) == 1: - # No other clients of the original division - new_num = num.owner.inputs[0] - return [true_div(new_num, denom)] - elif all(num.broadcastable) and isinstance(num, Constant): - if len(fgraph.clients[frac]) == 1: - new_num = -num.data - return [true_div(new_num, denom)] - - -@register_canonicalize -@local_optimizer([mul]) -def local_mul_zero(fgraph, node): - """ - As part of canonicalization, we replace multiplication by zero - with zero. - - """ - if node.op == mul: - otype = node.outputs[0].type - - for i in node.inputs: - try: - value = get_scalar_constant_value(i) - except NotScalarConstantError: - continue - # print 'MUL by value', value, node.inputs - if value == 0: - # print '... returning zeros' - return fill_chain(_asarray(0, dtype=otype.dtype), node.inputs) - - -# TODO: Add this to the canonicalization to reduce redundancy. -@register_specialize -@local_optimizer([true_div]) -def local_div_to_reciprocal(fgraph, node): - if node.op == true_div and np.all(get_constant(node.inputs[0]) == 1.0): - out = node.outputs[0] - new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], [])) - # The ones could have forced upcasting - if new_out.dtype != out.dtype: - new_out = cast(new_out, dtype=out.dtype) - # The ones could have forced a specific length - if not out.type.is_super(new_out.type): - new_out = broadcast_like(new_out, out, fgraph) - return [new_out] - else: - return False - - -@register_canonicalize -@local_optimizer([reciprocal]) -def local_reciprocal_canon(fgraph, node): - if node.op == reciprocal: - return [at_pow(node.inputs[0], -1.0)] - else: - return False - - -@register_canonicalize -@local_optimizer([at_pow]) -def local_pow_canonicalize(fgraph, node): - if node.op == at_pow: - cst = get_constant(node.inputs[1]) - if cst == 0: - return [broadcast_like(1, node.outputs[0], fgraph)] - if cst == 1: - return [broadcast_like(node.inputs[0], node.outputs[0], fgraph)] - else: - return False - - -@register_specialize -@local_optimizer([mul]) -def local_mul_to_sqr(fgraph, node): - """ - x*x -> sqr(x) - """ - if node.op == mul: - if len(node.inputs) == 2: - if node.inputs[0] is node.inputs[1]: - return [sqr(node.inputs[0])] - - -@register_canonicalize -@local_optimizer([int_div]) -def local_intdiv_by_one(fgraph, node): - """x // 1 -> x""" - if node.op in [int_div]: - if isinstance(node.inputs[1], TensorConstant) and np.all( - node.inputs[1].value == 1 - ): - return [node.inputs[0].astype(node.outputs[0].dtype)] - - -@register_canonicalize -@register_specialize -@local_optimizer([int_div, true_div]) -def local_zero_div(fgraph, node): - """0 / x -> 0""" - if isinstance(node.op, Elemwise) and isinstance( - node.op.scalar_op, (aes.IntDiv, aes.TrueDiv) - ): - if get_constant(node.inputs[0]) == 0: - ret = broadcast_like(0, node.outputs[0], fgraph) - ret.tag.values_eq_approx = values_eq_approx_remove_nan - return [ret] - - -@register_specialize -@local_optimizer([at_pow]) -def local_pow_specialize(fgraph, node): - # here, we are past the point of canonicalization, so we don't want - # to put in un-necessary fills. - if node.op == at_pow: - # the idea here is that we have pow(x, y) - odtype = node.outputs[0].dtype - xsym = node.inputs[0] - ysym = node.inputs[1] - y = get_constant(ysym) - if (y is not None) and encompasses_broadcastable( - xsym.type.broadcastable, ysym.type.broadcastable - ): - rval = None - - if np.all(y == 2): - rval = [sqr(xsym)] - if np.all(y == 1): - rval = [xsym] - if np.all(y == 0): - rval = [fill(xsym, np.asarray(1, dtype=odtype))] - if np.all(y == 0.5): - rval = [sqrt(xsym)] - if np.all(y == -0.5): - rval = [reciprocal(sqrt(xsym))] - if np.all(y == -1): - rval = [reciprocal(xsym)] - if np.all(y == -2): - rval = [reciprocal(sqr(xsym))] - if rval: - rval[0] = cast(rval[0], odtype) - assert rval[0].type == node.outputs[0].type, (rval, node.outputs) - return rval - else: - return False - - -@register_specialize_device -@local_optimizer([at_pow]) -def local_pow_specialize_device(fgraph, node): - """ - This optimization is not the same on all device. We do it only on cpu here. - """ - if node.op == at_pow: - # the idea here is that we have pow(x, y) - odtype = node.outputs[0].dtype - xsym = node.inputs[0] - ysym = node.inputs[1] - y = get_constant(ysym) - - # the next line is needed to fix a strange case that I don't - # know how to make a separate test. - # That happen in the test_opt.py:test_log_erfc test. - # y is a ndarray with dtype int8 and value 2,4 or 6. This make - # the abs(y) <= 512 fail! - # taking the value outside ndarray solve the problem. - # it could be that in that case, numpy make the comparison - # into the wrong type(do in int8 that overflow.) - if isinstance(y, np.ndarray): - assert y.size == 1 - try: - y = y[0] - except IndexError: - pass - if (y is not None) and encompasses_broadcastable( - xsym.type.broadcastable, ysym.type.broadcastable - ): - rval = None - # 512 is too small for the cpu and too big for some gpu! - if abs(y) == int(abs(y)) and abs(y) <= 512: - pow2 = [xsym] - pow2_scal = [aes.get_scalar_type(xsym.dtype)()] - y_to_do = abs(y) - for i in range(int(np.log2(y_to_do))): - pow2.append(sqr(pow2[i])) - pow2_scal.append(aes.sqr(pow2_scal[i])) - rval1 = None - rval1_scal = None - while y_to_do > 0: - log_to_do = int(np.log2(y_to_do)) - if rval1: - rval1 *= pow2[log_to_do] - rval1_scal *= pow2_scal[log_to_do] - else: - rval1 = pow2[log_to_do] - rval1_scal = pow2_scal[log_to_do] - y_to_do -= 2**log_to_do - - if abs(y) > 2: - # We fuse all the pow together here to make - # compilation faster - rval1 = Elemwise( - aes.Composite([pow2_scal[0]], [rval1_scal]) - ).make_node(xsym) - if y < 0: - rval = [reciprocal(rval1)] - else: - rval = [rval1] - if rval: - rval[0] = cast(rval[0], odtype) - assert rval[0].type == node.outputs[0].type, (rval, node.outputs) - return rval - - -@register_specialize -@local_optimizer([mul]) -def local_mul_specialize(fgraph, node): - """ - Remove special-case constants from mul arguments and useless neg in inputs. - - mul(-1, x) -> neg(x) - mul(1, x, y) -> mul(x, y) - mul(0, ...) -> alloc(0, shapes...) - - This is not done if we would add more nodes in the graph, like with: - - mul(-1, x, y) -/-> neg(mul(x, y)) - - """ - # here, we are past the point of canonicalization, so we don't - # want to put in un-necessary fills. - # - # at this point [post canonicalize], mul() may have many inputs. - if node.op == mul: - # the idea here is that we have pow(x, y) - has_neg = False - new_inputs = [] - nb_neg_node = 0 - nb_cst = 0 - for inp in node.inputs: - # remove any neg arguments - while inp.owner and inp.owner.op == neg: - has_neg ^= True - inp = inp.owner.inputs[0] - nb_neg_node += 1 - - # remove special case arguments of 1, -1 or 0 - y = get_constant(inp) - if y == 1.0: - nb_cst += 1 - elif y == -1.0: - nb_cst += 1 - has_neg ^= True # toggles - elif y == 0.0: - # if we find any zero, we just return right away - return [broadcast_like(0, node.outputs[0], fgraph)] - else: - new_inputs.append(inp) - - if new_inputs != node.inputs: - if new_inputs: - if len(new_inputs) == 1: - if has_neg: - if new_inputs[0].dtype in (uint_dtypes + ["bool"]): - return - else: - rval = -new_inputs[0] - else: - rval = new_inputs[0] - else: - # The next case would cause a replace by an equivalent case. - if has_neg and nb_neg_node == 0 and nb_cst == 1: - return - elif has_neg: - # Don't add an extra neg node as we can't - # fully replace this mul by a neg. - m1 = np.asarray(-1, dtype=node.outputs[0].dtype) - new_inputs = [m1] + new_inputs - rval = mul(*new_inputs) - - return [broadcast_like(rval, node.outputs[0], fgraph)] - else: - # there are no variable inputs to mul - # N.B. this could have been constant-folded... - if has_neg: - return [broadcast_like(-1, node.outputs[0], fgraph)] - else: - return [broadcast_like(1, node.outputs[0], fgraph)] - - -@register_specialize -@local_optimizer([add]) -def local_add_specialize(fgraph, node): - """Remove zeros from ``add``s. - - TODO: This should be a canonicalization, no? - """ - # here, we are past the point of canonicalization, so we don't want - # to put in un-necessary fills. - if node.op != add: - return False - - new_inputs = [] - for inp in node.inputs: - try: - y = get_scalar_constant_value(inp) - except NotScalarConstantError: - y = inp - if np.all(y == 0.0): - continue - new_inputs.append(inp) - - if len(new_inputs) == len(node.inputs): - return False - - node_output = node.outputs[0] - dtype = node_output.type.dtype - - if len(new_inputs) == 0: - # we got rid of the entire expression! - ndim = node_output.type.ndim - # Reuse call to constant for cache() - cst = constant(np.zeros((1,) * ndim, dtype=dtype)) - assert cst.type.broadcastable == (True,) * ndim - return fill_chain(cst, node.inputs) - - if len(new_inputs) == 1: - ret = fill_chain(new_inputs[0], node.inputs) - else: - ret = fill_chain(add(*new_inputs), node.inputs) - - # The dtype should not be changed. It can happen if the input - # that was forcing upcasting was equal to 0. - if ret[0].dtype != dtype: - ret = [cast(ret[0], dtype)] - - return ret - - -mul_canonizer = in2out( - LocalOptGroup(local_mul_canonizer, local_fill_sink, apply_all_opts=True), - name="mul_canonizer_groups", -) - - -def check_for_x_over_absX(numerators, denominators): - """Convert x/abs(x) into sign(x).""" - # TODO: this function should dig/search through dimshuffles - # This won't catch a dimshuffled absolute value - for den in list(denominators): - if den.owner and den.owner.op == at_abs and den.owner.inputs[0] in numerators: - if den.owner.inputs[0].type.dtype.startswith("complex"): - # TODO: Make an Op that projects a complex number to - # have unit length but projects 0 to 0. That - # would be a weird Op, but consistent with the - # special case below. I heard there's some - # convention in Matlab that is similar to - # this... but not sure. - pass - else: - denominators.remove(den) - numerators.remove(den.owner.inputs[0]) - numerators.append(sgn(den.owner.inputs[0])) - return numerators, denominators - - -local_mul_canonizer.add_simplifier(check_for_x_over_absX, "X_over_absX") - - -@register_canonicalize -@local_optimizer([at_abs]) -def local_abs_lift(fgraph, node): - """ - Move the abs toward the input. - - This is needed for check_for_x_over_absX to apply in more case. - - """ - if node.op == at_abs and node.inputs[0].owner: - assert node.nin == 1 - if node.inputs[0].owner.op == mul: - return [mul(*[at_abs(i) for i in node.inputs[0].owner.inputs])] - if node.inputs[0].owner.op == true_div: - i = node.inputs[0].owner.inputs - return [true_div(at_abs(i[0]), at_abs(i[1]))] - - -@register_specialize -@local_optimizer([mul, true_div]) -def local_abs_merge(fgraph, node): - """ - Merge abs generated by local_abs_lift when the canonizer don't - need it anymore - - """ - if node.op == mul and sum(i.owner.op == at_abs for i in node.inputs if i.owner) > 1: - inputs = [] - for i in node.inputs: - if i.owner and i.owner.op == at_abs: - inputs.append(i.owner.inputs[0]) - elif isinstance(i, Constant): - try: - const = get_scalar_constant_value(i, only_process_constants=True) - except NotScalarConstantError: - return False - if not (const >= 0).all(): - return False - inputs.append(i) - else: - return False - return [at_abs(mul(*inputs))] - if ( - node.op == true_div - and sum(i.owner.op == at_abs for i in node.inputs if i.owner) == 2 - ): - return [ - at_abs( - true_div(node.inputs[0].owner.inputs[0], node.inputs[1].owner.inputs[0]) - ) - ] - - -@register_stabilize -@register_specialize -@local_optimizer([log]) -def local_log1p(fgraph, node): - # log(1+x) -> log1p(x) - # log(1-x) -> log1p(-x) - if node.op == log: - (log_arg,) = node.inputs - if log_arg.owner and log_arg.owner.op == add: - scalars, scalar_inputs, nonconsts = scalarconsts_rest( - log_arg.owner.inputs, only_process_constants=True - ) - # scalar_inputs are potentially dimshuffled and fill'd scalars - if scalars and np.allclose(np.sum(scalars), 1): - if nonconsts: - if len(nonconsts) > 1: - ninp = add(*nonconsts) - else: - ninp = nonconsts[0] - if ninp.dtype != log_arg.type.dtype: - ninp = ninp.astype(node.outputs[0].dtype) - return fill_chain(log1p(ninp), scalar_inputs) - - elif log_arg.owner and log_arg.owner.op == sub: - one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True) - if one != 1: - return - other = log_arg.owner.inputs[1] - if other.dtype != log_arg.dtype: - other = other.astype(log_arg.dtype) - return [log1p(neg(other))] - - -@register_stabilize -@register_specialize -@local_optimizer([log]) -def local_log_add_exp(fgraph, node): - """ - ``log(exp(x)+exp(y)+exp(z)) = max + log(x-max, y-max, z-max)`` - - TODO: in canonicalize, change log10 and log2 -> log - """ - - if node.op == log: - z = node.inputs[0] - if z.owner and z.owner.op == add: - zi = z.owner.inputs - pre_exp = [x.owner.inputs[0] for x in zi if x.owner and x.owner.op == exp] - # all arguments to add are exp() - if len(pre_exp) == len(zi): - # Do not offset when max_pre = -np.inf, to avoid nan in the output - # Switch statement is placed directly inside add to break the self-symmetry - # of the returned output (otherwise the optimization would not stabilize) - max_pre = reduce(maximum, pre_exp) - ret = max_pre + log( - add( - *[ - switch(isinf(max_pre), exp(max_pre), exp(p - max_pre)) - for p in pre_exp - ] - ) - ) - return [ret] - - -@register_stabilize -@register_specialize -@local_optimizer([log]) -def local_log_sum_exp(fgraph, node): - # log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max))) - - if node.op != log: - return - - sum_node = node.inputs[0].owner - # If the sum has keepdims=True, there might be a dimshuffle - if sum_node and isinstance(sum_node.op, DimShuffle): - dimshuffle_op = sum_node.op - sum_node = sum_node.inputs[0].owner - else: - dimshuffle_op = None - - if not sum_node or not isinstance(sum_node.op, Sum): - return - - exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis - if not exp_node or not ( - isinstance(exp_node.op, Elemwise) and isinstance(exp_node.op.scalar_op, aes.Exp) - ): - return - - pre_exp = exp_node.inputs[0] - max_pre_exp = at_max(pre_exp, axis=axis) - max_pre_exp_keepdims = makeKeepDims(pre_exp, max_pre_exp, axis) - - # Do not offset when max_pre = -np.inf, to avoid nan in the output - # Switch statement is placed directly inside sum to break the self-symmetry - # of the returned output (otherwise the optimization would not stabilize) - ret = max_pre_exp + log( - at_sum( - switch( - isinf(max_pre_exp_keepdims), - exp(max_pre_exp_keepdims), - exp(pre_exp - max_pre_exp_keepdims), - ), - axis=axis, - ), - ) - - # Restore the dimshuffle op, if any. - if dimshuffle_op: - ret = dimshuffle_op(ret) - - return [ret] - - -def add_calculate(num, denum, aslist=False, out_type=None): - # TODO: make sure that this function and mul_calculate are similar - if out_type is None: - zero = 0.0 - else: - zero = _asarray(0, dtype=out_type.dtype) - # zero = 0.0 if out_type is None else _asarray(0, - # dtype=out_type.dtype) - if out_type and out_type.dtype == "bool": - if len(denum) == 0: - # NumPy 1.14 do not accept to do "bool - bool" - v = reduce(np.add, num, zero) - else: - raise Exception( - "bool subtraction not supported. This should not happen as" - " an earlier error should have been raised" - ) - else: - v = reduce(np.add, num, zero) - reduce(np.add, denum, zero) - if aslist: - if np.all(v == 0): - return [] - else: - return [v] - return v - - -local_add_canonizer = AlgebraicCanonizer(add, sub, neg, add_calculate) -add_canonizer = in2out( - LocalOptGroup(local_add_canonizer, local_fill_sink, apply_all_opts=True), - name="add_canonizer_group", -) - - -register_canonicalize(local_add_canonizer, name="local_add_canonizer") - - -def distribute_greedy(pos_pairs, neg_pairs, num, denum, out_type, minscore=0): - # each pair in pos_pairs and neg_pairs is a num/denum pair. this - # function attempts to add num and denum to the corresponding parts - # of each pair, and counts how many multiplications/divisions can - # be saved in that way. - - # each division is counted like div_cost multiplications - # (typically, division costs more so we are willing to multiply more - # in order to divide less) - # 1.5 was obtained through an informal test and may very well be - # platform dependent - div_cost = 1.5 - - # score is number of operations saved, higher is better - score = len(num) + div_cost * len(denum) - new_pos_pairs = list( - itertools.starmap( - local_mul_canonizer.simplify, - [(n + num, d + denum, out_type) for (n, d) in pos_pairs], - ) - ) - new_neg_pairs = list( - itertools.starmap( - local_mul_canonizer.simplify, - [(n + num, d + denum, out_type) for (n, d) in neg_pairs], - ) - ) - for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs): - # We calculate how many operations we are saving with the new - # num and denum - score += len(n) + div_cost * len(d) - len(nn) - div_cost * len(dd) - if score <= minscore: - # the change is not applied because it adds too many operations - return False, pos_pairs, neg_pairs - return True, new_pos_pairs, new_neg_pairs - - -def attempt_distribution(factor, num, denum, out_type): - """Try to insert each `num` and each `denum` in the factor? - - Returns - ------- - changes?, new_factor, new_num, new_denum - If there are changes, `new_num` and `new_denum` contain all the - numerators and denominators that could not be distributed in the factor - - """ - pos_terms, neg_terms = local_add_canonizer.get_num_denum(factor) - if len(pos_terms) == 1 and not neg_terms: - return False, factor, num, denum - pos_pairs = list(map(local_mul_canonizer.get_num_denum, pos_terms)) - neg_pairs = list(map(local_mul_canonizer.get_num_denum, neg_terms)) - change = False - for n in list(num): - success, pos_pairs, neg_pairs = distribute_greedy( - pos_pairs, neg_pairs, [n], [], out_type - ) - if success: - change = True - num.remove(n) - for d in list(denum): - success, pos_pairs, neg_pairs = distribute_greedy( - pos_pairs, neg_pairs, [], [d], out_type - ) - if success: - change = True - denum.remove(d) - if not change: - return change, factor, num, denum - else: - return ( - change, - local_add_canonizer.merge_num_denum( - list(itertools.starmap(local_mul_canonizer.merge_num_denum, pos_pairs)), - list(itertools.starmap(local_mul_canonizer.merge_num_denum, neg_pairs)), - ), - num, - denum, - ) - - -@register_canonicalize -@register_stabilize -@local_optimizer([mul, true_div, reciprocal]) -def local_greedy_distributor(fgraph, node): - """ - Optimize by reducing the number of multiplications and/or divisions. - - This optimization tries to apply distributivity of multiplication - to addition in order to reduce the number of multiplications - and/or divisions that must be done. The algorithm weighs division - more than multiplication to account for the former's slightly - greater computational cost. - - The following expressions are simplified: - 1. ((a/x + b/y) * x * y) --> a*y + b*x - 2. ((a/x + b) * x) --> a + b*x - 3. There are other forms too where node is a true_div. - - The following expressions are not simplified: - 4. ((a + b) * x) -/-> a*x + b*x - - This optimization aims to reduce computational cost. It may also - increase numerical stability, e.g. when x and/or y tend to 0 in - example 1. - - """ - - out = node.outputs[0] - num, denum = local_mul_canonizer.get_num_denum(out) - if len(num) == 1 and not denum: - return False - - new_num, new_denum = [], [] - - change = False - - out_type = out.type - for candidate in list(num): - if candidate not in num: - continue - num.remove(candidate) - _change, candidate, num, denum = attempt_distribution( - candidate, - num, - denum, - out_type, - ) - - change |= _change - new_num.append(candidate) - - for candidate in list(denum): - if candidate not in denum: - continue - denum.remove(candidate) - _change, candidate, denum, num = attempt_distribution( - candidate, denum, num, out_type - ) - change |= _change - new_denum.append(candidate) - if not change: - return False - - new_num += num - new_denum += denum - - rval = local_mul_canonizer.merge_num_denum(new_num, new_denum) - - if rval.type != out.type: - # WHY DOES THIS HAPPEN? - return False - - return [rval] - - -get_clients_at_depth1 = partial(get_clients_at_depth, depth=1) -get_clients_at_depth2 = partial(get_clients_at_depth, depth=2) - -# 1+erf(x)=>erfc(-x) -local_one_plus_erf = PatternSub( - (add, 1, (erf, "x")), - (erfc, (neg, "x")), - allow_multiple_clients=True, - name="local_one_plus_erf", - tracks=[erf], - get_nodes=get_clients_at_depth1, -) -register_canonicalize(local_one_plus_erf) -register_stabilize(local_one_plus_erf) -register_specialize(local_one_plus_erf) - -# Only one of the two rewrites below is needed if a canonicalization is added -# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y) -# 1-erf(x)=>erfc(x) -local_one_minus_erf = PatternSub( - (sub, 1, (erf, "x")), - (erfc, "x"), - allow_multiple_clients=True, - name="local_one_minus_erf", - tracks=[erf], - get_nodes=get_clients_at_depth1, -) -register_canonicalize(local_one_minus_erf) -register_stabilize(local_one_minus_erf) -register_specialize(local_one_minus_erf) - -local_one_minus_erf2 = PatternSub( - (add, 1, (neg, (erf, "x"))), - (erfc, "x"), - allow_multiple_clients=True, - name="local_one_minus_erf2", - tracks=[erf], - get_nodes=get_clients_at_depth2, -) -register_canonicalize(local_one_minus_erf2) -register_stabilize(local_one_minus_erf2) -register_specialize(local_one_minus_erf2) - -# (-1)+erf(x) => -erfc(x) -# There is no need for erf(x)+(-1) nor erf(x) - 1, as the canonicalize will -# convert those to the matched pattern -local_erf_minus_one = PatternSub( - (add, -1, (erf, "x")), - (neg, (erfc, "x")), - allow_multiple_clients=True, - name="local_erf_minus_one", - tracks=[erf], - get_nodes=get_clients_at_depth1, -) -register_canonicalize(local_erf_minus_one) -register_stabilize(local_erf_minus_one) -register_specialize(local_erf_minus_one) - -# Only one of the two rewrites below is needed if a canonicalization is added -# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y) -# 1-erfc(x) => erf(x) -local_one_minus_erfc = PatternSub( - (sub, 1, (erfc, "x")), - (erf, "x"), - allow_multiple_clients=True, - name="local_one_minus_erfc", - tracks=[erfc], - get_nodes=get_clients_at_depth1, -) -register_canonicalize(local_one_minus_erfc) -register_stabilize(local_one_minus_erfc) -register_specialize(local_one_minus_erfc) - -local_one_minus_erfc2 = PatternSub( - (add, 1, (neg, (erfc, "x"))), - (erf, "x"), - allow_multiple_clients=True, - name="local_one_minus_erfc2", - tracks=[erfc], - get_nodes=get_clients_at_depth2, -) -register_canonicalize(local_one_minus_erfc2) -register_stabilize(local_one_minus_erfc2) -register_specialize(local_one_minus_erfc2) - -# (-1)+erfc(-x)=>erf(x) -local_erf_neg_minus_one = PatternSub( - (add, -1, (erfc, (neg, "x"))), - (erf, "x"), - allow_multiple_clients=True, - name="local_erf_neg_minus_one", - tracks=[erfc], - get_nodes=get_clients_at_depth1, -) -register_canonicalize(local_erf_neg_minus_one) -register_stabilize(local_erf_neg_minus_one) -register_specialize(local_erf_neg_minus_one) - - -@register_stabilize -@register_specialize -@local_optimizer([log]) -def local_log_erfc(fgraph, node): - """Stability optimization for `log(erfc(x))`. - - log(erfc(x)) => when x>threshold, - -x**2-log(x)-.5*log(pi)+log(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)) - for float64: threshold=26.641747557 was chosen with: - [(i,numpy.log(scipy.special.erfc(numpy.asarray([i],dtype='float64')))) - for i in numpy.arange(26.641747557,26.6417475571,.00000000001)] - for float32: threshold=10.0541949, [(i,numpy.log(scipy.special.erfc( - numpy.asarray([i],dtype='float32')))) for i in numpy.arange( - 10.0541948,10.0541951,.0000001)] - """ - if node.op != log: - return False - if not node.inputs[0].owner or node.inputs[0].owner.op != erfc: - return False - - if hasattr(node.tag, "local_log_erfc_applied"): - # We use that flag to don't apply the optimization recursively - return False - node.tag.local_log_erfc_applied = True - - x = node.inputs[0].owner.inputs[0] - stab_value = ( - -(x**2) - - log(x) - - 0.5 * log(np.pi) - + log(1 - 1 / (2 * x**2) + 3 / (4 * x**4) - 15 / (8 * x**6)) - ) - - if node.outputs[0].dtype == "float32" or node.outputs[0].dtype == "float16": - threshold = 10.0541949 - elif node.outputs[0].dtype == "float64": - threshold = 26.641747557 - - ret = switch(x < threshold, node.outputs[0], stab_value) - ret.tag.values_eq_approx = values_eq_approx_remove_inf - return [ret] - - -@register_stabilize -@register_specialize -@local_optimizer([true_div]) -def local_grad_log_erfc_neg(fgraph, node): - """Stability optimization for the grad of `log(erfc(x))`. - - ([y*]exp(-(x**2)))/erfc(x) # The y* is optional - ([y*]exp(x**2))/erfc(-x) => [y*](when x > threshold, - sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))) - - for float64: threshold=26.63 see at the end of the fct for the explanation - for float32: threshold=9.3 see at the end of the fct for the explanation - - TODO: remove the constraint that there are only 2 inputs to exp(x**2) - is the second. - TODO: at the test point 10 in float32, there is instability in the original - value. The original gives -30.0, the stab -20.1 and in float64 -18.1. - Make it so that the test does not generate an error in that case! - - """ - if node.op != true_div: - return False - if not node.inputs[1].owner or node.inputs[1].owner.op != erfc: - return False - - erfc_in = node.inputs[1] - erfc_x = erfc_in.owner.inputs[0] - - if not node.inputs[0].owner: - return False - - # TODO: All of this should be replaced with a single, simple unification - # The mul is optional. - if node.inputs[0].owner.op != mul: - mul_in = None - y = [] - if not node.inputs[0].owner or node.inputs[0].owner.op != exp: - return False - exp_in = node.inputs[0] - else: - mul_in = node.inputs[0] - exp_in = None - for idx, inp in enumerate(mul_in.owner.inputs): - if inp.owner and inp.owner.op == exp: - exp_in = inp - break - else: - return False - - if len(mul_in.owner.inputs) == 2: - y = [mul_in.owner.inputs[1 - idx]] - else: - y = mul_in.owner.inputs[:] - del y[idx] - - if not exp_in.owner.inputs[0].owner: - return False - - if exp_in.owner.inputs[0].owner.op == neg: - neg_in = exp_in.owner.inputs[0] - if not neg_in.owner.inputs[0].owner or neg_in.owner.inputs[0].owner.op != sqr: - return False - sqr_in = neg_in.owner.inputs[0] - x = sqr_in.owner.inputs[0] - elif exp_in.owner.inputs[0].owner.op == mul: - # We should compare that -(erfc_x**2) is equivalent to mul_neg. - # There is currently no easy way to do this in the general case, - # so we implement some common case for now. - - # In many cases the neg are replaced by mul in the graph. - # This also allows to stabilize log(erfc(cst*x)). - mul_neg = exp_in.owner.inputs[0] - - # In case that multiple mul are not fused together, we do it here. - def check_input(inputs): - new_inputs = [] - for i in inputs: - if i.owner and i.owner.op == mul: - new_inputs.extend(check_input(i.owner.inputs)) - else: - new_inputs.append(i) - return new_inputs - - mul_inputs = check_input(mul_neg.owner.inputs) - - # Put the constant first. - for i in range(len(mul_inputs)): - if isinstance(i, Constant): - if i == 0: - break - else: - tmp = mul_inputs[0] - mul_inputs[0] = mul_inputs[i] - mul_inputs[i] = tmp - break - mul_neg = mul(*mul_inputs) - - try: - cst2 = get_scalar_constant_value( - mul_neg.owner.inputs[0], only_process_constants=True - ) - except NotScalarConstantError: - return False - - if len(mul_neg.owner.inputs) == 2: - if ( - not mul_neg.owner.inputs[1].owner - or mul_neg.owner.inputs[1].owner.op != sqr - ): - return False - sqr_in = mul_neg.owner.inputs[1] - x = sqr_in.owner.inputs[0] - elif len(mul_neg.owner.inputs) == 3: - if mul_neg.owner.inputs[1] is not mul_neg.owner.inputs[2]: - return False - x = mul_neg.owner.inputs[1] - else: - return False - - if cst2 != -1: - if ( - not erfc_x.owner - or erfc_x.owner.op != mul - or len(erfc_x.owner.inputs) != 2 - ): - # todo implement that case - return False - if erfc_x.owner.inputs[1] is not mul_neg.owner.inputs[1]: - return False - - x = erfc_x - try: - cst = get_scalar_constant_value( - erfc_x.owner.inputs[0], only_process_constants=True - ) - except NotScalarConstantError: - return False - if cst2 != -cst * 2: - return False - - # The constant is valid. Must check that the - elif erfc_x is not x: - return False - - else: - return False - - if hasattr(node.tag, "local_grad_log_erfc_neg"): - # We use that flag to don't apply the optimization recursively - return False - - if erfc_x is not x: - return None - - # we move the y outside the div. - true_div_no_mul = true_div(exp_in, erfc_in) - true_div_no_mul.owner.tag.local_grad_log_erfc_neg = True - - # aaron value - stab_value = ( - x - * at_pow(1 - 1 / (2 * (x**2)) + 3 / (4 * (x**4)) - 15 / (8 * (x**6)), -1) - * cast(sqrt(np.pi), dtype=x.dtype) - ) - - if x.dtype == "float32" or x.dtype == "float16": - threshold = 9.3 - # threshold = 10.1 - elif x.dtype == "float64": - threshold = 26.641747557 - - ret = switch(x < threshold, true_div_no_mul, stab_value) - - if y: - ret = mul(ret, *y) - - ret.tag.values_eq_approx = values_eq_approx_remove_inf_nan - - return [ret] - - -def local_add_mul_fusion(fgraph, node): - """Fuse consecutive add or mul in one such node with more inputs. - - It is better to fuse add/mul that way then in a Composite node as - this make the inner graph of the Composite smaller. This allow to - put more computation in a Composite before hitting the max - recursion limit when pickling Composite. - - """ - if not isinstance(node.op, Elemwise) or not isinstance( - node.op.scalar_op, (aes.Add, aes.Mul) - ): - return False - - s_op = node.op.scalar_op.__class__ - new_inp = [] - fused = False - nb_inputs = len(node.inputs) - max_inputs = float("inf") - if hasattr(node.op, "max_inputs"): - max_inputs = node.op.max_inputs(node) - for inp in node.inputs: - if ( - inp.owner - and isinstance(inp.owner.op, Elemwise) - and isinstance(inp.owner.op.scalar_op, s_op) - and - # Do not duplicate the operation. - len(fgraph.clients[inp]) == 1 - and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs - ): - new_inp.extend(inp.owner.inputs) - fused = True - else: - new_inp.append(inp) - - # We can not compare the number of inputs as Mul and Add could have - # 0 or 1 inputs in some corner cases. - if fused: - output = node.op(*new_inp) - copy_stack_trace(node.outputs[0], output) - - # Do the recursion here to help lower the number of - # FusionOptimizer iteration. - if output.owner: - output2 = local_add_mul_fusion(fgraph, output.owner) - if output2: - return output2 - return [output] - - -fuse_seqopt.register( - "local_add_mul_fusion", - FusionOptimizer(local_add_mul_fusion), - "fast_run", - "fusion", - position=0, -) - - -def _skip_mul_1(r): - if r.owner and r.owner.op == mul: - not_is_1 = [i for i in r.owner.inputs if not _is_1(i)] - if len(not_is_1) == 1: - return not_is_1[0] - - -def _is_1(expr): - """ - - Returns - ------- - bool - True iff expr is a constant close to 1. - - """ - try: - v = get_scalar_constant_value(expr) - return np.allclose(v, 1) - except NotScalarConstantError: - return False - - -logsigm_to_softplus = PatternSub( - (log, (sigmoid, "x")), - (neg, (softplus, (neg, "x"))), - allow_multiple_clients=True, - values_eq_approx=values_eq_approx_remove_inf, - skip_identities_fn=_skip_mul_1, - tracks=[sigmoid], - get_nodes=get_clients_at_depth1, -) -log1msigm_to_softplus = PatternSub( - (log, (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x"))), - (neg, (softplus, "x")), - allow_multiple_clients=True, - values_eq_approx=values_eq_approx_remove_inf, - skip_identities_fn=_skip_mul_1, - tracks=[sigmoid], - get_nodes=get_clients_at_depth2, -) -log1pexp_to_softplus = PatternSub( - (log1p, (exp, "x")), - (softplus, "x"), - values_eq_approx=values_eq_approx_remove_inf, - allow_multiple_clients=True, -) -log1p_neg_sigmoid = PatternSub( - (log1p, (neg, (sigmoid, "x"))), - (neg, (softplus, "x")), - values_eq_approx=values_eq_approx_remove_inf, - allow_multiple_clients=True, - tracks=[sigmoid], - get_nodes=get_clients_at_depth2, -) - -register_stabilize(logsigm_to_softplus, name="logsigm_to_softplus") -register_stabilize(log1msigm_to_softplus, name="log1msigm_to_softplus") -register_stabilize(log1pexp_to_softplus, name="log1pexp_to_softplus") -register_stabilize(log1p_neg_sigmoid, name="log1p_neg_sigmoid") -register_specialize(log1p_neg_sigmoid, name="log1p_neg_sigmoid") - - -def is_1pexp(t, only_process_constants=True): - """ - - Returns - ------- - object - If 't' is of the form (1+exp(x)), return (False, x). - Else return None. - - """ - if t.owner and t.owner.op == add: - scalars, scalar_inputs, nonconsts = scalarconsts_rest( - t.owner.inputs, only_process_constants=only_process_constants - ) - # scalar_inputs are potentially dimshuffled and filled with scalars - if len(nonconsts) == 1: - maybe_exp = nonconsts[0] - if maybe_exp.owner and maybe_exp.owner.op == exp: - # Verify that the constant terms sum to 1. - if scalars: - scal_sum = scalars[0] - for s in scalars[1:]: - scal_sum = scal_sum + s - if np.allclose(scal_sum, 1): - return False, maybe_exp.owner.inputs[0] - return None - - -def is_exp(var): - """ - Match a variable with either of the `exp(x)` or `-exp(x)` patterns. - - Parameters - ---------- - var - The Variable to analyze. - - Returns - ------- - tuple - A pair (b, x) with `b` a boolean set to True if `var` is of the - form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var` - cannot be cast into either form, then return `None`. - - """ - _neg = False - neg_info = is_neg(var) - if neg_info is not None: - _neg = True - var = neg_info - if var.owner and var.owner.op == exp: - return _neg, var.owner.inputs[0] - - -def is_mul(var): - """ - Match a variable with `x * y * z * ...`. - - Parameters - ---------- - var - The Variable to analyze. - - Returns - ------- - object - A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`, - or None if `var` cannot be cast into this form. - - """ - if var.owner and var.owner.op == mul: - return var.owner.inputs - else: - return None - - -def partition_num_or_denom(r, f): - if r.owner and r.owner.op == mul: - a = r.owner.inputs - else: - a = [r] - - # ugly 2.4-compatible thing - f_terms = [] - _neg = False - rest = [] - for t in a: - f_t = f(t) - if f_t is None: - rest.append(t) - else: - neg_t, f_t = f_t - f_terms.append(f_t) - _neg ^= neg_t # bit flip if neg_t is true - return f_terms, rest, _neg - - -def is_neg(var): - """ - Match a variable with the `-x` pattern. - - Parameters - ---------- - var - The Variable to analyze. - - Returns - ------- - object - `x` if `var` is of the form `-x`, or None otherwise. - - """ - var_node = var.owner - if not var_node: - return None - # First match against `neg`. - if var_node.op == neg: - return var_node.inputs[0] - # Then match against a multiplication by -1. - if var_node.op == mul and len(var_node.inputs) >= 2: - for idx, mul_input in enumerate(var_node.inputs): - try: - constant = get_scalar_constant_value(mul_input) - is_minus_1 = np.allclose(constant, -1) - except NotScalarConstantError: - is_minus_1 = False - if is_minus_1: - # Found a multiplication by -1. - if len(var_node.inputs) == 2: - # Only return the other input. - return var_node.inputs[1 - idx] - else: - # Return the multiplication of all other inputs. - return mul(*(var_node.inputs[0:idx] + var_node.inputs[idx + 1 :])) - # No match. - return None - - -@register_stabilize -@local_optimizer([true_div]) -def local_exp_over_1_plus_exp(fgraph, node): - """ - exp(x)/(1+exp(x)) -> sigm(x) - c/(1+exp(x)) -> c*sigm(-x) - - """ - # this optimization should be done for numerical stability - # so we don't care to check client counts - if node.op == true_div: - - # find all the exp() terms in the numerator - num, denom = node.inputs - num_exp_x, num_rest, num_neg = partition_num_or_denom(num, is_exp) - denom_1pexp, denom_rest, denom_neg = partition_num_or_denom(denom, is_1pexp) - - sigmoids = [] - for t in denom_1pexp: - if t in num_exp_x: - # case: exp(x) /(1+exp(x)) - sigmoids.append(sigmoid(t)) - del num_exp_x[num_exp_x.index(t)] - else: - # case: 1/(1+exp(x)) - sigmoids.append(sigmoid(-t)) - copy_stack_trace(node.outputs[0], sigmoids[-1]) - - if not sigmoids: # we didn't find any. abort - return - # put the new numerator together - new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest - if len(new_num) == 1: - new_num = new_num[0] - else: - new_num = mul(*new_num) - - if num_neg ^ denom_neg: - new_num = -new_num - - copy_stack_trace(num, new_num) - - if len(denom_rest) == 0: - return [new_num] - elif len(denom_rest) == 1: - out = new_num / denom_rest[0] - else: - out = new_num / mul(*denom_rest) - - copy_stack_trace(node.outputs[0], out) - return [out] - - -def parse_mul_tree(root): - """ - Parse a tree of multiplications starting at the given root. - - Parameters - ---------- - root - The variable at the root of the tree. - - Returns - ------- - object - A tree where each non-leaf node corresponds to a multiplication - in the computation of `root`, represented by the list of its inputs. - Each input is a pair [n, x] with `n` a boolean value indicating whether - sub-tree `x` should be negated. - - Examples - -------- - - .. code-block:: python - - x * y -> [False, [[False, x], [False, y]]] - -(x * y) -> [True, [[False, x], [False, y]]] - -x * y -> [False, [[True, x], [False, y]]] - -x -> [True, x] - (x * y) * -z -> [False, [[False, [[False, x], [False, y]]], - [True, z]]] - - """ - # Is it a multiplication? - mul_info = is_mul(root) - if mul_info is None: - # Is it a negation? - neg_info = is_neg(root) - if neg_info is None: - # Keep the root "as is". - return [False, root] - else: - # Recurse, inverting the negation. - neg, sub_tree = parse_mul_tree(neg_info) - return [not neg, sub_tree] - else: - # Recurse into inputs. - return [False, list(map(parse_mul_tree, mul_info))] - - -def replace_leaf(arg, leaves, new_leaves, op, neg): - """ - Attempt to replace a leaf of a multiplication tree. - - We search for a leaf in `leaves` whose argument is `arg`, and if we find - one, we remove it from `leaves` and add to `new_leaves` a leaf with - argument `arg` and variable `op(arg)`. - - Parameters - ---------- - arg - The argument of the leaf we are looking for. - leaves - List of leaves to look into. Each leaf should be a pair - (x, l) with `x` the argument of the Op found in the leaf, and `l` the - actual leaf as found in a multiplication tree output by `parse_mul_tree` - (i.e. a pair [boolean, variable]). - new_leaves - If a replacement occurred, then the leaf is removed from `leaves` - and added to the list `new_leaves` (after being modified by `op`). - op - A function that, when applied to `arg`, returns the Variable - we want to replace the original leaf variable with. - neg : bool - If True, then the boolean value associated to the leaf should - be swapped. If False, then this value should remain unchanged. - - Returns - ------- - bool - True if a replacement occurred, or False otherwise. - - """ - for idx, x in enumerate(leaves): - if x[0] == arg: - x[1][0] ^= neg - x[1][1] = op(arg) - leaves.pop(idx) - new_leaves.append(x) - return True - return False - - -def simplify_mul(tree): - """ - Simplify a multiplication tree. - - Parameters - ---------- - tree - A multiplication tree (as output by `parse_mul_tree`). - - Returns - ------- - object - A multiplication tree computing the same output as `tree` but without - useless multiplications by 1 nor -1 (identified by leaves of the form - [False, None] or [True, None] respectively). Useless multiplications - (with less than two inputs) are also removed from the tree. - - """ - neg, inputs = tree - if isinstance(inputs, list): - # Recurse through inputs. - s_inputs = [] - for s_i in map(simplify_mul, inputs): - if s_i[1] is None: - # Multiplication by +/-1. - neg ^= s_i[0] - else: - s_inputs.append(s_i) - if not s_inputs: - # The multiplication is empty. - rval = [neg, None] - elif len(s_inputs) == 1: - # The multiplication has a single input. - s_inputs[0][0] ^= neg - rval = s_inputs[0] - else: - rval = [neg, s_inputs] - else: - rval = tree - # print 'simplify_mul: %s -> %s' % (tree, rval) - return rval - - -def compute_mul(tree): - """ - Compute the Variable that is the output of a multiplication tree. - - This is the inverse of the operation performed by `parse_mul_tree`, i.e. - compute_mul(parse_mul_tree(tree)) == tree. - - Parameters - ---------- - tree - A multiplication tree (as output by `parse_mul_tree`). - - Returns - ------- - object - A Variable that computes the multiplication represented by the tree. - - """ - neg, inputs = tree - if inputs is None: - raise AssertionError( - "Function `compute_mul` found a missing leaf, did you forget to " - "call `simplify_mul` on the tree first?" - ) - elif isinstance(inputs, list): - # Recurse through inputs. - rval = mul(*list(map(compute_mul, inputs))) - else: - rval = inputs - if neg: - rval = -rval - return rval - - -def perform_sigm_times_exp( - tree, - exp_x=None, - exp_minus_x=None, - sigm_x=None, - sigm_minus_x=None, - parent=None, - child_idx=None, - full_tree=None, -): - """ - Core processing of the `local_sigm_times_exp` optimization. - - This recursive function operates on a multiplication tree as output by - `parse_mul_tree`. It walks through the tree and modifies it in-place - by replacing matching pairs (exp, sigmoid) with the desired optimized - version. - - Parameters - ---------- - tree - The sub-tree to operate on. - exp_x - List of arguments x so that `exp(x)` exists somewhere in the whole - multiplication tree. Each argument is a pair (x, leaf) with `x` the - argument of the exponential, and `leaf` the corresponding leaf in the - multiplication tree (of the form [n, exp(x)] -- see `parse_mul_tree`). - If None, this argument is initialized to an empty list. - exp_minus_x - Similar to `exp_x`, but for `exp(-x)`. - sigm_x - Similar to `exp_x`, but for `sigmoid(x)`. - sigm_minus_x - Similar to `exp_x`, but for `sigmoid(-x)`. - parent - Parent of `tree` (None if `tree` is the global root). - child_idx - Index of `tree` in its parent's inputs (None if `tree` is the global - root). - full_tree - The global multiplication tree (should not be set except by recursive - calls to this function). Used for debugging only. - - Returns - ------- - bool - True if a modification was performed somewhere in the whole multiplication - tree, or False otherwise. - - """ - if exp_x is None: - exp_x = [] - if exp_minus_x is None: - exp_minus_x = [] - if sigm_x is None: - sigm_x = [] - if sigm_minus_x is None: - sigm_minus_x = [] - if full_tree is None: - full_tree = tree - if False: # Debug code. - print("") - print(f" full_tree = {full_tree}") - print(f" tree = {tree}") - print(f" exp_x = {exp_x}") - print(f" exp_minus_x = {exp_minus_x}") - print(f" sigm_x = {sigm_x}") - print(f" sigm_minus_x= {sigm_minus_x}") - neg, inputs = tree - if isinstance(inputs, list): - # Recurse through inputs of the multiplication. - rval = False - for sub_idx, sub_tree in enumerate(inputs): - rval |= perform_sigm_times_exp( - tree=sub_tree, - parent=tree, - child_idx=sub_idx, - exp_x=exp_x, - exp_minus_x=exp_minus_x, - sigm_x=sigm_x, - sigm_minus_x=sigm_minus_x, - full_tree=full_tree, - ) - return rval - else: - # Reached a leaf: if it is an exponential or a sigmoid, then we - # first attempt to find a match in leaves already visited. - # If there is such a match, we modify the already-visited leaf - # accordingly: for instance if we visited a leaf sigmoid(x), then - # find later a -exp(-x), we replace the previous leaf by - # -sigmoid(-x) and remove the -exp(-x) from the tree. - # If no match is found, then we register this leaf so that it can - # be found later while walking the tree. - var = inputs - keep_it = False - exp_info = is_exp(var) - if exp_info is not None: - exp_neg, exp_arg = exp_info - neg ^= exp_neg - neg_arg = is_neg(exp_arg) - if neg_arg is None: - if not replace_leaf(exp_arg, sigm_minus_x, sigm_x, sigmoid, neg): - exp_x.append((exp_arg, tree)) - keep_it = True - else: - if not replace_leaf( - neg_arg, sigm_x, sigm_minus_x, lambda x: sigmoid(-x), neg - ): - exp_minus_x.append((neg_arg, tree)) - keep_it = True - elif var.owner and var.owner.op == sigmoid: - sigm_arg = var.owner.inputs[0] - neg_arg = is_neg(sigm_arg) - if neg_arg is None: - if not replace_leaf( - sigm_arg, exp_minus_x, sigm_minus_x, lambda x: sigmoid(-x), neg - ): - sigm_x.append((sigm_arg, tree)) - keep_it = True - else: - if not replace_leaf(neg_arg, exp_x, sigm_x, sigmoid, neg): - sigm_minus_x.append((neg_arg, tree)) - keep_it = True - else: - # It is not an exponential nor a sigmoid. - keep_it = True - if not keep_it: - # Delete this leaf, i.e. replace it by [False, None] (corresponding - # to a multiplication by 1). - assert parent is not None - parent[1][child_idx] = [False, None] - return not keep_it - - -@register_stabilize -@local_optimizer([mul]) -def local_sigm_times_exp(fgraph, node): - """ - exp(x) * sigm(-x) -> sigm(x) - exp(-x) * sigm(x) -> sigm(-x) - - todo: add stack traces to the intermediate variables - """ - # Bail early if it is not a multiplication. - if node.op != mul: - return None - # Obtain tree of multiplications starting at this node. - mul_tree = parse_mul_tree(node.outputs[0]) - # Perform core optimization. - did_something = perform_sigm_times_exp(mul_tree) - if not did_something: - # No change. - return None - # The optimization may have introduced multiplications by 1 in the tree: - # get rid of them. - mul_tree = simplify_mul(mul_tree) - # Recompute final output based on the updated tree. - out = compute_mul(mul_tree) - # keep the stack trace - copy_stack_trace(node.outputs[0], out) - return [out] - - -@register_stabilize -@local_optimizer([reciprocal]) -def local_reciprocal_1_plus_exp(fgraph, node): - """``reciprocal(1+exp(x)) -> sigm(-x)`` - - TODO: This is redundant; we can just decided on *one* canonical form - for division (e.g. either `true_div` or `reciprocal`) and have this - taken care of with existing rewrites. - """ - # this optimization should be done for numerical stability - # so we don't care to check client counts - if node.op == reciprocal: - reciprocal_arg = node.inputs[0] - if reciprocal_arg.owner and reciprocal_arg.owner.op == add: - scalars_, scalar_inputs, nonconsts = scalarconsts_rest( - reciprocal_arg.owner.inputs, only_process_constants=True - ) - # scalar_inputs are potentially dimshuffled and fill'd scalars - if len(nonconsts) == 1: - if nonconsts[0].owner and nonconsts[0].owner.op == exp: - if scalars_ and np.allclose(np.sum(scalars_), 1): - out = fill_chain( - sigmoid(neg(nonconsts[0].owner.inputs[0])), - scalar_inputs, - ) - # keep combined stack traces of - # exp(x): nonconsts[0], - # 1 + exp(x): reciprocal_arg, - # 1 / (1 + exp(x)): node.outputs[0] - copy_stack_trace( - [nonconsts[0], reciprocal_arg, node.outputs[0]], out - ) - return out - - -# 1 - sigmoid(x) -> sigmoid(-x) -local_1msigmoid = PatternSub( - (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x")), - (sigmoid, (neg, "x")), - tracks=[sigmoid], - get_nodes=get_clients_at_depth1, - name="local_1msigmoid", -) -register_stabilize(local_1msigmoid) -register_specialize(local_1msigmoid) - - -log1pmexp_to_log1mexp = PatternSub( - (log1p, (neg, (exp, "x"))), - (log1mexp, "x"), - allow_multiple_clients=True, -) -register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp") - - -# log(sigmoid(x) / (1 - sigmoid(x))) -> x -# i.e logit(sigmoid(x)) -> x -local_logit_sigmoid = PatternSub( - (log, (true_div, (sigmoid, "x"), (sub, 1, (sigmoid, "x")))), - "x", - tracks=[sigmoid], - get_nodes=get_clients_at_depth2, - allow_multiple_clients=True, - name="local_logit_sigmoid", +warnings.warn( + "The module `aesara.tensor.math_opt` is deprecated; use `aesara.tensor.rewriting.math` instead.", + DeprecationWarning, + stacklevel=2, ) -register_canonicalize(local_logit_sigmoid) -register_specialize(local_logit_sigmoid) - - -# sigmoid(log(x / (1-x)) -> x -# i.e., sigmoid(logit(x)) -> x -local_sigmoid_logit = PatternSub( - (sigmoid, (log, (true_div, "x", (sub, 1, "x")))), - "x", - allow_multiple_clients=True, - name="local_sigmoid_logit", -) -register_canonicalize(local_sigmoid_logit) -register_specialize(local_sigmoid_logit) - -@register_canonicalize -@register_useless -@local_optimizer([_conj]) -def local_useless_conj(fgraph, node): - r"""Remove `conj` `Op`\s applied to non-imaginary variable types.""" - x = node.inputs[0] - if x.type.dtype not in complex_dtypes: - return [x] +from aesara.tensor.rewriting.math import * # noqa: F401 E402 F403 diff --git a/aesara/tensor/nnet/__init__.py b/aesara/tensor/nnet/__init__.py index 57d7c8f186..a1e9a9d4bf 100644 --- a/aesara/tensor/nnet/__init__.py +++ b/aesara/tensor/nnet/__init__.py @@ -1,6 +1,6 @@ import warnings -import aesara.tensor.nnet.opt +import aesara.tensor.nnet.rewriting from aesara.tensor.nnet.abstract_conv import ( abstract_conv2d, conv2d, diff --git a/aesara/tensor/nnet/abstract_conv.py b/aesara/tensor/nnet/abstract_conv.py index 885b24898b..dbfc0b7b69 100644 --- a/aesara/tensor/nnet/abstract_conv.py +++ b/aesara/tensor/nnet/abstract_conv.py @@ -2763,13 +2763,9 @@ def grad(self, inp, grads): class AbstractConv_gradWeights(BaseAbstractConv): - """Gradient wrt. filters for `AbstractConv`. - Refer to :func:`BaseAbstractConv ` - for a more detailed documentation. + """Gradient with respect to filters for `AbstractConv`. - :note: You will not want to use this directly, but rely on - Aesara's automatic differentiation or graph optimization to - use it as needed. + Refer to :class:`BaseAbstractConv` for more detailed documentation. """ @@ -2991,13 +2987,9 @@ def infer_shape(self, fgraph, node, input_shapes): class AbstractConv2d_gradWeights(AbstractConv_gradWeights): - """Gradient wrt. filters for `AbstractConv2d`. - Refer to :func:`BaseAbstractConv ` - for a more detailed documentation. + """Gradient with respect to filters for `AbstractConv2d`. - :note: You will not want to use this directly, but rely on - Aesara's automatic differentiation or graph optimization to - use it as needed. + Refer to :class:`BaseAbstractConv` for more detailed documentation. """ @@ -3058,13 +3050,9 @@ def grad(self, inp, grads): class AbstractConv3d_gradWeights(AbstractConv_gradWeights): - """Gradient wrt. filters for `AbstractConv3d`. - Refer to :func:`BaseAbstractConv ` - for a more detailed documentation. + """Gradient with respect to filters for `AbstractConv3d`. - :note: You will not want to use this directly, but rely on - Aesara's automatic differentiation or graph optimization to - use it as needed. + Refer to :class:`BaseAbstractConv` for more detailed documentation. """ @@ -3121,13 +3109,9 @@ def grad(self, inp, grads): class AbstractConv_gradInputs(BaseAbstractConv): - """Gradient wrt. inputs for `AbstractConv`. - Refer to :func:`BaseAbstractConv ` - for a more detailed documentation. + """Gradient with respect to inputs for `AbstractConv`. - :note: You will not want to use this directly, but rely on - Aesara's automatic differentiation or graph optimization to - use it as needed. + Refer to :class:`BaseAbstractConv` for more detailed documentation. """ @@ -3373,13 +3357,9 @@ def infer_shape(self, fgraph, node, input_shapes): class AbstractConv2d_gradInputs(AbstractConv_gradInputs): - """Gradient wrt. inputs for `AbstractConv2d`. - Refer to :func:`BaseAbstractConv ` - for a more detailed documentation. + """Gradient with respect to inputs for `AbstractConv2d`. - :note: You will not want to use this directly, but rely on - Aesara's automatic differentiation or graph optimization to - use it as needed. + Refer to :class:`BaseAbstractConv` for more detailed documentation. """ @@ -3440,13 +3420,9 @@ def grad(self, inp, grads): class AbstractConv3d_gradInputs(AbstractConv_gradInputs): - """Gradient wrt. inputs for `AbstractConv3d`. - Refer to :func:`BaseAbstractConv ` - for a more detailed documentation. + """Gradient with respect to inputs for `AbstractConv3d`. - :note: You will not want to use this directly, but rely on - Aesara's automatic differentiation or graph optimization to - use it as needed. + Refer to :class:`BaseAbstractConv` for more detailed documentation. """ diff --git a/aesara/tensor/nnet/basic.py b/aesara/tensor/nnet/basic.py index 3cb3261da6..af1e3d0bd1 100644 --- a/aesara/tensor/nnet/basic.py +++ b/aesara/tensor/nnet/basic.py @@ -18,17 +18,12 @@ from aesara.gradient import DisconnectedType, grad_not_implemented from aesara.graph.basic import Apply from aesara.graph.op import Op -from aesara.graph.opt import copy_stack_trace, local_optimizer, optimizer +from aesara.graph.rewriting.basic import copy_stack_trace, graph_rewriter, node_rewriter from aesara.link.c.op import COp from aesara.raise_op import Assert from aesara.scalar import UnaryScalarOp from aesara.tensor import basic as at from aesara.tensor.basic import ARange, as_tensor_variable -from aesara.tensor.basic_opt import ( - register_canonicalize, - register_specialize, - register_stabilize, -) from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.extra_ops import Unique @@ -50,8 +45,13 @@ ) from aesara.tensor.math import sum as at_sum from aesara.tensor.math import tanh, tensordot, true_div -from aesara.tensor.math_opt import local_mul_canonizer from aesara.tensor.nnet.blocksparse import sparse_block_dot +from aesara.tensor.rewriting.basic import ( + register_canonicalize, + register_specialize, + register_stabilize, +) +from aesara.tensor.rewriting.math import local_mul_canonizer from aesara.tensor.shape import Shape, shape_padleft from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor from aesara.tensor.type import ( @@ -1046,7 +1046,7 @@ def c_code_cache_version(): # This is not registered in stabilize, as it cause some crossentropy # optimization to not be inserted. @register_specialize("stabilize", "fast_compile") -@local_optimizer([Elemwise]) +@node_rewriter([Elemwise]) def local_logsoftmax(fgraph, node): """ Detect Log(Softmax(x)) and replace it with LogSoftmax(x) @@ -1071,7 +1071,7 @@ def local_logsoftmax(fgraph, node): # This is not registered in stabilize, as it cause some crossentropy # optimization to not be inserted. @register_specialize("stabilize", "fast_compile") -@local_optimizer([SoftmaxGrad]) +@node_rewriter([SoftmaxGrad]) def local_logsoftmax_grad(fgraph, node): """ Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad @@ -1150,7 +1150,7 @@ def logsoftmax(c, axis=UNSET_AXIS): @register_specialize("fast_compile") -@local_optimizer([softmax_legacy]) +@node_rewriter([softmax_legacy]) def local_softmax_with_bias(fgraph, node): """ Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias). @@ -1847,18 +1847,8 @@ def grad(self, inp, grads): @register_stabilize("fast_compile") @register_specialize("fast_compile") -@optimizer +@graph_rewriter def crossentropy_to_crossentropy_with_softmax_with_bias(fgraph): - """ - This is a stabilization optimization. - - Notes - ----- - Not a local optimization because we are replacing outputs - from several nodes at once. - - """ - def search_make_one_sub(): for node in fgraph.toposort(): if node.op == crossentropy_categorical_1hot: @@ -1884,21 +1874,16 @@ def search_make_one_sub(): return -@optimizer +@graph_rewriter def crossentropy_to_crossentropy_with_softmax(fgraph): """ - This is a stabilization optimization that is more general than - crossentropy_to_crossentropy_with_softmax_with_bias. - - It must be executed after local_softmax_with_bias optimization in - specialize. - - TODO : This is a stabilization optimization! How to make this more cleanly? + This is a stabilization rewrite that is more general than + `crossentropy_to_crossentropy_with_softmax_with_bias`. Notes ----- - Not a local optimization because we are replacing outputs from several - nodes at once. + It must be executed after `local_softmax_with_bias` during the + specialization passes. """ @@ -1954,7 +1939,7 @@ def search_make_one_sub(): @register_specialize( "fast_compile", "local_crossentropy_to_crossentropy_with_softmax_grad" ) # old name -@local_optimizer([softmax_grad_legacy]) +@node_rewriter([softmax_grad_legacy]) def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node): if node.op == softmax_grad_legacy and node.inputs[1].ndim == 2: g_coding_dist, coding_dist = node.inputs @@ -1971,7 +1956,7 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node): @register_specialize("fast_compile") -@local_optimizer([MaxAndArgmax]) +@node_rewriter([MaxAndArgmax]) def local_argmax_pushdown(fgraph, node): if ( isinstance(node.op, MaxAndArgmax) @@ -2060,7 +2045,7 @@ def _is_const(z, val, approx=False): @register_specialize("fast_compile") -@local_optimizer([AdvancedSubtensor, log]) +@node_rewriter([AdvancedSubtensor, log]) def local_advanced_indexing_crossentropy_onehot(fgraph, node): log_op = None sm = None @@ -2108,7 +2093,7 @@ def local_advanced_indexing_crossentropy_onehot(fgraph, node): @register_specialize("fast_compile") -@local_optimizer([softmax_grad_legacy]) +@node_rewriter([softmax_grad_legacy]) def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node): if not (node.op == softmax_grad_legacy and node.inputs[1].ndim == 2): return @@ -2323,7 +2308,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(fgraph, node): @register_specialize("fast_compile") -@local_optimizer([softmax_with_bias]) +@node_rewriter([softmax_with_bias]) def graph_merge_softmax_with_crossentropy_softmax(fgraph, node): if node.op == softmax_with_bias: x, b = node.inputs @@ -2340,7 +2325,7 @@ def graph_merge_softmax_with_crossentropy_softmax(fgraph, node): @register_specialize @register_stabilize @register_canonicalize -@local_optimizer([CrossentropySoftmax1HotWithBiasDx]) +@node_rewriter([CrossentropySoftmax1HotWithBiasDx]) def local_useless_crossentropy_softmax_1hot_with_bias_dx_alloc(fgraph, node): """ Replace a CrossentropySoftmax1HotWithBiasDx op, whose incoming gradient is diff --git a/aesara/tensor/nnet/batchnorm.py b/aesara/tensor/nnet/batchnorm.py index 1f6bedc6af..2edf3675e4 100644 --- a/aesara/tensor/nnet/batchnorm.py +++ b/aesara/tensor/nnet/batchnorm.py @@ -4,14 +4,14 @@ from aesara.configdefaults import config from aesara.graph.basic import Apply from aesara.graph.op import Op -from aesara.graph.opt import copy_stack_trace, local_optimizer +from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter from aesara.scalar import Composite, add, as_common_dtype, mul, sub, true_div from aesara.tensor import basic as at from aesara.tensor.basic import as_tensor_variable -from aesara.tensor.basic_opt import register_specialize_device from aesara.tensor.elemwise import Elemwise from aesara.tensor.math import mean, prod, reciprocal, sqrt from aesara.tensor.math import sum as at_sum +from aesara.tensor.rewriting.basic import register_specialize_device from aesara.tensor.shape import specify_broadcastable from aesara.tensor.type import TensorType @@ -778,7 +778,7 @@ def perform(self, node, inputs, output_storage): output_storage[2][0] = g_wrt_bias -@local_optimizer([AbstractBatchNormTrain]) +@node_rewriter([AbstractBatchNormTrain]) def local_abstract_batch_norm_train(fgraph, node): if not isinstance(node.op, AbstractBatchNormTrain): return None @@ -832,7 +832,7 @@ def local_abstract_batch_norm_train(fgraph, node): return results -@local_optimizer([AbstractBatchNormTrainGrad]) +@node_rewriter([AbstractBatchNormTrainGrad]) def local_abstract_batch_norm_train_grad(fgraph, node): if not isinstance(node.op, AbstractBatchNormTrainGrad): return None @@ -866,7 +866,7 @@ def local_abstract_batch_norm_train_grad(fgraph, node): return results -@local_optimizer([AbstractBatchNormInference]) +@node_rewriter([AbstractBatchNormInference]) def local_abstract_batch_norm_inference(fgraph, node): if not isinstance(node.op, AbstractBatchNormInference): return None @@ -896,7 +896,7 @@ def local_abstract_batch_norm_inference(fgraph, node): # Register Cpu Optimization -bn_groupopt = aesara.graph.optdb.LocalGroupDB() +bn_groupopt = aesara.graph.rewriting.db.LocalGroupDB() bn_groupopt.__name__ = "batchnorm_opts" register_specialize_device(bn_groupopt, "fast_compile", "fast_run") diff --git a/aesara/tensor/nnet/conv.py b/aesara/tensor/nnet/conv.py index 35bdb24919..0dbb0240c1 100644 --- a/aesara/tensor/nnet/conv.py +++ b/aesara/tensor/nnet/conv.py @@ -46,12 +46,13 @@ def conv2d( subsample=(1, 1), **kargs, ): - """ - Deprecated, old conv2d interface. - This function will build the symbolic graph for convolving a stack of - input images with a set of filters. The implementation is modelled after - Convolutional Neural Networks (CNN). It is simply a wrapper to the ConvOp - but provides a much cleaner interface. + """Build the symbolic graph for convolving a stack of input images with a set of filters. + + The implementation is modelled after Convolutional Neural Networks + (CNN). It is simply a wrapper to the `ConvOp` but provides a much cleaner + interface. + + This is deprecated. Parameters ---------- @@ -402,8 +403,7 @@ def getOutputShape(inshp, kshp, stride=(1, 1), mode="valid"): # with s=1 for mode=='full' and s=-1 for mode=='valid'. # To support symbolic shapes, we express this with integer arithmetic. warnings.warn( - "The method `getOutputShape` is deprecated use" - "`get_conv_output_shape` instead.", + "`getOutputShape` is deprecated; use `get_conv_output_shape` instead.", DeprecationWarning, stacklevel=2, ) diff --git a/aesara/tensor/nnet/conv3d2d.py b/aesara/tensor/nnet/conv3d2d.py index 4ba007ee85..044211d6a0 100644 --- a/aesara/tensor/nnet/conv3d2d.py +++ b/aesara/tensor/nnet/conv3d2d.py @@ -3,7 +3,11 @@ from aesara.gradient import DisconnectedType from aesara.graph.basic import Apply from aesara.graph.op import Op -from aesara.graph.opt import TopoOptimizer, copy_stack_trace, local_optimizer +from aesara.graph.rewriting.basic import ( + WalkingGraphRewriter, + copy_stack_trace, + node_rewriter, +) def get_diagonal_subtensor_view(x, i0, i1): @@ -102,7 +106,10 @@ def __init__(self, inplace=False): def make_node(self, x, i0, i1): _i0 = at.as_tensor_variable(i0) _i1 = at.as_tensor_variable(i1) - return Apply(self, [x, _i0, _i1], [x.type()]) + # TODO: We could produce a more precise static shape output type + type_shape = (1 if shape == 1 else None for shape in x.type.shape) + out_type = at.TensorType(x.type.dtype, shape=type_shape) + return Apply(self, [x, _i0, _i1], [out_type()]) def perform(self, node, inputs, output_storage): xview = get_diagonal_subtensor_view(*inputs) @@ -296,7 +303,7 @@ def conv3d( return out_5d -@local_optimizer([DiagonalSubtensor, IncDiagonalSubtensor]) +@node_rewriter([DiagonalSubtensor, IncDiagonalSubtensor]) def local_inplace_DiagonalSubtensor(fgraph, node): """Also work for IncDiagonalSubtensor.""" if ( @@ -312,8 +319,9 @@ def local_inplace_DiagonalSubtensor(fgraph, node): aesara.compile.optdb.register( "local_inplace_DiagonalSubtensor", - TopoOptimizer( - local_inplace_DiagonalSubtensor, failure_callback=TopoOptimizer.warn_inplace + WalkingGraphRewriter( + local_inplace_DiagonalSubtensor, + failure_callback=WalkingGraphRewriter.warn_inplace, ), "fast_run", "inplace", diff --git a/aesara/tensor/nnet/ctc.py b/aesara/tensor/nnet/ctc.py index 3c7d523526..d3aebeb422 100644 --- a/aesara/tensor/nnet/ctc.py +++ b/aesara/tensor/nnet/ctc.py @@ -5,12 +5,12 @@ from aesara.configdefaults import config from aesara.gradient import grad_undefined from aesara.graph.basic import Apply -from aesara.graph.opt import local_optimizer +from aesara.graph.rewriting.basic import node_rewriter from aesara.link.c.cmodule import GCC_compiler from aesara.link.c.op import ExternalCOp, OpenMPOp -from aesara.tensor.basic_opt import register_canonicalize from aesara.tensor.blas import batched_dot from aesara.tensor.extra_ops import cpu_contiguous +from aesara.tensor.rewriting.basic import register_canonicalize from aesara.tensor.type import ftensor3, fvector @@ -249,7 +249,7 @@ def ctc(activations, labels, input_lengths): # Disable gradient computation if not needed @register_canonicalize("fast_compile") -@local_optimizer([ConnectionistTemporalClassification]) +@node_rewriter([ConnectionistTemporalClassification]) def local_ctc_no_grad(fgraph, node): if isinstance(node.op, ConnectionistTemporalClassification): if len(node.outputs) > 1: diff --git a/aesara/tensor/nnet/opt.py b/aesara/tensor/nnet/opt.py index ccd314967a..0ff97c5217 100644 --- a/aesara/tensor/nnet/opt.py +++ b/aesara/tensor/nnet/opt.py @@ -1,600 +1,10 @@ -""" -Optimizations addressing the ops in nnet root directory -""" - -import aesara -from aesara import compile -from aesara.compile import optdb -from aesara.configdefaults import config -from aesara.graph.opt import ( - LocalMetaOptimizerSkipAssertionError, - TopoOptimizer, - copy_stack_trace, - in2out, - local_optimizer, -) -from aesara.tensor.basic_opt import register_specialize_device -from aesara.tensor.nnet.abstract_conv import ( - AbstractConv2d, - AbstractConv2d_gradInputs, - AbstractConv2d_gradWeights, - AbstractConv3d, - AbstractConv3d_gradInputs, - AbstractConv3d_gradWeights, - get_conv_output_shape, -) -from aesara.tensor.nnet.blocksparse import ( - SparseBlockGemv, - SparseBlockOuter, - sparse_block_gemv_inplace, - sparse_block_outer_inplace, -) - -# Cpu implementation -from aesara.tensor.nnet.conv import ConvOp, conv2d -from aesara.tensor.nnet.corr import CorrMM, CorrMM_gradInputs, CorrMM_gradWeights -from aesara.tensor.nnet.corr3d import Corr3dMM, Corr3dMMGradInputs, Corr3dMMGradWeights -from aesara.tensor.type import TensorType - - -@local_optimizer([SparseBlockGemv], inplace=True) -def local_inplace_sparse_block_gemv(fgraph, node): - """ - SparseBlockGemv(inplace=False) -> SparseBlockGemv(inplace=True) - """ - if isinstance(node.op, SparseBlockGemv) and not node.op.inplace: - new_node = sparse_block_gemv_inplace(*node.inputs) - copy_stack_trace(node.outputs[0], new_node) - return [new_node] - return False - - -compile.optdb.register( - "local_inplace_sparse_block_gemv", - TopoOptimizer( - local_inplace_sparse_block_gemv, failure_callback=TopoOptimizer.warn_inplace - ), - "fast_run", - "inplace", - position=60, -) # DEBUG - - -@local_optimizer([SparseBlockOuter], inplace=True) -def local_inplace_sparse_block_outer(fgraph, node): - """ - SparseBlockOuter(inplace=False) -> SparseBlockOuter(inplace=True) - """ - if isinstance(node.op, SparseBlockOuter) and not node.op.inplace: - new_node = sparse_block_outer_inplace(*node.inputs) - copy_stack_trace(node.outputs[0], new_node) - return [new_node] - return False - - -compile.optdb.register( - "local_inplace_sparse_block_outer", - TopoOptimizer( - local_inplace_sparse_block_outer, - failure_callback=TopoOptimizer.warn_inplace, - ), - "fast_run", - "inplace", - position=60, -) # DEBUG - - -# Conv opts -@local_optimizer([AbstractConv2d]) -def local_abstractconv_gemm(fgraph, node): - # If config.blas__ldflags is empty, Aesara will use - # a NumPy C implementation of [sd]gemm_. - if config.cxx == "" or node.inputs[0].dtype == "float16": - return - if not isinstance(node.op, AbstractConv2d): - return None - img, kern = node.inputs - if not isinstance(img.type, TensorType) or not isinstance(kern.type, TensorType): - return None - - # need to flip the kernel if necessary - if node.op.filter_flip: - flip = (slice(None),) * (kern.ndim - 2) + (slice(None, None, -1),) * 2 - kern = kern[flip] - rval = CorrMM( - border_mode=node.op.border_mode, - subsample=node.op.subsample, - filter_dilation=node.op.filter_dilation, - num_groups=node.op.num_groups, - unshared=node.op.unshared, - )(img, kern) - copy_stack_trace(node.outputs[0], rval) - - return [rval] - - -@local_optimizer([AbstractConv3d]) -def local_abstractconv3d_gemm(fgraph, node): - # If config.blas__ldflags is empty, Aesara will use - # a NumPy C implementation of [sd]gemm_. - if config.cxx == "" or node.inputs[0].dtype == "float16": - return - if not isinstance(node.op, AbstractConv3d): - return None - img, kern = node.inputs - if not isinstance(img.type, TensorType) or not isinstance(kern.type, TensorType): - return None - - # need to flip the kernel if necessary - if node.op.filter_flip: - kern = kern[:, :, ::-1, ::-1, ::-1] - rval = Corr3dMM( - border_mode=node.op.border_mode, - subsample=node.op.subsample, - filter_dilation=node.op.filter_dilation, - num_groups=node.op.num_groups, - )(img, kern) - copy_stack_trace(node.outputs[0], rval) - - return [rval] - - -@local_optimizer([AbstractConv2d_gradWeights]) -def local_abstractconv_gradweight_gemm(fgraph, node): - # If config.blas__ldflags is empty, Aesara will use - # a NumPy C implementation of [sd]gemm_. - if config.cxx == "" or node.inputs[0].dtype == "float16": - return - if not isinstance(node.op, AbstractConv2d_gradWeights): - return None - img, topgrad, shape = node.inputs - if not isinstance(img.type, TensorType) or not isinstance(topgrad.type, TensorType): - return None - - rval = CorrMM_gradWeights( - border_mode=node.op.border_mode, - subsample=node.op.subsample, - filter_dilation=node.op.filter_dilation, - num_groups=node.op.num_groups, - unshared=node.op.unshared, - )(img, topgrad, shape) - copy_stack_trace(node.outputs[0], rval) - - # need to flip the kernel if necessary - if node.op.filter_flip: - flip = (slice(None),) * (rval.ndim - 2) + (slice(None, None, -1),) * 2 - rval = rval[flip] - copy_stack_trace(node.outputs[0], rval) - - return [rval] - - -@local_optimizer([AbstractConv3d_gradWeights]) -def local_abstractconv3d_gradweight_gemm(fgraph, node): - # If config.blas__ldflags is empty, Aesara will use - # a NumPy C implementation of [sd]gemm_. - if config.cxx == "" or node.inputs[0].dtype == "float16": - return - if not isinstance(node.op, AbstractConv3d_gradWeights): - return None - img, topgrad, shape = node.inputs - if not isinstance(img.type, TensorType) or not isinstance(topgrad.type, TensorType): - return None - - rval = Corr3dMMGradWeights( - border_mode=node.op.border_mode, - subsample=node.op.subsample, - filter_dilation=node.op.filter_dilation, - num_groups=node.op.num_groups, - )(img, topgrad, shape) - copy_stack_trace(node.outputs[0], rval) - - # need to flip the kernel if necessary - if node.op.filter_flip: - rval = rval[:, :, ::-1, ::-1, ::-1] - copy_stack_trace(node.outputs[0], rval) - - return [rval] - - -@local_optimizer([AbstractConv2d_gradInputs]) -def local_abstractconv_gradinputs_gemm(fgraph, node): - # If config.blas__ldflags is empty, Aesara will use - # a NumPy C implementation of [sd]gemm_. - if config.cxx == "" or node.inputs[0].dtype == "float16": - return - if not isinstance(node.op, AbstractConv2d_gradInputs): - return None - kern, topgrad, shape = node.inputs - if not isinstance(kern.type, TensorType) or not isinstance( - topgrad.type, TensorType - ): - return None - - # need to flip the kernel if necessary - if node.op.filter_flip: - flip = (slice(None),) * (kern.ndim - 2) + (slice(None, None, -1),) * 2 - kern = kern[flip] - rval = CorrMM_gradInputs( - border_mode=node.op.border_mode, - subsample=node.op.subsample, - filter_dilation=node.op.filter_dilation, - num_groups=node.op.num_groups, - unshared=node.op.unshared, - )(kern, topgrad, shape) - copy_stack_trace(node.outputs[0], rval) - - return [rval] - - -@local_optimizer([AbstractConv3d_gradInputs]) -def local_abstractconv3d_gradinputs_gemm(fgraph, node): - # If config.blas__ldflags is empty, Aesara will use - # a NumPy C implementation of [sd]gemm_. - if config.cxx == "" or node.inputs[0].dtype == "float16": - return - if not isinstance(node.op, AbstractConv3d_gradInputs): - return None - kern, topgrad, shape = node.inputs - if not isinstance(kern.type, TensorType) or not isinstance( - topgrad.type, TensorType - ): - return None - - # need to flip the kernel if necessary - if node.op.filter_flip: - kern = kern[:, :, ::-1, ::-1, ::-1] - rval = Corr3dMMGradInputs( - border_mode=node.op.border_mode, - subsample=node.op.subsample, - filter_dilation=node.op.filter_dilation, - num_groups=node.op.num_groups, - )(kern, topgrad, shape) - copy_stack_trace(node.outputs[0], rval) - - return [rval] - - -@local_optimizer([AbstractConv2d]) -def local_conv2d_cpu(fgraph, node): - - if not isinstance(node.op, AbstractConv2d) or node.inputs[0].dtype == "float16": - return None - - img, kern = node.inputs - if not isinstance(img.type, TensorType) or not isinstance(kern.type, TensorType): - return None - if node.op.border_mode not in ("full", "valid"): - return None - if not node.op.filter_flip: - # Not tested yet - return None - if node.op.num_groups > 1 or node.op.unshared: - return None - if node.op.filter_dilation != (1, 1): - return None - - rval = conv2d( - img, - kern, - node.op.imshp, - node.op.kshp, - border_mode=node.op.border_mode, - subsample=node.op.subsample, - ) - - copy_stack_trace(node.outputs[0], rval) - return [rval] +import warnings -@local_optimizer([AbstractConv2d_gradWeights]) -def local_conv2d_gradweight_cpu(fgraph, node): - if ( - not isinstance(node.op, AbstractConv2d_gradWeights) - or node.inputs[0].dtype == "float16" - ): - return None - - img, topgrad, shape = node.inputs - - if not isinstance(img.type, TensorType) or not isinstance(topgrad.type, TensorType): - return None - if node.op.border_mode not in ("full", "valid"): - return None - if not node.op.filter_flip: - # Not tested yet - return - if node.op.num_groups > 1 or node.op.unshared: - return None - - if node.op.border_mode == "valid" and (node.op.subsample != (1, 1)): - return None - - dx, dy = node.op.subsample - if dx not in (1, 2) or dy not in (1, 2): - # Not implemented in the gradient of ConvOp - return None - - if node.op.imshp is None: - op_imshp = (None, None, None, None) - else: - op_imshp = node.op.imshp - - if node.op.kshp is None: - op_kshp = (None, None, None, None) - else: - op_kshp = node.op.kshp - - if None in op_imshp or None in op_kshp: - if (dx, dy) != (1, 1): - # We cannot infer the shapes - return None - - # Determine gradient on kernels - assert len(op_imshp) == 4 and len(op_kshp) == 4 - - outshp = get_conv_output_shape( - op_imshp, - op_kshp, - node.op.border_mode, - node.op.subsample, - node.op.filter_dilation, - )[2:] - fulloutshp = get_conv_output_shape(op_imshp, op_kshp, node.op.border_mode, (1, 1))[ - 2: - ] - - newimg = img.dimshuffle((1, 0, 2, 3)) - newtopgrad = topgrad.dimshuffle((1, 0, 2, 3)) - - if node.op.border_mode == "valid": - (img, filters) = (newimg, newtopgrad) - kshp_logical = fulloutshp - kshp_logical_top_aligned = False - imshp_logical = None - (bsize, nkern) = (op_imshp[1], op_kshp[0]) - imshp = (op_imshp[0], op_imshp[2], op_imshp[3]) - kshp = outshp - elif node.op.border_mode == "full": - (img, filters) = (newtopgrad, newimg) - kshp_logical = None - kshp_logical_top_aligned = True - imshp_logical = (op_imshp[0], fulloutshp[0], fulloutshp[1]) - (bsize, nkern) = (op_kshp[0], op_imshp[1]) - imshp = (op_imshp[0], outshp[0], outshp[1]) - kshp = op_imshp[2:] - else: - raise NotImplementedError("Only [full,valid] modes are currently supported.") - - # Flip the kernels - filters = filters[:, :, ::-1, ::-1] - - dw = ConvOp( - imshp, - kshp, - nkern, - bsize, - 1, - 1, - output_mode="valid", - unroll_batch=None, - unroll_kern=None, - unroll_patch=None, - imshp_logical=imshp_logical, - kshp_logical=kshp_logical, - kshp_logical_top_aligned=kshp_logical_top_aligned, - direction_hint="bprop weights", - ) - res = dw(img, filters) - copy_stack_trace(node.outputs[0], res) - - if node.op.border_mode == "valid": - res = res.dimshuffle((1, 0, 2, 3)) - res = res[:, :, ::-1, ::-1] - copy_stack_trace(node.outputs[0], res) - - return [res] - - -@local_optimizer([AbstractConv2d_gradInputs]) -def local_conv2d_gradinputs_cpu(fgraph, node): - if ( - not isinstance(node.op, AbstractConv2d_gradInputs) - or node.inputs[0].dtype == "float16" - ): - return None - - kern, topgrad, shape = node.inputs - - if not isinstance(kern.type, TensorType) or not isinstance( - topgrad.type, TensorType - ): - return None - if node.op.border_mode not in ("full", "valid"): - return None - if not node.op.filter_flip: - # Not tested yet - return None - if node.op.num_groups > 1 or node.op.unshared: - return None - - # Conv 3d implementation, needed when subsample > 2 - if node.op.border_mode == "valid" and node.op.subsample != (1, 1): - # The op don't support that anymore. - return False - - # Conv2d Implementation - dx, dy = node.op.subsample - if dx not in (1, 2) or dy not in (1, 2): - # Not implemented in the gradient of ConvOp - return None - - if node.op.imshp is None: - op_imshp = (None, None, None, None) - else: - op_imshp = node.op.imshp - - if node.op.kshp is None: - op_kshp = (None, None, None, None) - else: - op_kshp = node.op.kshp - - if None in op_imshp or None in op_kshp: - if (dx, dy) != (1, 1): - return None - - mode = "valid" - if node.op.border_mode != "full": - mode = "full" - filters = kern.dimshuffle((1, 0, 2, 3)) - filters = filters[:, :, ::-1, ::-1] - - outshp = get_conv_output_shape( - op_imshp, - op_kshp, - node.op.border_mode, - node.op.subsample, - node.op.filter_dilation, - )[2:] - fulloutshp = get_conv_output_shape(op_imshp, op_kshp, node.op.border_mode, (1, 1))[ - 2: - ] - - nkern = op_imshp[1] - imshp = (op_kshp[0], outshp[0], outshp[1]) - imshp_logical = (op_kshp[0], fulloutshp[0], fulloutshp[1]) - din = ConvOp( - imshp, - op_kshp[2:], - nkern, - op_imshp[0], - 1, - 1, - output_mode=mode, - unroll_batch=None, - unroll_kern=None, - unroll_patch=None, - imshp_logical=imshp_logical, - kshp_logical=None, - version=-1, - direction_hint="bprop inputs", - ) - din = din(topgrad, filters) - copy_stack_trace(node.outputs[0], din) - return [din] - - -# Register Cpu Optimization -conv_groupopt = aesara.graph.optdb.LocalGroupDB() -conv_groupopt.__name__ = "conv_opts" -register_specialize_device(conv_groupopt, "fast_compile", "fast_run") - -# GEMM-based convolution -# It can be disabled by excluding 'conv_gemm'. -conv_groupopt.register( - "local_abstractconv_gemm", - local_abstractconv_gemm, - "conv_gemm", - "fast_compile", - "fast_run", - position=30, -) -conv_groupopt.register( - "local_abstractconv_gradweight_gemm", - local_abstractconv_gradweight_gemm, - "conv_gemm", - "fast_compile", - "fast_run", - position=30, -) -conv_groupopt.register( - "local_abstractconv_gradinputs_gemm", - local_abstractconv_gradinputs_gemm, - "conv_gemm", - "fast_compile", - "fast_run", - position=30, -) -conv_groupopt.register( - "local_abstractconv3d_gemm", - local_abstractconv3d_gemm, - "conv_gemm", - "fast_compile", - "fast_run", - position=30, -) -conv_groupopt.register( - "local_abstractconv3d_gradweight_gemm", - local_abstractconv3d_gradweight_gemm, - "conv_gemm", - "fast_compile", - "fast_run", - position=30, -) -conv_groupopt.register( - "local_abstractconv3d_gradinputs_gemm", - local_abstractconv3d_gradinputs_gemm, - "conv_gemm", - "fast_compile", - "fast_run", - position=30, -) - -# Legacy convolution -conv_groupopt.register( - "local_conv2d_cpu", local_conv2d_cpu, "fast_compile", "fast_run", position=40 -) -conv_groupopt.register( - "local_conv2d_gradweight_cpu", - local_conv2d_gradweight_cpu, - "fast_compile", - "fast_run", - position=40, -) -conv_groupopt.register( - "local_conv2d_gradinputs_cpu", - local_conv2d_gradinputs_cpu, - "fast_compile", - "fast_run", - position=40, +warnings.warn( + "The module `aesara.tensor.nnet.opt` is deprecated; use `aesara.tensor.nnet.rewriting` instead.", + DeprecationWarning, + stacklevel=2, ) - -# Verify that no AbstractConv are present in the graph -@local_optimizer( - [ - AbstractConv2d, - AbstractConv2d_gradWeights, - AbstractConv2d_gradInputs, - AbstractConv3d, - AbstractConv3d_gradWeights, - AbstractConv3d_gradInputs, - ] -) -def local_abstractconv_check(fgraph, node): - if isinstance( - node.op, - ( - AbstractConv2d, - AbstractConv2d_gradWeights, - AbstractConv2d_gradInputs, - AbstractConv3d, - AbstractConv3d_gradWeights, - AbstractConv3d_gradInputs, - ), - ): - raise LocalMetaOptimizerSkipAssertionError( - f"{node.op.__class__.__name__} Aesara optimization failed: there is no implementation " - "available supporting the requested options. If on CPU, " - "do you have a BLAS library installed Aesara can link against? " - "On the CPU we do not support float16." - ) - - -optdb.register( - "AbstractConvCheck", - in2out(local_abstractconv_check, name="AbstractConvCheck"), - "fast_compile", - "fast_run", - position=48.7, -) +from aesara.tensor.nnet.rewriting import * # noqa: F401 E402 F403 diff --git a/aesara/tensor/nnet/rewriting.py b/aesara/tensor/nnet/rewriting.py new file mode 100644 index 0000000000..3a32e557c7 --- /dev/null +++ b/aesara/tensor/nnet/rewriting.py @@ -0,0 +1,601 @@ +""" +Optimizations addressing the ops in nnet root directory +""" + +import aesara +from aesara import compile +from aesara.compile import optdb +from aesara.configdefaults import config +from aesara.graph.rewriting.basic import ( + MetaNodeRewriterSkip, + WalkingGraphRewriter, + copy_stack_trace, + in2out, + node_rewriter, +) +from aesara.tensor.nnet.abstract_conv import ( + AbstractConv2d, + AbstractConv2d_gradInputs, + AbstractConv2d_gradWeights, + AbstractConv3d, + AbstractConv3d_gradInputs, + AbstractConv3d_gradWeights, + get_conv_output_shape, +) +from aesara.tensor.nnet.blocksparse import ( + SparseBlockGemv, + SparseBlockOuter, + sparse_block_gemv_inplace, + sparse_block_outer_inplace, +) + +# Cpu implementation +from aesara.tensor.nnet.conv import ConvOp, conv2d +from aesara.tensor.nnet.corr import CorrMM, CorrMM_gradInputs, CorrMM_gradWeights +from aesara.tensor.nnet.corr3d import Corr3dMM, Corr3dMMGradInputs, Corr3dMMGradWeights +from aesara.tensor.rewriting.basic import register_specialize_device +from aesara.tensor.type import TensorType + + +@node_rewriter([SparseBlockGemv], inplace=True) +def local_inplace_sparse_block_gemv(fgraph, node): + """ + SparseBlockGemv(inplace=False) -> SparseBlockGemv(inplace=True) + """ + if isinstance(node.op, SparseBlockGemv) and not node.op.inplace: + new_node = sparse_block_gemv_inplace(*node.inputs) + copy_stack_trace(node.outputs[0], new_node) + return [new_node] + return False + + +compile.optdb.register( + "local_inplace_sparse_block_gemv", + WalkingGraphRewriter( + local_inplace_sparse_block_gemv, + failure_callback=WalkingGraphRewriter.warn_inplace, + ), + "fast_run", + "inplace", + position=60, +) + + +@node_rewriter([SparseBlockOuter], inplace=True) +def local_inplace_sparse_block_outer(fgraph, node): + """ + SparseBlockOuter(inplace=False) -> SparseBlockOuter(inplace=True) + """ + if isinstance(node.op, SparseBlockOuter) and not node.op.inplace: + new_node = sparse_block_outer_inplace(*node.inputs) + copy_stack_trace(node.outputs[0], new_node) + return [new_node] + return False + + +compile.optdb.register( + "local_inplace_sparse_block_outer", + WalkingGraphRewriter( + local_inplace_sparse_block_outer, + failure_callback=WalkingGraphRewriter.warn_inplace, + ), + "fast_run", + "inplace", + position=60, +) + + +# Conv opts +@node_rewriter([AbstractConv2d]) +def local_abstractconv_gemm(fgraph, node): + # If config.blas__ldflags is empty, Aesara will use + # a NumPy C implementation of [sd]gemm_. + if config.cxx == "" or node.inputs[0].dtype == "float16": + return + if not isinstance(node.op, AbstractConv2d): + return None + img, kern = node.inputs + if not isinstance(img.type, TensorType) or not isinstance(kern.type, TensorType): + return None + + # need to flip the kernel if necessary + if node.op.filter_flip: + flip = (slice(None),) * (kern.ndim - 2) + (slice(None, None, -1),) * 2 + kern = kern[flip] + rval = CorrMM( + border_mode=node.op.border_mode, + subsample=node.op.subsample, + filter_dilation=node.op.filter_dilation, + num_groups=node.op.num_groups, + unshared=node.op.unshared, + )(img, kern) + copy_stack_trace(node.outputs[0], rval) + + return [rval] + + +@node_rewriter([AbstractConv3d]) +def local_abstractconv3d_gemm(fgraph, node): + # If config.blas__ldflags is empty, Aesara will use + # a NumPy C implementation of [sd]gemm_. + if config.cxx == "" or node.inputs[0].dtype == "float16": + return + if not isinstance(node.op, AbstractConv3d): + return None + img, kern = node.inputs + if not isinstance(img.type, TensorType) or not isinstance(kern.type, TensorType): + return None + + # need to flip the kernel if necessary + if node.op.filter_flip: + kern = kern[:, :, ::-1, ::-1, ::-1] + rval = Corr3dMM( + border_mode=node.op.border_mode, + subsample=node.op.subsample, + filter_dilation=node.op.filter_dilation, + num_groups=node.op.num_groups, + )(img, kern) + copy_stack_trace(node.outputs[0], rval) + + return [rval] + + +@node_rewriter([AbstractConv2d_gradWeights]) +def local_abstractconv_gradweight_gemm(fgraph, node): + # If config.blas__ldflags is empty, Aesara will use + # a NumPy C implementation of [sd]gemm_. + if config.cxx == "" or node.inputs[0].dtype == "float16": + return + if not isinstance(node.op, AbstractConv2d_gradWeights): + return None + img, topgrad, shape = node.inputs + if not isinstance(img.type, TensorType) or not isinstance(topgrad.type, TensorType): + return None + + rval = CorrMM_gradWeights( + border_mode=node.op.border_mode, + subsample=node.op.subsample, + filter_dilation=node.op.filter_dilation, + num_groups=node.op.num_groups, + unshared=node.op.unshared, + )(img, topgrad, shape) + copy_stack_trace(node.outputs[0], rval) + + # need to flip the kernel if necessary + if node.op.filter_flip: + flip = (slice(None),) * (rval.ndim - 2) + (slice(None, None, -1),) * 2 + rval = rval[flip] + copy_stack_trace(node.outputs[0], rval) + + return [rval] + + +@node_rewriter([AbstractConv3d_gradWeights]) +def local_abstractconv3d_gradweight_gemm(fgraph, node): + # If config.blas__ldflags is empty, Aesara will use + # a NumPy C implementation of [sd]gemm_. + if config.cxx == "" or node.inputs[0].dtype == "float16": + return + if not isinstance(node.op, AbstractConv3d_gradWeights): + return None + img, topgrad, shape = node.inputs + if not isinstance(img.type, TensorType) or not isinstance(topgrad.type, TensorType): + return None + + rval = Corr3dMMGradWeights( + border_mode=node.op.border_mode, + subsample=node.op.subsample, + filter_dilation=node.op.filter_dilation, + num_groups=node.op.num_groups, + )(img, topgrad, shape) + copy_stack_trace(node.outputs[0], rval) + + # need to flip the kernel if necessary + if node.op.filter_flip: + rval = rval[:, :, ::-1, ::-1, ::-1] + copy_stack_trace(node.outputs[0], rval) + + return [rval] + + +@node_rewriter([AbstractConv2d_gradInputs]) +def local_abstractconv_gradinputs_gemm(fgraph, node): + # If config.blas__ldflags is empty, Aesara will use + # a NumPy C implementation of [sd]gemm_. + if config.cxx == "" or node.inputs[0].dtype == "float16": + return + if not isinstance(node.op, AbstractConv2d_gradInputs): + return None + kern, topgrad, shape = node.inputs + if not isinstance(kern.type, TensorType) or not isinstance( + topgrad.type, TensorType + ): + return None + + # need to flip the kernel if necessary + if node.op.filter_flip: + flip = (slice(None),) * (kern.ndim - 2) + (slice(None, None, -1),) * 2 + kern = kern[flip] + rval = CorrMM_gradInputs( + border_mode=node.op.border_mode, + subsample=node.op.subsample, + filter_dilation=node.op.filter_dilation, + num_groups=node.op.num_groups, + unshared=node.op.unshared, + )(kern, topgrad, shape) + copy_stack_trace(node.outputs[0], rval) + + return [rval] + + +@node_rewriter([AbstractConv3d_gradInputs]) +def local_abstractconv3d_gradinputs_gemm(fgraph, node): + # If config.blas__ldflags is empty, Aesara will use + # a NumPy C implementation of [sd]gemm_. + if config.cxx == "" or node.inputs[0].dtype == "float16": + return + if not isinstance(node.op, AbstractConv3d_gradInputs): + return None + kern, topgrad, shape = node.inputs + if not isinstance(kern.type, TensorType) or not isinstance( + topgrad.type, TensorType + ): + return None + + # need to flip the kernel if necessary + if node.op.filter_flip: + kern = kern[:, :, ::-1, ::-1, ::-1] + rval = Corr3dMMGradInputs( + border_mode=node.op.border_mode, + subsample=node.op.subsample, + filter_dilation=node.op.filter_dilation, + num_groups=node.op.num_groups, + )(kern, topgrad, shape) + copy_stack_trace(node.outputs[0], rval) + + return [rval] + + +@node_rewriter([AbstractConv2d]) +def local_conv2d_cpu(fgraph, node): + + if not isinstance(node.op, AbstractConv2d) or node.inputs[0].dtype == "float16": + return None + + img, kern = node.inputs + if not isinstance(img.type, TensorType) or not isinstance(kern.type, TensorType): + return None + if node.op.border_mode not in ("full", "valid"): + return None + if not node.op.filter_flip: + # Not tested yet + return None + if node.op.num_groups > 1 or node.op.unshared: + return None + if node.op.filter_dilation != (1, 1): + return None + + rval = conv2d( + img, + kern, + node.op.imshp, + node.op.kshp, + border_mode=node.op.border_mode, + subsample=node.op.subsample, + ) + + copy_stack_trace(node.outputs[0], rval) + return [rval] + + +@node_rewriter([AbstractConv2d_gradWeights]) +def local_conv2d_gradweight_cpu(fgraph, node): + if ( + not isinstance(node.op, AbstractConv2d_gradWeights) + or node.inputs[0].dtype == "float16" + ): + return None + + img, topgrad, shape = node.inputs + + if not isinstance(img.type, TensorType) or not isinstance(topgrad.type, TensorType): + return None + if node.op.border_mode not in ("full", "valid"): + return None + if not node.op.filter_flip: + # Not tested yet + return + if node.op.num_groups > 1 or node.op.unshared: + return None + + if node.op.border_mode == "valid" and (node.op.subsample != (1, 1)): + return None + + dx, dy = node.op.subsample + if dx not in (1, 2) or dy not in (1, 2): + # Not implemented in the gradient of ConvOp + return None + + if node.op.imshp is None: + op_imshp = (None, None, None, None) + else: + op_imshp = node.op.imshp + + if node.op.kshp is None: + op_kshp = (None, None, None, None) + else: + op_kshp = node.op.kshp + + if None in op_imshp or None in op_kshp: + if (dx, dy) != (1, 1): + # We cannot infer the shapes + return None + + # Determine gradient on kernels + assert len(op_imshp) == 4 and len(op_kshp) == 4 + + outshp = get_conv_output_shape( + op_imshp, + op_kshp, + node.op.border_mode, + node.op.subsample, + node.op.filter_dilation, + )[2:] + fulloutshp = get_conv_output_shape(op_imshp, op_kshp, node.op.border_mode, (1, 1))[ + 2: + ] + + newimg = img.dimshuffle((1, 0, 2, 3)) + newtopgrad = topgrad.dimshuffle((1, 0, 2, 3)) + + if node.op.border_mode == "valid": + (img, filters) = (newimg, newtopgrad) + kshp_logical = fulloutshp + kshp_logical_top_aligned = False + imshp_logical = None + (bsize, nkern) = (op_imshp[1], op_kshp[0]) + imshp = (op_imshp[0], op_imshp[2], op_imshp[3]) + kshp = outshp + elif node.op.border_mode == "full": + (img, filters) = (newtopgrad, newimg) + kshp_logical = None + kshp_logical_top_aligned = True + imshp_logical = (op_imshp[0], fulloutshp[0], fulloutshp[1]) + (bsize, nkern) = (op_kshp[0], op_imshp[1]) + imshp = (op_imshp[0], outshp[0], outshp[1]) + kshp = op_imshp[2:] + else: + raise NotImplementedError("Only [full,valid] modes are currently supported.") + + # Flip the kernels + filters = filters[:, :, ::-1, ::-1] + + dw = ConvOp( + imshp, + kshp, + nkern, + bsize, + 1, + 1, + output_mode="valid", + unroll_batch=None, + unroll_kern=None, + unroll_patch=None, + imshp_logical=imshp_logical, + kshp_logical=kshp_logical, + kshp_logical_top_aligned=kshp_logical_top_aligned, + direction_hint="bprop weights", + ) + res = dw(img, filters) + copy_stack_trace(node.outputs[0], res) + + if node.op.border_mode == "valid": + res = res.dimshuffle((1, 0, 2, 3)) + res = res[:, :, ::-1, ::-1] + copy_stack_trace(node.outputs[0], res) + + return [res] + + +@node_rewriter([AbstractConv2d_gradInputs]) +def local_conv2d_gradinputs_cpu(fgraph, node): + if ( + not isinstance(node.op, AbstractConv2d_gradInputs) + or node.inputs[0].dtype == "float16" + ): + return None + + kern, topgrad, shape = node.inputs + + if not isinstance(kern.type, TensorType) or not isinstance( + topgrad.type, TensorType + ): + return None + if node.op.border_mode not in ("full", "valid"): + return None + if not node.op.filter_flip: + # Not tested yet + return None + if node.op.num_groups > 1 or node.op.unshared: + return None + + # Conv 3d implementation, needed when subsample > 2 + if node.op.border_mode == "valid" and node.op.subsample != (1, 1): + # The op don't support that anymore. + return False + + # Conv2d Implementation + dx, dy = node.op.subsample + if dx not in (1, 2) or dy not in (1, 2): + # Not implemented in the gradient of ConvOp + return None + + if node.op.imshp is None: + op_imshp = (None, None, None, None) + else: + op_imshp = node.op.imshp + + if node.op.kshp is None: + op_kshp = (None, None, None, None) + else: + op_kshp = node.op.kshp + + if None in op_imshp or None in op_kshp: + if (dx, dy) != (1, 1): + return None + + mode = "valid" + if node.op.border_mode != "full": + mode = "full" + filters = kern.dimshuffle((1, 0, 2, 3)) + filters = filters[:, :, ::-1, ::-1] + + outshp = get_conv_output_shape( + op_imshp, + op_kshp, + node.op.border_mode, + node.op.subsample, + node.op.filter_dilation, + )[2:] + fulloutshp = get_conv_output_shape(op_imshp, op_kshp, node.op.border_mode, (1, 1))[ + 2: + ] + + nkern = op_imshp[1] + imshp = (op_kshp[0], outshp[0], outshp[1]) + imshp_logical = (op_kshp[0], fulloutshp[0], fulloutshp[1]) + din = ConvOp( + imshp, + op_kshp[2:], + nkern, + op_imshp[0], + 1, + 1, + output_mode=mode, + unroll_batch=None, + unroll_kern=None, + unroll_patch=None, + imshp_logical=imshp_logical, + kshp_logical=None, + version=-1, + direction_hint="bprop inputs", + ) + din = din(topgrad, filters) + copy_stack_trace(node.outputs[0], din) + return [din] + + +# Register Cpu Optimization +conv_groupopt = aesara.graph.rewriting.db.LocalGroupDB() +conv_groupopt.__name__ = "conv_opts" +register_specialize_device(conv_groupopt, "fast_compile", "fast_run") + +# GEMM-based convolution +# It can be disabled by excluding 'conv_gemm'. +conv_groupopt.register( + "local_abstractconv_gemm", + local_abstractconv_gemm, + "conv_gemm", + "fast_compile", + "fast_run", + position=30, +) +conv_groupopt.register( + "local_abstractconv_gradweight_gemm", + local_abstractconv_gradweight_gemm, + "conv_gemm", + "fast_compile", + "fast_run", + position=30, +) +conv_groupopt.register( + "local_abstractconv_gradinputs_gemm", + local_abstractconv_gradinputs_gemm, + "conv_gemm", + "fast_compile", + "fast_run", + position=30, +) +conv_groupopt.register( + "local_abstractconv3d_gemm", + local_abstractconv3d_gemm, + "conv_gemm", + "fast_compile", + "fast_run", + position=30, +) +conv_groupopt.register( + "local_abstractconv3d_gradweight_gemm", + local_abstractconv3d_gradweight_gemm, + "conv_gemm", + "fast_compile", + "fast_run", + position=30, +) +conv_groupopt.register( + "local_abstractconv3d_gradinputs_gemm", + local_abstractconv3d_gradinputs_gemm, + "conv_gemm", + "fast_compile", + "fast_run", + position=30, +) + +# Legacy convolution +conv_groupopt.register( + "local_conv2d_cpu", local_conv2d_cpu, "fast_compile", "fast_run", position=40 +) +conv_groupopt.register( + "local_conv2d_gradweight_cpu", + local_conv2d_gradweight_cpu, + "fast_compile", + "fast_run", + position=40, +) +conv_groupopt.register( + "local_conv2d_gradinputs_cpu", + local_conv2d_gradinputs_cpu, + "fast_compile", + "fast_run", + position=40, +) + + +# Verify that no AbstractConv are present in the graph +@node_rewriter( + [ + AbstractConv2d, + AbstractConv2d_gradWeights, + AbstractConv2d_gradInputs, + AbstractConv3d, + AbstractConv3d_gradWeights, + AbstractConv3d_gradInputs, + ] +) +def local_abstractconv_check(fgraph, node): + if isinstance( + node.op, + ( + AbstractConv2d, + AbstractConv2d_gradWeights, + AbstractConv2d_gradInputs, + AbstractConv3d, + AbstractConv3d_gradWeights, + AbstractConv3d_gradInputs, + ), + ): + raise MetaNodeRewriterSkip( + f"{node.op.__class__.__name__} Aesara rewriting failed: there is no implementation " + "available supporting the requested options. If on CPU, " + "do you have a BLAS library installed Aesara can link against? " + "On the CPU we do not support float16." + ) + + +optdb.register( + "AbstractConvCheck", + in2out(local_abstractconv_check, name="AbstractConvCheck"), + "fast_compile", + "fast_run", + position=48.7, +) diff --git a/aesara/tensor/nnet/sigm.py b/aesara/tensor/nnet/sigm.py index 6c9f93244a..a351192741 100644 --- a/aesara/tensor/nnet/sigm.py +++ b/aesara/tensor/nnet/sigm.py @@ -1,6 +1,4 @@ """ -Ops and optimizations: sigmoid, softplus. - These functions implement special cases of exp and log to improve numerical stability. @@ -9,7 +7,7 @@ import aesara from aesara import printing from aesara import scalar as aes -from aesara.graph.opt import copy_stack_trace, local_optimizer +from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter from aesara.printing import pprint from aesara.scalar import sigmoid as scalar_sigmoid from aesara.scalar.math import Sigmoid @@ -98,8 +96,7 @@ def c_code_cache_version(): pprint.assign(ultra_fast_sigmoid, printing.FunctionPrinter(["ultra_fast_sigmoid"])) -# @opt.register_uncanonicalize -@local_optimizer(None) +@node_rewriter(None) def local_ultra_fast_sigmoid(fgraph, node): """ When enabled, change all sigmoid to ultra_fast_sigmoid. @@ -158,8 +155,7 @@ def hard_sigmoid(x): return x -# @opt.register_uncanonicalize -@local_optimizer([sigmoid]) +@node_rewriter([sigmoid]) def local_hard_sigmoid(fgraph, node): if isinstance(node.op, Elemwise) and node.op.scalar_op == scalar_sigmoid: out = hard_sigmoid(node.inputs[0]) diff --git a/aesara/tensor/opt_uncanonicalize.py b/aesara/tensor/opt_uncanonicalize.py index 3d798461fe..7994fc8457 100644 --- a/aesara/tensor/opt_uncanonicalize.py +++ b/aesara/tensor/opt_uncanonicalize.py @@ -1,256 +1,10 @@ -""" -This file implement specialization optimization that break the -canonization form of the graph. +import warnings -Currently there is problem with the order of optimization and the -definition of definition of canonized graph. -Right now there is a canonization optimization phase that try to make -all equivalent graph identical. This is not always the case, but it do -many of the basic stuff canonical. We need to extend the definition of -canonization to make this true more often. +warnings.warn( + "The module `aesara.tensor.opt_uncanonicalize` is deprecated; use `aesara.tensor.rewriting.uncanonicalize` instead.", + DeprecationWarning, + stacklevel=2, +) -The problem this file indent to fix in the future is that in the -"Equilibrium" specialization optimization phase, there is optimization -that request that the graph is canonical, some other request that this -is not true, and some other that break the canonicalization for some -optimization. As we can't control the order of those optimization, there -is case that some optimization requesting a canonical graph won't be -applied as optimization that break the canonicalization form of the -graph executed before. - -To fix this, we need to split the specialization phase into a phase -where optimization can't break the canonicalization form and one where -this is allowed. This is also needed for the stabilized optimization -phase, but as it happen before the specialization phase, this cause less -problem. - -Also, we should make the fgraph refuse optimization that break the -canonization of the graph in the optimizations phases where the graph is -supposed to be canonical. - -""" - -import logging - -from aesara import scalar as aes -from aesara.graph.opt import copy_stack_trace, local_optimizer -from aesara.tensor.basic import Alloc, alloc, constant -from aesara.tensor.basic_opt import register_uncanonicalize -from aesara.tensor.elemwise import CAReduce, DimShuffle -from aesara.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg -from aesara.tensor.shape import Reshape, reshape -from aesara.tensor.subtensor import Subtensor - - -_logger = logging.getLogger("aesara.tensor.opt_uncanonicalize") - - -@register_uncanonicalize -@local_optimizer([MaxAndArgmax]) -def local_max_and_argmax(fgraph, node): - """ - If we don't use the argmax, change it to a max only. - """ - if isinstance(node.op, MaxAndArgmax): - axis = node.op.get_params(node) - if len(fgraph.clients[node.outputs[1]]) == 0: - new = Max(axis)(node.inputs[0]) - copy_stack_trace(node.outputs[0], new) - return [new, None] - - if len(fgraph.clients[node.outputs[0]]) == 0: - new = Argmax(axis)(node.inputs[0]) - copy_stack_trace(node.outputs[0], new) - return [None, new] - - -@register_uncanonicalize -@local_optimizer([neg]) -def local_max_to_min(fgraph, node): - """ - Change -(max(-x)) to min. - - This is tested in tensor/tests/test_basic.py:test_min_max. - - Notes - ----- - We don't need an opt that will do the reverse as by default - the interface put only MaxAndArgmax into the graph. - - """ - if node.op == neg and node.inputs[0].owner: - max = node.inputs[0] - if ( - max.owner - and isinstance(max.owner.op, CAReduce) - and max.owner.op.scalar_op == aes.scalar_maximum - ): - neg_node = max.owner.inputs[0] - if neg_node.owner and neg_node.owner.op == neg: - new = Min(max.owner.op.axis)(neg_node.owner.inputs[0]) - return [copy_stack_trace(node.outputs[0], new)] - - return False - - -@register_uncanonicalize -@local_optimizer([Alloc]) -def local_alloc_dimshuffle(fgraph, node): - """ - If a dimshuffle is inside an alloc and only adds dimension to the - left, remove it. - - Alloc(DimShuffle(x), ...) - > Alloc(x, ...) - """ - if isinstance(node.op, Alloc): - input_ = node.inputs[0] - if input_.owner and isinstance(input_.owner.op, DimShuffle): - # check if it only adds dimension to the left - new_order = input_.owner.op.new_order - expected_new_order = ("x",) * ( - input_.ndim - input_.owner.inputs[0].ndim - ) + tuple(range(input_.owner.inputs[0].ndim)) - if new_order != expected_new_order: - return False - return [alloc(input_.owner.inputs[0], *node.inputs[1:])] - return False - - -@register_uncanonicalize -@local_optimizer([Reshape]) -def local_reshape_dimshuffle(fgraph, node): - """ - If a dimshuffle is inside a reshape and does not change the order - of dimensions, remove it. - - Reshape(Dimshuffle(x), shp) -> Reshape(x, shp) - """ - if isinstance(node.op, Reshape): - input_ = node.inputs[0] - if input_.owner and isinstance(input_.owner.op, DimShuffle): - new_order = input_.owner.op.new_order - offset = 0 - for dim in new_order: - if dim == "x": - continue - elif dim != offset: - return False - else: - offset += 1 - return [ - reshape( - input_.owner.inputs[0], node.inputs[1], ndim=node.outputs[0].ndim - ) - ] - return False - - -@register_uncanonicalize -@local_optimizer([DimShuffle]) -def local_dimshuffle_alloc(fgraph, node): - """ - If an alloc is inside a dimshuffle which only adds dimension to the left, - scrap the dimshuffle and adds 1 into the alloc - - dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2) - """ - if isinstance(node.op, DimShuffle) and node.inputs[0].owner: - input_ = node.inputs[0] - if isinstance(input_.owner.op, Alloc): - # check if it only adds dimension to the left - new_order = node.op.new_order - expected_new_order = ("x",) * (len(new_order) - input_.ndim) + tuple( - range(input_.ndim) - ) - if new_order != expected_new_order: - return False - - # count numbers of 'x' - nb_new_dims = len(new_order) - input_.ndim - new_shape_input = (1,) * nb_new_dims + tuple(input_.owner.inputs[1:]) - - return [alloc(input_.owner.inputs[0], *new_shape_input)] - return False - - -@register_uncanonicalize -@local_optimizer([DimShuffle]) -def local_dimshuffle_subtensor(fgraph, node): - """If a subtensor is inside a dimshuffle which only drop - broadcastable dimensions, scrap the dimshuffle and index the - subtensor with 0 - - x[i:j, :, k:l].dimshuffle(0, 2) => - x[i:j, 0, k:l] if x.broadcastable == (False, True, False) - - """ - if isinstance(node.op, DimShuffle) and node.inputs[0].owner: - # the dimshuffle can only drop dimensions (cannot reshape nor add 'x') - if "x" in node.op.new_order: - return False - new_order = node.op.new_order - # new order could be empty - # Verif that we don't change dimensions order. - if len(new_order) > 1: - past_dim = new_order[0] - for dim in new_order[1:]: - if not dim > past_dim: - return False - else: - past_dim = dim - - input_ = node.inputs[0] - if isinstance(input_.owner.op, Subtensor): - # the arguments missing from the dimshuffles must be dims - # that are broadcastable - broadcastable = input_.broadcastable - - missing_dims = list(range(input_.ndim)) - for dim in new_order: - missing_dims.remove(dim) - - if not all(broadcastable[i] for i in missing_dims): - return False - - # create a new idx_list for a new Subtensor object - # have to loop on idx_list and inputs - # inputs has the length of sum of non None elements of idx_list - # (check in slice!). - # len(missing_dims) can be < len(idx_list), this happens if - # tensor was indexed such as x[scalar, :, :], check that as well - new_idx_list = list(input_.owner.op.idx_list) - new_inputs = [input_.owner.inputs[0]] - zero = constant(0) - slice_attr_list = ["start", "stop", "step"] - j = 0 - slice_i = -1 - subtensor_removed_dims = 0 - for i, idx in enumerate(input_.owner.op.idx_list): - if isinstance(idx, slice): - past_j = j - slice_i += 1 - for slice_attr in slice_attr_list: - if getattr(idx, slice_attr) is not None: - new_inputs += [input_.owner.inputs[1 + j]] - j += 1 - # if past_j == j indicates a slice(None, None, None), - # that's where we want to index with 0 if it is also at - # the same spot of a missing dim - if past_j == j and slice_i in missing_dims: - new_idx_list[i] = zero - new_inputs += [zero] - else: - new_inputs += [input_.owner.inputs[1 + j]] - j += 1 - subtensor_removed_dims += 1 - # Verify the trailing dimensions the subtensor didn't look at. - for idx in range(len(input_.owner.op.idx_list), new_inputs[0].ndim): - if (idx - subtensor_removed_dims) in missing_dims: - while len(new_idx_list) < idx: - new_idx_list.append(slice(None)) - - new_idx_list.append(zero) - new_inputs.append(zero) - return [Subtensor(new_idx_list)(*new_inputs)] - return False +from aesara.tensor.rewriting.uncanonicalize import * # noqa: F401 E402 F403 diff --git a/aesara/tensor/random/__init__.py b/aesara/tensor/random/__init__.py index f46f144e26..234e5520dd 100644 --- a/aesara/tensor/random/__init__.py +++ b/aesara/tensor/random/__init__.py @@ -1,5 +1,5 @@ -# Initialize `RandomVariable` optimizations -import aesara.tensor.random.opt +# Initialize `RandomVariable` rewrites +import aesara.tensor.random.rewriting import aesara.tensor.random.utils from aesara.tensor.random.basic import * from aesara.tensor.random.op import RandomState, default_rng diff --git a/aesara/tensor/random/basic.py b/aesara/tensor/random/basic.py index f4963276d8..6b79151451 100644 --- a/aesara/tensor/random/basic.py +++ b/aesara/tensor/random/basic.py @@ -70,6 +70,19 @@ def rng_fn(cls, *args, **kwargs): class UniformRV(RandomVariable): + r"""A uniform continuous random variable. + + The probability density function for `uniform` within the interval :math:`[l, h)` is: + + .. math:: + \begin{split} + f(x; l, h) = \begin{cases} + \frac{1}{h-l}\quad \text{for $l \leq x \leq h$},\\ + 0\quad \text{otherwise}. + \end{cases} + \end{split} + + """ name = "uniform" ndim_supp = 0 ndims_params = [0, 0] @@ -77,6 +90,25 @@ class UniformRV(RandomVariable): _print_name = ("U", "\\operatorname{U}") def __call__(self, low=0.0, high=1.0, size=None, **kwargs): + r"""Draw samples from a uniform distribution. + + The results are undefined when `high < low`. + + Parameters + ---------- + low + Lower boundary :math:`l` of the output interval; all values generated + will be greater than or equal to `low`. + high + Upper boundary :math:`h` of the output interval; all values generated + will be less than or equal to `high`. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ return super().__call__(low, high, size=size, **kwargs) @@ -84,28 +116,113 @@ def __call__(self, low=0.0, high=1.0, size=None, **kwargs): class TriangularRV(RandomVariable): + r"""A triangular continuous random variable. + + The probability density function for `triangular` within the interval :math:`[l, r)` + and mode :math:`m` (where the peak of the distribution occurs) is: + + .. math:: + + \begin{split} + f(x; l, m, r) = \begin{cases} + \frac{2(x-l)}{(r-l)(m-l)}\quad \text{for $l \leq x \leq m$},\\ + \frac{2(r-x)}{(r-l)(r-m)}\quad \text{for $m \leq x \leq r$},\\ + 0\quad \text{otherwise}. + \end{cases} + \end{split} + + """ name = "triangular" ndim_supp = 0 ndims_params = [0, 0, 0] dtype = "floatX" _print_name = ("Triang", "\\operatorname{Triang}") + def __call__(self, left, mode, right, size=None, **kwargs): + r"""Draw samples from a triangular distribution. + + Parameters + ---------- + left + Lower boundary :math:`l` of the output interval; all values generated + will be greater than or equal to `left`. + mode + Mode :math:`m` of the distribution, where the peak occurs. Must be such + that `left <= mode <= right`. + right + Upper boundary :math:`r` of the output interval; all values generated + will be less than or equal to `right`. Must be larger than `left`. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ + return super().__call__(left, mode, right, size=size, **kwargs) + triangular = TriangularRV() class BetaRV(RandomVariable): + r"""A beta continuous random variable. + + The probability density function for `beta` in terms of its parameters :math:`\alpha` + and :math:`\beta` is: + + .. math:: + + f(x; \alpha, \beta) = \frac{1}{B(\alpha, \beta)} x^{\alpha-1} (1-x)^{\beta-1} + + for :math:`0 \leq x \leq 1`. :math:`B` is the beta function defined as: + + .. math:: + + B(\alpha, \beta) = \int_0^1 t^{\alpha-1} (1-t)^{\beta-1} \mathrm{d}t + + """ name = "beta" ndim_supp = 0 ndims_params = [0, 0] dtype = "floatX" _print_name = ("Beta", "\\operatorname{Beta}") + def __call__(self, alpha, beta, size=None, **kwargs): + r"""Draw samples from a beta distribution. + + Parameters + ---------- + alpha + Alpha parameter :math:`\alpha` of the distribution. Must be positive. + beta + Beta parameter :math:`\beta` of the distribution. Must be positive. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ + return super().__call__(alpha, beta, size=size, **kwargs) + beta = BetaRV() class NormalRV(RandomVariable): + r"""A normal continuous random variable. + + The probability density function for `normal` in terms of its location parameter (mean) + :math:`\mu` and scale parameter (standard deviation) :math:`\sigma` is: + + .. math:: + + f(x; \mu, \sigma) = \frac{1}{\sqrt{2 \pi \sigma^2}} e^{-\frac{(x-\mu)^2}{2\sigma^2}} + + for :math:`\sigma > 0`. + + """ name = "normal" ndim_supp = 0 ndims_params = [0, 0] @@ -113,6 +230,21 @@ class NormalRV(RandomVariable): _print_name = ("N", "\\operatorname{N}") def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): + r"""Draw samples from a normal distribution. + + Parameters + ---------- + loc + Mean :math:`\mu` of the normal distribution. + scale + Standard deviation :math:`\sigma` of the normal distribution. Must be positive. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ return super().__call__(loc, scale, size=size, **kwargs) @@ -120,7 +252,28 @@ def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): class StandardNormalRV(NormalRV): + r"""A standard normal continuous random variable. + + The probability density function for `standard_normal` is: + + .. math:: + + f(x) = \frac{1}{\sqrt{2 \pi}} e^{-\frac{x^2}{2}} + + """ + def __call__(self, size=None, **kwargs): + """Draw samples from a standard normal distribution. + + Parameters + ---------- + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ return super().__call__(loc=0.0, scale=1.0, size=size, **kwargs) @@ -128,6 +281,18 @@ def __call__(self, size=None, **kwargs): class HalfNormalRV(ScipyRandomVariable): + r"""A half-normal continuous random variable. + + The probability density function for `halfnormal` in terms of its location parameter + :math:`\mu` and scale parameter :math:`\sigma` is: + + .. math:: + + f(x; \mu, \sigma) = \frac{1}{\sqrt{2 \pi \sigma^2}} e^{-\frac{(x-\mu)^2}{2\sigma^2}} + + for :math:`x \geq 0` and :math:`\sigma > 0`. + + """ name = "halfnormal" ndim_supp = 0 ndims_params = [0, 0] @@ -135,10 +300,40 @@ class HalfNormalRV(ScipyRandomVariable): _print_name = ("N**+", "\\operatorname{N^{+}}") def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): + r"""Draw samples from a half-normal distribution. + + Parameters + ---------- + loc + Location parameter :math:`\mu` of the distribution. + scale + Scale parameter :math:`\sigma` of the distribution. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ return super().__call__(loc, scale, size=size, **kwargs) @classmethod def rng_fn_scipy(cls, rng, loc, scale, size): + r"""Draw sample from a half-normal distribution using Scipy's generator. + + Parameters + ---------- + loc + Location parameter :math:`\mu` of the distribution. + scale + Scale parameter :math:`\sigma` of the distribution. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ return stats.halfnorm.rvs(loc, scale, random_state=rng, size=size) @@ -146,6 +341,18 @@ def rng_fn_scipy(cls, rng, loc, scale, size): class LogNormalRV(RandomVariable): + r"""A lognormal continuous random variable. + + The probability density function for `lognormal` in terms of the mean + parameter :math:`\mu` and sigma parameter :math:`\sigma` is: + + .. math:: + + f(x; \mu, \sigma) = \frac{1}{x \sqrt{2 \pi \sigma^2}} e^{-\frac{(\ln(x)-\mu)^2}{2\sigma^2}} + + for :math:`x > 0` and :math:`\sigma > 0`. + + """ name = "lognormal" ndim_supp = 0 ndims_params = [0, 0] @@ -153,6 +360,21 @@ class LogNormalRV(RandomVariable): _print_name = ("LogN", "\\operatorname{LogN}") def __call__(self, mean=0.0, sigma=1.0, size=None, **kwargs): + r"""Draw sample from a lognormal distribution. + + Parameters + ---------- + mean + Mean :math:`\mu` of the random variable's natural logarithm. + sigma + Standard deviation :math:`\sigma` of the random variable's natural logarithm. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ return super().__call__(mean, sigma, size=size, **kwargs) @@ -162,15 +384,19 @@ def __call__(self, mean=0.0, sigma=1.0, size=None, **kwargs): class GammaRV(ScipyRandomVariable): r"""A gamma continuous random variable. - The probability density function for `gamma` in terms of `shape = alpha` and `rate = beta` is: + The probability density function for `gamma` in terms of the shape parameter + :math:`\alpha` and rate parameter :math:`\beta` is: .. math:: - f(x, \alpha, \beta) = \frac{\beta^\alpha}{\Gamma(\alpha)}x^{\alpha-1}e^{-\beta x} + f(x; \alpha, \beta) = \frac{\beta^\alpha}{\Gamma(\alpha)}x^{\alpha-1}e^{-\beta x} + + for :math:`x \geq 0`, :math:`\alpha > 0` and :math:`\beta > 0`. :math:`\Gamma` is + the gamma function: - for :math:`x \geq 0`, :math:`\alpha > 0` and :math:`\beta > 0`. `gamma` - takes ``shape`` as a shape parameter for :math:`\alpha` and ``rate`` as a - rate parameter for :math:`\beta`. + .. math:: + + \Gamma(x) = \int_0^{\infty} t^{x-1} e^{-t} \mathrm{d}t """ name = "gamma" @@ -180,14 +406,14 @@ class GammaRV(ScipyRandomVariable): _print_name = ("Gamma", "\\operatorname{Gamma}") def __call__(self, shape, rate, size=None, **kwargs): - """Return gamma-distributed random variables. + r"""Draw samples from a gamma distribution. Parameters ---------- shape - The shape of the gamma distribution. Must be positive. + The shape :math:`\alpha` of the gamma distribution. Must be positive. rate - The rate of the gamma distribution. Must be positive. + The rate :math:`\beta` of the gamma distribution. Must be positive. size Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` independent, identically distributed random variables are @@ -195,7 +421,6 @@ def __call__(self, shape, rate, size=None, **kwargs): is returned. """ - return super().__call__(shape, 1.0 / rate, size=size, **kwargs) @classmethod @@ -207,17 +432,65 @@ def rng_fn_scipy(cls, rng, shape, scale, size): class ChiSquareRV(RandomVariable): + r"""A chi square continuous random variable. + + The probability density function for `chisquare` in terms of the number of degrees of + freedom :math:`k` is: + + .. math:: + + f(x; k) = \frac{(1/2)^{k/2}}{\Gamma(k/2)} x^{k/2-1} e^{-x/2} + + for :math:`k > 2`. :math:`\Gamma` is the gamma function: + + .. math:: + + \Gamma(x) = \int_0^{\infty} t^{x-1} e^{-t} \mathrm{d}t + + + This variable is obtained by summing the squares :math:`k` independent, standard normally + distributed random variables. + + """ name = "chisquare" ndim_supp = 0 ndims_params = [0] dtype = "floatX" _print_name = ("ChiSquare", "\\operatorname{ChiSquare}") + def __call__(self, df, size=None, **kwargs): + r"""Draw samples from a chisquare distribution. + + Parameters + ---------- + df + The number :math:`k` of degrees of freedom. Must be positive. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ + return super().__call__(df, size=size, **kwargs) + chisquare = ChiSquareRV() class ParetoRV(ScipyRandomVariable): + r"""A pareto continuous random variable. + + The probability density function for `pareto` in terms of its shape parameter :math:`b` and + scale parameter :math:`x_m` is: + + .. math:: + + f(x; b, x_m) = \frac{b x_m^b}{x^{b+1}} + + and is defined for :math:`x \geq x_m`. + + """ name = "pareto" ndim_supp = 0 ndims_params = [0, 0] @@ -225,6 +498,21 @@ class ParetoRV(ScipyRandomVariable): _print_name = ("Pareto", "\\operatorname{Pareto}") def __call__(self, b, scale=1.0, size=None, **kwargs): + r"""Draw samples from a pareto distribution. + + Parameters + ---------- + b + The shape :math:`b` (or exponent) of the pareto distribution. Must be positive. + scale + The scale :math:`x_m` of the pareto distribution. Must be positive. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ return super().__call__(b, scale, size=size, **kwargs) @classmethod @@ -236,6 +524,18 @@ def rng_fn_scipy(cls, rng, b, scale, size): class GumbelRV(ScipyRandomVariable): + r"""A gumbel continuous random variable. + + The probability density function for `gumbel` in terms of its location parameter :math:`\mu` and + scale parameter :math:`\beta` is: + + .. math:: + + f(x; \mu, \beta) = \frac{\exp(-(x + e^{(x-\mu)/\beta})}{\beta} + + for :math:`\beta > 0`. + + """ name = "gumbel" ndim_supp = 0 ndims_params = [0, 0] @@ -249,6 +549,21 @@ def __call__( size: Optional[Union[List[int], int]] = None, **kwargs, ) -> RandomVariable: + r"""Draw samples from a gumbel distribution. + + Parameters + ---------- + loc + The location parameter :math:`\mu` of the distribution. + scale + The scale :math:`\beta` of the distribution. Must be positive. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ return super().__call__(loc, scale, size=size, **kwargs) @classmethod @@ -266,6 +581,17 @@ def rng_fn_scipy( class ExponentialRV(RandomVariable): + r"""An exponential continuous random variable. + + The probability density function for `exponential` in terms of its scale parameter :math:`\beta` is: + + .. math:: + + f(x; \beta) = \frac{\exp(-x / \beta)}{\beta} + + for :math:`x \geq 0` and :math:`\beta > 0`. + + """ name = "exponential" ndim_supp = 0 ndims_params = [0] @@ -273,6 +599,19 @@ class ExponentialRV(RandomVariable): _print_name = ("Exp", "\\operatorname{Exp}") def __call__(self, scale=1.0, size=None, **kwargs): + r"""Draw samples from an exponential distribution. + + Parameters + ---------- + scale + The scale :math:`\beta` of the distribution. Must be positive. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ return super().__call__(scale, size=size, **kwargs) @@ -280,17 +619,56 @@ def __call__(self, scale=1.0, size=None, **kwargs): class WeibullRV(RandomVariable): + r"""A weibull continuous random variable. + + The probability density function for `weibull` in terms of its shape parameter :math:`k` is : + + .. math:: + + f(x; k) = k x^{k-1} e^{-x^k} + + for :math:`x \geq 0` and :math:`k > 0`. + + """ name = "weibull" ndim_supp = 0 ndims_params = [0] dtype = "floatX" _print_name = ("Weibull", "\\operatorname{Weibull}") + def __call__(self, shape, size=None, **kwargs): + r"""Draw samples from a weibull distribution. + + Parameters + ---------- + shape + The shape :math:`k` of the distribution. Must be positive. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ + return super().__call__(shape, size=size, **kwargs) + weibull = WeibullRV() class LogisticRV(RandomVariable): + r"""A logistic continuous random variable. + + The probability density function for `logistic` in terms of its location parameter :math:`\mu` and + scale parameter :math:`s` is : + + .. math:: + + f(x; \mu, s) = \frac{e^{-(x-\mu)/s}}{s(1+e^{-(x-\mu)/s})^2} + + for :math:`s > 0`. + + """ name = "logistic" ndim_supp = 0 ndims_params = [0, 0] @@ -298,6 +676,21 @@ class LogisticRV(RandomVariable): _print_name = ("Logistic", "\\operatorname{Logistic}") def __call__(self, loc=0, scale=1, size=None, **kwargs): + r"""Draw samples from a logistic distribution. + + Parameters + ---------- + loc + The location parameter :math:`\mu` of the distribution. + scale + The scale :math:`s` of the distribution. Must be positive. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ return super().__call__(loc, scale, size=size, **kwargs) @@ -305,12 +698,43 @@ def __call__(self, loc=0, scale=1, size=None, **kwargs): class VonMisesRV(RandomVariable): + r"""A von Misses continuous random variable. + + The probability density function for `vonmisses` in terms of its mode :math:`\mu` and + dispersion parameter :math:`\kappa` is : + + .. math:: + + f(x; \mu, \kappa) = \frac{e^{\kappa \cos(x-\mu)}}{2 \pi I_0(\kappa)} + + for :math:`x \in [-\pi, \pi]` and :math:`\kappa > 0`. :math:`I_0` is the modified Bessel + function of order 0. + + """ name = "vonmises" ndim_supp = 0 ndims_params = [0, 0] dtype = "floatX" _print_name = ("VonMises", "\\operatorname{VonMises}") + def __call__(self, mu, kappa, size=None, **kwargs): + r"""Draw samples from a von Mises distribution. + + Parameters + ---------- + mu + The mode :math:`\mu` of the distribution. + kappa + The dispersion parameter :math:`\kappa` of the distribution. Must be positive. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ + return super().__call__(mu, kappa, size=size, **kwargs) + vonmises = VonMisesRV() @@ -337,6 +761,18 @@ def safe_multivariate_normal(mean, cov, size=None, rng=None): class MvNormalRV(RandomVariable): + r"""A multivariate normal random variable. + + The probability density function for `multivariate_normal` in term of its location parameter + :math:`\boldsymbol{\mu}` and covariance matrix :math:`\Sigma` is + + .. math:: + + f(\boldsymbol{x}; \boldsymbol{\mu}, \Sigma) = \det(2 \pi \Sigma)^{-1/2} \exp\left(-\frac{1}{2} (\boldsymbol{x} - \boldsymbol{\mu})^T \Sigma (\boldsymbol{x} - \boldsymbol{\mu})\right) + + where :math:`\Sigma` is a positive semi-definite matrix. + + """ name = "multivariate_normal" ndim_supp = 1 ndims_params = [1, 2] @@ -344,7 +780,23 @@ class MvNormalRV(RandomVariable): _print_name = ("N", "\\operatorname{N}") def __call__(self, mean=None, cov=None, size=None, **kwargs): + r""" "Draw samples from a multivariate normal distribution. + + Parameters + ---------- + mean + Location parameter (mean) :math:`\boldsymbol{\mu}` of the distribution. Vector + of length `N`. + cov + Covariance matrix :math:`\Sigma` of the distribution. Must be a symmetric + and positive-semidefinite `NxN` matrix. + size + Given a size of, for example, `(m, n, k)`, `m * n * k` independent, + identically distributed samples are generated. Because each sample + is `N`-dimensional, the output shape is `(m, n, k, N)`. If no shape + is specified, a single `N`-dimensional sample is returned. + """ dtype = aesara.config.floatX if self.dtype == "floatX" else self.dtype if mean is None: @@ -382,12 +834,42 @@ def rng_fn(cls, rng, mean, cov, size): class DirichletRV(RandomVariable): + r"""A Dirichlet continuous random variable. + + The probability density function for `dirichlet` in terms of the vector of + concentration parameters :math:`\boldsymbol{\alpha}` is: + + .. math:: + + f(x; \boldsymbol{\alpha}) = \prod_{i=1}^k x_i^{\alpha_i-1} + + where :math:`x` is a vector, such that :math:`x_i > 0\;\forall i` and + :math:`\sum_{i=1}^k x_i = 1`. + + """ name = "dirichlet" ndim_supp = 1 ndims_params = [1] dtype = "floatX" _print_name = ("Dir", "\\operatorname{Dir}") + def __call__(self, alphas, size=None, **kwargs): + r"""Draw samples from a dirichlet distribution. + + Parameters + ---------- + alphas + A sequence of concentration parameters :math:`\boldsymbol{\alpha}` of the + distribution. A sequence of length `k` will produce samples of length `k`. + size + Given a size of, for example, `(r, s, t)`, `r * s * t` independent, + identically distributed samples are generated. Because each sample + is `k`-dimensional, the output shape is `(r, s, t, k)`. If no shape + is specified, a single `k`-dimensional sample is returned. + + """ + return super().__call__(alphas, size=size, **kwargs) + @classmethod def rng_fn(cls, rng, alphas, size): if alphas.ndim > 1: @@ -413,6 +895,18 @@ def rng_fn(cls, rng, alphas, size): class PoissonRV(RandomVariable): + r"""A poisson discrete random variable. + + The probability mass function for `poisson` in terms of the expected number + of events :math:`\lambda` is: + + .. math:: + + f(k; \lambda) = \frac{\lambda^k e^{-\lambda}}{k!} + + for :math:`\lambda > 0`. + + """ name = "poisson" ndim_supp = 0 ndims_params = [0] @@ -420,35 +914,119 @@ class PoissonRV(RandomVariable): _print_name = ("Pois", "\\operatorname{Pois}") def __call__(self, lam=1.0, size=None, **kwargs): + r"""Draw samples from a poisson distribution. + + Parameters + ---------- + lam + Expected number of events :math:`\lambda`. Must be positive. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ return super().__call__(lam, size=size, **kwargs) -poisson = PoissonRV() +poisson = PoissonRV() + + +class GeometricRV(RandomVariable): + r"""A geometric discrete random variable. + + The probability mass function for `geometric` for the number of successes :math:`k` + before the first failure in terms of the probability of success :math:`p` of a single + trial is: + + .. math:: + + f(k; p) = p^{k-1}(1-p) + for :math:`0 \geq p \geq 1`. -class GeometricRV(RandomVariable): + """ name = "geometric" ndim_supp = 0 ndims_params = [0] dtype = "int64" _print_name = ("Geom", "\\operatorname{Geom}") + def __call__(self, p, size=None, **kwargs): + r"""Draw samples from a geometric distribution. + + Parameters + ---------- + p + Probability of success :math:`p` of an individual trial. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * + k` independent, identically distributed samples are returned. + Default is `None` in which case a single sample is returned. + + """ + return super().__call__(p, size=size, **kwargs) + geometric = GeometricRV() class HyperGeometricRV(RandomVariable): + r"""A hypergeometric discrete random variable. + + The probability mass function for `hypergeometric` for the number of + successes :math:`k` in :math:`n` draws without replacement, from a + finite population of size :math:`N` with :math:`K` desired items is: + + .. math:: + + f(k; n, N, K) = \frac{{K \choose k} {N-K \choose n-k}}{{N \choose n}} + + """ name = "hypergeometric" ndim_supp = 0 ndims_params = [0, 0, 0] dtype = "int64" _print_name = ("HyperGeom", "\\operatorname{HyperGeom}") + def __call__(self, ngood, nbad, nsample, size=None, **kwargs): + r"""Draw samples from a geometric distribution. + + Parameters + ---------- + ngood + Number :math:`K` of desirable items in the population. Positive integer. + nbad + Number :math:`N-K` of undesirable items in the population. Positive integer. + nsample + Number :math:`n` of items sampled. Must be less than :math:`N`, + i.e. `ngood + nbad`.` Positive integer. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed samples are returned. Default is + `None` in which case a single sample is returned. + + """ + return super().__call__(ngood, nbad, nsample, size=size, **kwargs) + hypergeometric = HyperGeometricRV() class CauchyRV(ScipyRandomVariable): + r"""A Cauchy continuous random variable. + + The probability density function for `cauchy` in terms of its location + parameter :math:`x_0` and scale parameter :math:`\gamma` is: + + .. math:: + + f(x; x_0, \gamma) = \frac{1}{\pi \gamma \left(1 + (\frac{x-x_0}{\gamma})^2\right)} + + where :math:`\gamma > 0`. + + """ name = "cauchy" ndim_supp = 0 ndims_params = [0, 0] @@ -456,6 +1034,21 @@ class CauchyRV(ScipyRandomVariable): _print_name = ("C", "\\operatorname{C}") def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): + r"""Draw samples from a Cauchy distribution. + + Parameters + ---------- + loc + Location parameter :math:`x_0` of the distribution. + scale + Scale parameter :math:`\gamma` of the distribution. Must be + positive. + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed samples are returned. Default is + `None` in which case a single sample is returned. + + """ return super().__call__(loc, scale, size=size, **kwargs) @classmethod @@ -467,6 +1060,18 @@ def rng_fn_scipy(cls, rng, loc, scale, size): class HalfCauchyRV(ScipyRandomVariable): + r"""A half-Cauchy continuous random variable. + + The probability density function for `halfcauchy` in terms of its location + parameter :math:`x_0` and scale parameter :math:`\gamma` is: + + .. math:: + + f(x; x_0, \gamma) = \frac{1}{\pi \gamma \left(1 + (\frac{x-x_0}{\gamma})^2\right)} + + for :math:`x \geq 0` where :math:`\gamma > 0`. + + """ name = "halfcauchy" ndim_supp = 0 ndims_params = [0, 0] @@ -474,6 +1079,21 @@ class HalfCauchyRV(ScipyRandomVariable): _print_name = ("C**+", "\\operatorname{C^{+}}") def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): + r"""Draw samples from a half-Cauchy distribution. + + Parameters + ---------- + loc + Location parameter :math:`x_0` of the distribution. + scale + Scale parameter :math:`\gamma` of the distribution. Must be + positive. + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed samples are returned. Default is + `None`, in which case a single sample is returned. + + """ return super().__call__(loc, scale, size=size, **kwargs) @classmethod @@ -485,21 +1105,67 @@ def rng_fn_scipy(cls, rng, loc, scale, size): class InvGammaRV(ScipyRandomVariable): + r"""An inverse-gamma continuous random variable. + + The probability density function for `invgamma` in terms of its shape + parameter :math:`\alpha` and scale parameter :math:`\beta` is: + + .. math:: + + f(x; \alpha, \beta) = \frac{\beta^\alpha}{\Gamma(\alpha)} x^{-(\alpha+1)} \exp\left(-\frac{\beta}{x}\right) + + for :math:`x > 0`, where :math:`\alpha > 0` and :math:`\beta > 0`. :math:`Gamma` is the gamma function : + + .. math:: + + \Gamma(x) = \int_0^{\infty} t^{x-1} e^{-t} \mathrm{d}t + + """ name = "invgamma" ndim_supp = 0 ndims_params = [0, 0] dtype = "floatX" _print_name = ("InvGamma", "\\operatorname{Gamma^{-1}}") + def __call__(self, shape, scale, size=None, **kwargs): + r"""Draw samples from an inverse-gamma distribution. + + Parameters + ---------- + shape + Shape parameter :math:`\alpha` of the distribution. Must be positive. + scale + Scale parameter :math:`\beta` of the distribution. Must be + positive. + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed sample are returned. Default is + `None`, in which case a single sample is returned. + + """ + return super().__call__(shape, scale, size=size, **kwargs) + @classmethod - def rng_fn_scipy(cls, rng, shape, rate, size): - return stats.invgamma.rvs(shape, scale=rate, size=size, random_state=rng) + def rng_fn_scipy(cls, rng, shape, scale, size): + return stats.invgamma.rvs(shape, scale=scale, size=size, random_state=rng) invgamma = InvGammaRV() class WaldRV(RandomVariable): + r"""A Wald (or inverse Gaussian) continuous random variable. + + The probability density function for `wald` in terms of its mean + parameter :math:`\mu` and shape parameter :math:`\lambda` is: + + .. math:: + + f(x; \mu, \lambda) = \sqrt{\frac{\lambda}{2 \pi x^3}} \exp\left(-\frac{\lambda (x-\mu)^2}{2 \mu^2 x}\right) + + for :math:`x > 0`, where :math:`\mu > 0` and :math:`\lambda > 0`. + + """ name = "wald" ndim_supp = 0 ndims_params = [0, 0] @@ -507,6 +1173,21 @@ class WaldRV(RandomVariable): _print_name_ = ("Wald", "\\operatorname{Wald}") def __call__(self, mean=1.0, scale=1.0, size=None, **kwargs): + r"""Draw samples from a Wald distribution. + + Parameters + ---------- + mean + Mean parameter :math:`\mu` of the distribution. Must be positive. + shape + Shape parameter :math:`\lambda` of the distribution. Must be + positive. + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed samples are returned. Default is + `None`, in which case a single sample is returned. + + """ return super().__call__(mean, scale, size=size, **kwargs) @@ -514,12 +1195,45 @@ def __call__(self, mean=1.0, scale=1.0, size=None, **kwargs): class TruncExponentialRV(ScipyRandomVariable): + r"""A truncated exponential continuous random variable. + + The probability density function for `truncexp` in terms of its shape + parameter :math:`b`, location parameter :math:`\alpha` and scale + parameter :math:`\beta` is: + + .. math:: + + f(x; b, \alpha, \beta) = \frac{\exp(-(x-\alpha)/\beta)}{\beta (1-\exp(-b))} + + for :math:`0 \leq x \leq b` and :math:`\beta > 0`. + + """ name = "truncexpon" ndim_supp = 0 ndims_params = [0, 0, 0] dtype = "floatX" _print_name = ("TruncExp", "\\operatorname{TruncExp}") + def __call__(self, b, loc=0.0, scale=1.0, size=None, **kwargs): + r"""Draw samples from a truncated exponential distribution. + + Parameters + ---------- + b + Shape parameter :math:`b` of the distribution. Must be positive. + loc + Location parameter :math:`\alpha` of the distribution. + scale + Scale parameter :math:`\beta` of the distribution. Must be + positive. + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed samples are returned. Default is + `None` in which case a single sample is returned. + + """ + return super().__call__(b, loc, scale, size=size, **kwargs) + @classmethod def rng_fn_scipy(cls, rng, b, loc, scale, size): return stats.truncexpon.rvs( @@ -531,12 +1245,45 @@ def rng_fn_scipy(cls, rng, b, loc, scale, size): class BernoulliRV(ScipyRandomVariable): + r"""A Bernoulli discrete random variable. + + The probability mass function for `bernoulli` in terms of the probability + of success :math:`p` of a single trial is: + + + .. math:: + + \begin{split} + f(k; p) = \begin{cases} + (1-p)\quad \text{if $k = 0$},\\ + p\quad \text{if $k=1$}\\ + \end{cases} + \end{split} + + where :math:`0 \leq p \leq 1`. + + """ name = "bernoulli" ndim_supp = 0 ndims_params = [0] dtype = "int64" _print_name = ("Bern", "\\operatorname{Bern}") + def __call__(self, p, size=None, **kwargs): + r"""Draw samples from a Bernoulli distribution. + + Parameters + ---------- + p + Probability of success :math:`p` of a single trial. + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed samples are returned. Default + is `None` in which case a single sample is returned. + + """ + return super().__call__(p, size=size, **kwargs) + @classmethod def rng_fn_scipy(cls, rng, p, size): return stats.bernoulli.rvs(p, size=size, random_state=rng) @@ -546,34 +1293,120 @@ def rng_fn_scipy(cls, rng, p, size): class LaplaceRV(RandomVariable): + r"""A Laplace continuous random variable. + + The probability density function for `laplace` in terms of its location + parameter :math:`\mu` and scale parameter :math:`\lambda` is: + + .. math:: + + f(x; \mu, \lambda) = \frac{1}{2 \lambda} \exp\left(-\frac{|x-\mu|}{\lambda}\right) + + with :math:`\lambda > 0`. + + """ name = "laplace" ndim_supp = 0 ndims_params = [0, 0] dtype = "floatX" _print_name = ("Laplace", "\\operatorname{Laplace}") + def __call__(self, loc=0.0, scale=1.0, size=None, **kwargs): + r"""Draw samples from a Laplace distribution. + + Parameters + ---------- + loc + Location parameter :math:`\mu` of the distribution. + scale + Scale parameter :math:`\lambda` of the distribution. Must be + positive. + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed samples are returned. Default + is `None` in which case a single sample is returned. + + """ + return super().__call__(loc, scale, size=size, **kwargs) + laplace = LaplaceRV() class BinomialRV(RandomVariable): + r"""A binomial discrete random variable. + + The probability mass function for `binomial` for the number :math:`k` of successes + in terms of the probability of success :math:`p` of a single trial and the number + :math:`n` of trials is: + + .. math:: + + f(k; p, n) = {n \choose k} p^k (1-p)^{n-k} + + """ name = "binomial" ndim_supp = 0 ndims_params = [0, 0] dtype = "int64" _print_name = ("Binom", "\\operatorname{Binom}") + def __call__(self, n, p, size=None, **kwargs): + r"""Draw samples from a binomial distribution. + + Parameters + ---------- + n + Number of trials :math:`n`. Must be a positive integer. + p + Probability of success :math:`p` of a single trial. + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed samples are returned. Default + is `None` in which case a single sample is returned. + + """ + return super().__call__(n, p, size=size, **kwargs) + binomial = BinomialRV() class NegBinomialRV(ScipyRandomVariable): + r"""A negative binomial discrete random variable. + + The probability mass function for `nbinom` for the number :math:`k` of draws + before observing the :math:`n`\th success in terms of the probability of + success :math:`p` of a single trial is: + + .. math:: + + f(k; p, n) = {k+n-1 \choose n-1} p^n (1-p)^{k} + + """ name = "nbinom" ndim_supp = 0 ndims_params = [0, 0] dtype = "int64" _print_name = ("NB", "\\operatorname{NB}") + def __call__(self, n, p, size=None, **kwargs): + r"""Draw samples from a negative binomial distribution. + + Parameters + ---------- + n + Number of successes :math:`n`. Must be a positive integer. + p + Probability of success :math:`p` of a single trial. + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed samples are returned. Default + is `None` in which case a single sample is returned. + + """ + return super().__call__(n, p, size=size, **kwargs) + @classmethod def rng_fn_scipy(cls, rng, n, p, size): return stats.nbinom.rvs(n, p, size=size, random_state=rng) @@ -584,12 +1417,48 @@ def rng_fn_scipy(cls, rng, n, p, size): class BetaBinomialRV(ScipyRandomVariable): + r"""A beta-binomial discrete random variable. + + The probability mass function for `betabinom` in terms of its shape + parameters :math:`n \geq 0`, :math:`a > 0`, :math:`b > 0` and the probability + :math:`p` is: + + .. math:: + + f(k; p, n, a, b) = {n \choose k} \frac{\operatorname{B}(k+a, n-k+b)}{\operatorname{B}(a,b)} + + where :math:`\operatorname{B}` is the beta function: + + .. math:: + + \operatorname{B}(a, b) = \int_0^1 t^{a-1} (1-t)^{b-1} \mathrm{d}t + + """ name = "beta_binomial" ndim_supp = 0 ndims_params = [0, 0, 0] dtype = "int64" _print_name = ("BetaBinom", "\\operatorname{BetaBinom}") + def __call__(self, n, a, b, size=None, **kwargs): + r"""Draw samples from a beta-binomial distribution. + + Parameters + ---------- + n + Shape parameter :math:`n`. Must be a positive integer. + a + Shape parameter :math:`a`. Must be positive. + b + Shape parameter :math:`b`. Must be positive. + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed samples are returned. Default + is `None` in which case a single sample is returned. + + """ + return super().__call__(n, a, b, size=size, **kwargs) + @classmethod def rng_fn_scipy(cls, rng, n, a, b, size): return stats.betabinom.rvs(n, a, b, size=size, random_state=rng) @@ -599,6 +1468,18 @@ def rng_fn_scipy(cls, rng, n, a, b, size): class GenGammaRV(ScipyRandomVariable): + r"""A generalized gamma continuous random variable. + + The probability density function of `gengamma` in terms of its scale parameter + :math:`\alpha` and other parameters :math:`p` and :math:`\lambda` is: + + .. math:: + + f(x; \alpha, \lambda, p) = \frac{p/\lambda^\alpha}{\Gamma(\alpha/p)} x^{\alpha-1} e^{-(x/\lambda)^p} + + for :math:`x > 0`, where :math:`\alpha, \lambda, p > 0`. + + """ name = "gengamma" ndim_supp = 0 ndims_params = [0, 0, 0] @@ -606,6 +1487,23 @@ class GenGammaRV(ScipyRandomVariable): _print_name = ("GG", "\\operatorname{GG}") def __call__(self, alpha=1.0, p=1.0, lambd=1.0, size=None, **kwargs): + r"""Draw samples from a generalized gamma distribution. + + Parameters + ---------- + alpha + Parameter :math:`\alpha`. Must be positive. + p + Parameter :math:`p`. Must be positive. + lambd + Scale parameter :math:`\lambda`. Must be positive. + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed samples are + returned. Default is `None` in which case a single sample + is returned. + + """ return super().__call__(alpha, p, lambd, size=size, **kwargs) @classmethod @@ -619,20 +1517,50 @@ def rng_fn_scipy(cls, rng, alpha, p, lambd, size): class MultinomialRV(RandomVariable): - """A Multinomial random variable type. + r"""A multinomial discrete random variable. + + The probability mass function of `multinomial` in terms of the number + of experiments :math:`n` and the probabilities :math:`p_1, \dots, p_k` + of the :math:`k` different possible outcomes is: + + + .. math:: + + f(x_1,\dots,x_k; n, p_1, \dots, p_k) = \frac{n!}{x_1! \dots x_k!} \prod_{i=1}^k x_i^{p_i} + + + where :math:`n>0` and :math:`\sum_{i=1}^k p_i = 1`. Notes ----- The length of the support dimension is determined by the last dimension in the *second* parameter (i.e. the probabilities vector). - """ + """ name = "multinomial" ndim_supp = 1 ndims_params = [0, 1] dtype = "int64" _print_name = ("MN", "\\operatorname{MN}") + def __call__(self, n, p, size=None, **kwargs): + r"""Draw samples from a discrete multinomial distribution. + + Parameters + ---------- + n + Number of experiments :math:`n`. Must be a positive integer. + p + Probabilities of each of the :math:`k` different outcomes. + size + Given a size of, for example, `(r, s, t)`, `r * s * t` independent, + identically distributed samples are generated. Because each sample + is `k`-dimensional, the output shape is `(r, s, t, k)`. If no shape + is specified, a single `k`-dimensional sample is returned. + + """ + return super().__call__(n, p, size=size, **kwargs) + def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): return default_supp_shape_from_params( self.ndim_supp, dist_params, rep_param_idx, param_shapes @@ -663,29 +1591,55 @@ def rng_fn(cls, rng, n, p, size): class CategoricalRV(RandomVariable): + r"""A categorical discrete random variable. + + The probability mass function of `categorical` in terms of its :math:`N` event + probabilities :math:`p_1, \dots, p_N` is: + + .. math:: + + P(k=i) = p_k + + where :math:`\sum_i p_i = 1`. + + """ + name = "categorical" ndim_supp = 0 ndims_params = [1] dtype = "int64" _print_name = ("Cat", "\\operatorname{Cat}") - @classmethod - def rng_fn(cls, rng, p, size): - if size is None: - size = () + def __call__(self, p, size=None, **kwargs): + r"""Draw samples from a discrete categorical distribution. - size = tuple(np.atleast_1d(size)) - ind_shape = p.shape[:-1] + Parameters + ---------- + p + An array that contains the :math:`N` event probabilities. + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed random samples are + returned. Default is `None`, in which case a single sample + is returned. - if len(ind_shape) > 0: - if len(size) > 0 and size[-len(ind_shape) :] != ind_shape: - raise ValueError("Parameters shape and size do not match.") + """ + return super().__call__(p, size=size, **kwargs) - samples_shape = size[: -len(ind_shape)] + ind_shape + @classmethod + def rng_fn(cls, rng, p, size): + if size is None: + size = p.shape[:-1] else: - samples_shape = size - - unif_samples = rng.uniform(size=samples_shape) + # Check that `size` does not define a shape that would be broadcasted + # to `p.shape[:-1]` in the call to `vsearchsorted` below. + if len(size) < (p.ndim - 1): + raise ValueError("`size` is incompatible with the shape of `p`") + for s, ps in zip(reversed(size), reversed(p.shape[:-1])): + if s == 1 and ps != 1: + raise ValueError("`size` is incompatible with the shape of `p`") + + unif_samples = rng.uniform(size=size) samples = vsearchsorted(p.cumsum(axis=-1), unif_samples) return samples @@ -695,6 +1649,12 @@ def rng_fn(cls, rng, p, size): class RandIntRV(RandomVariable): + r"""A discrete uniform random variable. + + Only available for `RandomStateType`. Use `integers` with `RandomGeneratorType`\s. + + """ + name = "randint" ndim_supp = 0 ndims_params = [0, 0] @@ -702,6 +1662,25 @@ class RandIntRV(RandomVariable): _print_name = ("randint", "\\operatorname{randint}") def __call__(self, low, high=None, size=None, **kwargs): + r"""Draw samples from a discrete uniform distribution. + + Parameters + ---------- + low + Lower boundary of the output interval. All values generated will + be greater than or equal to `low`, unless `high=None`, in which case + all values generated are greater than or equal to `0` and + smaller than `low` (exclusive). + high + Upper boundary of the output interval. All values generated + will be smaller than `high` (exclusive). + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed samples are + returned. Default is `None`, in which case a single + sample is returned. + + """ if high is None: low, high = 0, low return super().__call__(low, high, size=size, **kwargs) @@ -718,6 +1697,11 @@ def make_node(self, rng, *args, **kwargs): class IntegersRV(RandomVariable): + r"""A discrete uniform random variable. + + Only available for `RandomGeneratorType`. Use `randint` with `RandomStateType`\s. + + """ name = "integers" ndim_supp = 0 ndims_params = [0, 0] @@ -725,6 +1709,23 @@ class IntegersRV(RandomVariable): _print_name = ("integers", "\\operatorname{integers}") def __call__(self, low, high=None, size=None, **kwargs): + r"""Draw samples from a discrete uniform distribution. + + Parameters + ---------- + low + Lower boundary of the output interval. All values generated + will be greater than or equal to `low` (inclusive). + high + Upper boundary of the output interval. All values generated + will be smaller than `high` (exclusive). + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * k` + independent, identically distributed samples are + returned. Default is `None`, in which case a single sample + is returned. + + """ if high is None: low, high = 0, low return super().__call__(low, high, size=size, **kwargs) @@ -742,6 +1743,8 @@ def make_node(self, rng, *args, **kwargs): class ChoiceRV(RandomVariable): + """Randomly choose an element in a sequence.""" + name = "choice" ndim_supp = 0 ndims_params = [1, 1, 0] @@ -756,14 +1759,40 @@ def _supp_shape_from_params(self, *args, **kwargs): raise NotImplementedError() def _infer_shape(self, size, dist_params, param_shapes=None): - return size + (a, p, _) = dist_params + + if isinstance(p.type, aesara.tensor.type_other.NoneTypeT): + shape = super()._infer_shape(size, (a,), param_shapes) + else: + shape = super()._infer_shape(size, (a, p), param_shapes) + + return shape def __call__(self, a, size=None, replace=True, p=None, **kwargs): + r"""Generate a random sample from an array. + + Parameters + ---------- + a + The array from which to randomly sample an element. If an int, + a sample is generated from `aesara.tensor.arange(a)`. + size + Sample shape. If the given size is `(m, n, k)`, then `m * n * + k` independent samples are returned. Default is `None`, in + which case a single sample is returned. + replace + When ``True``, sampling is performed with replacement. + p + The probabilities associated with each entry in `a`. If not + given, all elements have equal probability. + """ + a = as_tensor_variable(a) - a = as_tensor_variable(a, ndim=1) + if a.ndim == 0: + a = aesara.tensor.arange(a) if p is None: - p = aesara.tensor.type_other.NoneConst.clone() + p = aesara.tensor.type_other.NoneConst if isinstance(replace, bool): replace = aesara.tensor.constant(np.array(replace)) @@ -775,6 +1804,8 @@ def __call__(self, a, size=None, replace=True, p=None, **kwargs): class PermutationRV(RandomVariable): + """Randomly shuffle a sequence.""" + name = "permutation" ndim_supp = 1 ndims_params = [1] @@ -798,6 +1829,15 @@ def _infer_shape(self, size, dist_params, param_shapes=None): return x_shape def __call__(self, x, **kwargs): + r"""Randomly permute a sequence or a range of values. + + Parameters + ---------- + x + If `x` is an integer, randomly permute `np.arange(x)`. If `x` is a sequence, + shuffle its elements randomly. + + """ x = as_tensor_variable(x) return super().__call__(x, dtype=x.dtype, **kwargs) diff --git a/aesara/tensor/random/opt.py b/aesara/tensor/random/opt.py index 2321ab901c..69f4cab0da 100644 --- a/aesara/tensor/random/opt.py +++ b/aesara/tensor/random/opt.py @@ -1,411 +1,10 @@ -from aesara.compile import optdb -from aesara.configdefaults import config -from aesara.graph.op import compute_test_value -from aesara.graph.opt import in2out, local_optimizer -from aesara.tensor.basic import constant, get_vector_length -from aesara.tensor.elemwise import DimShuffle -from aesara.tensor.extra_ops import broadcast_to -from aesara.tensor.math import sum as at_sum -from aesara.tensor.random.op import RandomVariable -from aesara.tensor.random.utils import broadcast_params -from aesara.tensor.shape import Shape, Shape_i -from aesara.tensor.subtensor import ( - AdvancedSubtensor, - AdvancedSubtensor1, - Subtensor, - as_index_variable, - get_idx_list, - indexed_result_shape, -) - - -def is_rv_used_in_graph(base_rv, node, fgraph): - """Determine whether or not `base_rv` is used by a node other than `node` in `fgraph`. - - If a node uses `Shape` or `Shape_i` on the `base_rv`, we ignore it, because - those `Op`s don't rely on the actual sample values of `base_rv`. - - TODO: We should apply all the shape rewrites before these rewrites, since - that would properly remove the unnecessary dependencies on `base_rv` (when - possible). - - """ - - def _node_check(n, i): - if n == "output": - n = fgraph.outputs[i].owner - return n == node or isinstance(n.op, (Shape, Shape_i)) - - return not all(_node_check(n, i) for n, i in fgraph.clients.get(base_rv, ())) - - -@local_optimizer([RandomVariable], inplace=True) -def random_make_inplace(fgraph, node): - op = node.op - - if isinstance(op, RandomVariable) and not op.inplace: - props = op._props_dict() - props["inplace"] = True - new_op = type(op)(**props) - return new_op.make_node(*node.inputs).outputs +import warnings - return False - -optdb.register( - "random_make_inplace", - in2out(random_make_inplace, ignore_newtrees=True), - "fast_run", - "inplace", - position=99, +warnings.warn( + "The module `aesara.tensor.random.opt` is deprecated; use `aesara.tensor.random.rewriting` instead.", + DeprecationWarning, + stacklevel=2, ) - -@local_optimizer(tracks=None) -def local_rv_size_lift(fgraph, node): - """Lift the ``size`` parameter in a ``RandomVariable``. - - In other words, this will broadcast the distribution parameters by adding - the extra dimensions implied by the ``size`` parameter, and remove the - ``size`` parameter in the process. - - For example, ``normal(0, 1, size=(1, 2))`` becomes - ``normal([[0, 0]], [[1, 1]], size=())``. - - """ - - if not isinstance(node.op, RandomVariable): - return - - rng, size, dtype, *dist_params = node.inputs - - dist_params = broadcast_params(dist_params, node.op.ndims_params) - - if get_vector_length(size) > 0: - dist_params = [ - broadcast_to( - p, - ( - tuple(size) - + ( - tuple(p.shape)[-node.op.ndims_params[i] :] - if node.op.ndims_params[i] > 0 - else () - ) - ) - if node.op.ndim_supp > 0 - else size, - ) - for i, p in enumerate(dist_params) - ] - else: - return - - new_node = node.op.make_node(rng, None, dtype, *dist_params) - - if config.compute_test_value != "off": - compute_test_value(new_node) - - return new_node.outputs - - -@local_optimizer([DimShuffle]) -def local_dimshuffle_rv_lift(fgraph, node): - """Lift a ``DimShuffle`` through ``RandomVariable`` inputs. - - For example, ``normal(mu, std).T == normal(mu.T, std.T)``. - - The basic idea behind this rewrite is that we need to separate the - ``DimShuffle``-ing into distinct ``DimShuffle``s that each occur in two - distinct sub-spaces: the (set of independent) parameters and ``size`` - (i.e. replications) sub-spaces. - - If a ``DimShuffle`` exchanges dimensions across those two sub-spaces, then we - don't do anything. - - Otherwise, if the ``DimShuffle`` only exchanges dimensions within each of - those sub-spaces, we can break it apart and apply the parameter-space - ``DimShuffle`` to the distribution parameters, and then apply the - replications-space ``DimShuffle`` to the ``size`` tuple. The latter is a - particularly simple rearranging of a tuple, but the former requires a - little more work. - - TODO: Currently, multivariate support for this rewrite is disabled. - - """ - - ds_op = node.op - - if not isinstance(ds_op, DimShuffle): - return False - - base_rv = node.inputs[0] - rv_node = base_rv.owner - - if not ( - rv_node and isinstance(rv_node.op, RandomVariable) and rv_node.op.ndim_supp == 0 - ): - return False - - # If no one else is using the underlying `RandomVariable`, then we can - # do this; otherwise, the graph would be internally inconsistent. - if is_rv_used_in_graph(base_rv, node, fgraph): - return False - - rv_op = rv_node.op - rng, size, dtype, *dist_params = rv_node.inputs - - # We need to know the dimensions that were *not* added by the `size` - # parameter (i.e. the dimensions corresponding to independent variates with - # different parameter values) - num_ind_dims = None - if len(dist_params) == 1: - num_ind_dims = dist_params[0].ndim - else: - # When there is more than one distribution parameter, assume that all - # of them will broadcast to the maximum number of dimensions - num_ind_dims = max(d.ndim for d in dist_params) - - # If the indices in `ds_new_order` are entirely within the replication - # indices group or the independent variates indices group, then we can apply - # this rewrite. - - ds_new_order = ds_op.new_order - # Create a map from old index order to new/`DimShuffled` index order - dim_orders = [(n, d) for n, d in enumerate(ds_new_order) if isinstance(d, int)] - - # Find the index at which the replications/independents split occurs - reps_ind_split_idx = len(dim_orders) - (num_ind_dims + rv_op.ndim_supp) - - ds_reps_new_dims = dim_orders[:reps_ind_split_idx] - ds_ind_new_dims = dim_orders[reps_ind_split_idx:] - ds_in_ind_space = ds_ind_new_dims and all( - d >= reps_ind_split_idx for n, d in ds_ind_new_dims - ) - - if ds_in_ind_space or (not ds_ind_new_dims and not ds_reps_new_dims): - - # Update the `size` array to reflect the `DimShuffle`d dimensions, - # since the trailing dimensions in `size` represent the independent - # variates dimensions (for univariate distributions, at least) - has_size = get_vector_length(size) > 0 - new_size = ( - [constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order] - if has_size - else size - ) - - # Compute the new axes parameter(s) for the `DimShuffle` that will be - # applied to the `RandomVariable` parameters (they need to be offset) - if ds_ind_new_dims: - rv_params_new_order = [ - d - reps_ind_split_idx if isinstance(d, int) else d - for d in ds_new_order[ds_ind_new_dims[0][0] :] - ] - - if not has_size and len(ds_new_order[: ds_ind_new_dims[0][0]]) > 0: - # Additional broadcast dimensions need to be added to the - # independent dimensions (i.e. parameters), since there's no - # `size` to which they can be added - rv_params_new_order = ( - list(ds_new_order[: ds_ind_new_dims[0][0]]) + rv_params_new_order - ) - else: - # This case is reached when, for example, `ds_new_order` only - # consists of new broadcastable dimensions (i.e. `"x"`s) - rv_params_new_order = ds_new_order - - # Lift the `DimShuffle`s into the parameters - # NOTE: The parameters might not be broadcasted against each other, so - # we can only apply the parts of the `DimShuffle` that are relevant. - new_dist_params = [] - for d in dist_params: - if d.ndim < len(ds_ind_new_dims): - _rv_params_new_order = [ - o - for o in rv_params_new_order - if (isinstance(o, int) and o < d.ndim) or o == "x" - ] - else: - _rv_params_new_order = rv_params_new_order - - new_dist_params.append( - type(ds_op)(d.type.broadcastable, _rv_params_new_order)(d) - ) - new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) - - if config.compute_test_value != "off": - compute_test_value(new_node) - - out = new_node.outputs[1] - if base_rv.name: - out.name = f"{base_rv.name}_lifted" - return [out] - - ds_in_reps_space = ds_reps_new_dims and all( - d < reps_ind_split_idx for n, d in ds_reps_new_dims - ) - - if ds_in_reps_space: - # Update the `size` array to reflect the `DimShuffle`d dimensions. - # There should be no need to `DimShuffle` now. - new_size = [ - constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order - ] - - new_node = rv_op.make_node(rng, new_size, dtype, *dist_params) - - if config.compute_test_value != "off": - compute_test_value(new_node) - - out = new_node.outputs[1] - if base_rv.name: - out.name = f"{base_rv.name}_lifted" - return [out] - - return False - - -@local_optimizer([Subtensor, AdvancedSubtensor1, AdvancedSubtensor]) -def local_subtensor_rv_lift(fgraph, node): - """Lift a ``*Subtensor`` through ``RandomVariable`` inputs. - - In a fashion similar to ``local_dimshuffle_rv_lift``, the indexed dimensions - need to be separated into distinct replication-space and (independent) - parameter-space ``*Subtensor``s. - - The replication-space ``*Subtensor`` can be used to determine a - sub/super-set of the replication-space and, thus, a "smaller"/"larger" - ``size`` tuple. The parameter-space ``*Subtensor`` is simply lifted and - applied to the distribution parameters. - - Consider the following example graph: - ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``. The - ``*Subtensor`` ``Op`` requests indices ``idx1``, ``idx2``, and ``idx3``, - which correspond to all three ``size`` dimensions. Now, depending on the - broadcasted dimensions of ``mu`` and ``std``, this ``*Subtensor`` ``Op`` - could be reducing the ``size`` parameter and/or sub-setting the independent - ``mu`` and ``std`` parameters. Only once the dimensions are properly - separated into the two replication/parameter subspaces can we determine how - the ``*Subtensor`` indices are distributed. - For instance, ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]`` - could become - ``normal(mu[idx1], std[idx2], size=np.shape(idx1) + np.shape(idx2) + np.shape(idx3))`` - if ``mu.shape == std.shape == ()`` - - ``normal`` is a rather simple case, because it's univariate. Multivariate - cases require a mapping between the parameter space and the image of the - random variable. This may not always be possible, but for many common - distributions it is. For example, the dimensions of the multivariate - normal's image can be mapped directly to each dimension of its parameters. - We use these mappings to change a graph like ``multivariate_normal(mu, Sigma)[idx1]`` - into ``multivariate_normal(mu[idx1], Sigma[idx1, idx1])``. - - """ - - st_op = node.op - - if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)): - return False - - base_rv = node.inputs[0] - - rv_node = base_rv.owner - if not (rv_node and isinstance(rv_node.op, RandomVariable)): - return False - - # If no one else is using the underlying `RandomVariable`, then we can - # do this; otherwise, the graph would be internally inconsistent. - if is_rv_used_in_graph(base_rv, node, fgraph): - return False - - rv_op = rv_node.op - rng, size, dtype, *dist_params = rv_node.inputs - - # TODO: Remove this once the multi-dimensional changes described below are - # in place. - if rv_op.ndim_supp > 0: - return False - - rv_op = base_rv.owner.op - rng, size, dtype, *dist_params = base_rv.owner.inputs - - idx_list = getattr(st_op, "idx_list", None) - if idx_list: - cdata = get_idx_list(node.inputs, idx_list) - else: - cdata = node.inputs[1:] - - st_indices, st_is_bool = zip( - *tuple( - (as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata - ) - ) - - # We need to separate dimensions into replications and independents - num_ind_dims = None - if len(dist_params) == 1: - num_ind_dims = dist_params[0].ndim - else: - # When there is more than one distribution parameter, assume that all - # of them will broadcast to the maximum number of dimensions - num_ind_dims = max(d.ndim for d in dist_params) - - reps_ind_split_idx = base_rv.ndim - (num_ind_dims + rv_op.ndim_supp) - - if len(st_indices) > reps_ind_split_idx: - # These are the indices that need to be applied to the parameters - ind_indices = tuple(st_indices[reps_ind_split_idx:]) - - # We need to broadcast the parameters before applying the `*Subtensor*` - # with these indices, because the indices could be referencing broadcast - # dimensions that don't exist (yet) - bcast_dist_params = broadcast_params(dist_params, rv_op.ndims_params) - - # TODO: For multidimensional distributions, we need a map that tells us - # which dimensions of the parameters need to be indexed. - # - # For example, `multivariate_normal` would have the following: - # `RandomVariable.param_to_image_dims = ((0,), (0, 1))` - # - # I.e. the first parameter's (i.e. mean's) first dimension maps directly to - # the dimension of the RV's image, and its second parameter's - # (i.e. covariance's) first and second dimensions map directly to the - # dimension of the RV's image. - - args_lifted = tuple(p[ind_indices] for p in bcast_dist_params) - else: - # In this case, no indexing is applied to the parameters; only the - # `size` parameter is affected. - args_lifted = dist_params - - # TODO: Could use `ShapeFeature` info. We would need to be sure that - # `node` isn't in the results, though. - # if hasattr(fgraph, "shape_feature"): - # output_shape = fgraph.shape_feature.shape_of(node.outputs[0]) - # else: - output_shape = indexed_result_shape(base_rv.shape, st_indices) - - size_lifted = ( - output_shape if rv_op.ndim_supp == 0 else output_shape[: -rv_op.ndim_supp] - ) - - # Boolean indices can actually change the `size` value (compared to just - # *which* dimensions of `size` are used). - if any(st_is_bool): - size_lifted = tuple( - at_sum(idx) if is_bool else s - for s, is_bool, idx in zip( - size_lifted, st_is_bool, st_indices[: (reps_ind_split_idx + 1)] - ) - ) - - new_node = rv_op.make_node(rng, size_lifted, dtype, *args_lifted) - _, new_rv = new_node.outputs - - # Calling `Op.make_node` directly circumvents test value computations, so - # we need to compute the test values manually - if config.compute_test_value != "off": - compute_test_value(new_node) - - return [new_rv] +from aesara.tensor.random.rewriting import * # noqa: F401 E402 F403 diff --git a/aesara/tensor/random/rewriting.py b/aesara/tensor/random/rewriting.py new file mode 100644 index 0000000000..c247eb0c7f --- /dev/null +++ b/aesara/tensor/random/rewriting.py @@ -0,0 +1,411 @@ +from aesara.compile import optdb +from aesara.configdefaults import config +from aesara.graph.op import compute_test_value +from aesara.graph.rewriting.basic import in2out, node_rewriter +from aesara.tensor.basic import constant, get_vector_length +from aesara.tensor.elemwise import DimShuffle +from aesara.tensor.extra_ops import broadcast_to +from aesara.tensor.math import sum as at_sum +from aesara.tensor.random.op import RandomVariable +from aesara.tensor.random.utils import broadcast_params +from aesara.tensor.shape import Shape, Shape_i +from aesara.tensor.subtensor import ( + AdvancedSubtensor, + AdvancedSubtensor1, + Subtensor, + as_index_variable, + get_idx_list, + indexed_result_shape, +) + + +def is_rv_used_in_graph(base_rv, node, fgraph): + """Determine whether or not `base_rv` is used by a node other than `node` in `fgraph`. + + If a node uses `Shape` or `Shape_i` on the `base_rv`, we ignore it, because + those `Op`s don't rely on the actual sample values of `base_rv`. + + TODO: We should apply all the shape rewrites before these rewrites, since + that would properly remove the unnecessary dependencies on `base_rv` (when + possible). + + """ + + def _node_check(n, i): + if n == "output": + n = fgraph.outputs[i].owner + return n == node or isinstance(n.op, (Shape, Shape_i)) + + return not all(_node_check(n, i) for n, i in fgraph.clients.get(base_rv, ())) + + +@node_rewriter([RandomVariable], inplace=True) +def random_make_inplace(fgraph, node): + op = node.op + + if isinstance(op, RandomVariable) and not op.inplace: + props = op._props_dict() + props["inplace"] = True + new_op = type(op)(**props) + return new_op.make_node(*node.inputs).outputs + + return False + + +optdb.register( + "random_make_inplace", + in2out(random_make_inplace, ignore_newtrees=True), + "fast_run", + "inplace", + position=99, +) + + +@node_rewriter(tracks=None) +def local_rv_size_lift(fgraph, node): + """Lift the ``size`` parameter in a ``RandomVariable``. + + In other words, this will broadcast the distribution parameters by adding + the extra dimensions implied by the ``size`` parameter, and remove the + ``size`` parameter in the process. + + For example, ``normal(0, 1, size=(1, 2))`` becomes + ``normal([[0, 0]], [[1, 1]], size=())``. + + """ + + if not isinstance(node.op, RandomVariable): + return + + rng, size, dtype, *dist_params = node.inputs + + dist_params = broadcast_params(dist_params, node.op.ndims_params) + + if get_vector_length(size) > 0: + dist_params = [ + broadcast_to( + p, + ( + tuple(size) + + ( + tuple(p.shape)[-node.op.ndims_params[i] :] + if node.op.ndims_params[i] > 0 + else () + ) + ) + if node.op.ndim_supp > 0 + else size, + ) + for i, p in enumerate(dist_params) + ] + else: + return + + new_node = node.op.make_node(rng, None, dtype, *dist_params) + + if config.compute_test_value != "off": + compute_test_value(new_node) + + return new_node.outputs + + +@node_rewriter([DimShuffle]) +def local_dimshuffle_rv_lift(fgraph, node): + """Lift a ``DimShuffle`` through ``RandomVariable`` inputs. + + For example, ``normal(mu, std).T == normal(mu.T, std.T)``. + + The basic idea behind this rewrite is that we need to separate the + ``DimShuffle``-ing into distinct ``DimShuffle``s that each occur in two + distinct sub-spaces: the (set of independent) parameters and ``size`` + (i.e. replications) sub-spaces. + + If a ``DimShuffle`` exchanges dimensions across those two sub-spaces, then we + don't do anything. + + Otherwise, if the ``DimShuffle`` only exchanges dimensions within each of + those sub-spaces, we can break it apart and apply the parameter-space + ``DimShuffle`` to the distribution parameters, and then apply the + replications-space ``DimShuffle`` to the ``size`` tuple. The latter is a + particularly simple rearranging of a tuple, but the former requires a + little more work. + + TODO: Currently, multivariate support for this rewrite is disabled. + + """ + + ds_op = node.op + + if not isinstance(ds_op, DimShuffle): + return False + + base_rv = node.inputs[0] + rv_node = base_rv.owner + + if not ( + rv_node and isinstance(rv_node.op, RandomVariable) and rv_node.op.ndim_supp == 0 + ): + return False + + # If no one else is using the underlying `RandomVariable`, then we can + # do this; otherwise, the graph would be internally inconsistent. + if is_rv_used_in_graph(base_rv, node, fgraph): + return False + + rv_op = rv_node.op + rng, size, dtype, *dist_params = rv_node.inputs + + # We need to know the dimensions that were *not* added by the `size` + # parameter (i.e. the dimensions corresponding to independent variates with + # different parameter values) + num_ind_dims = None + if len(dist_params) == 1: + num_ind_dims = dist_params[0].ndim + else: + # When there is more than one distribution parameter, assume that all + # of them will broadcast to the maximum number of dimensions + num_ind_dims = max(d.ndim for d in dist_params) + + # If the indices in `ds_new_order` are entirely within the replication + # indices group or the independent variates indices group, then we can apply + # this rewrite. + + ds_new_order = ds_op.new_order + # Create a map from old index order to new/`DimShuffled` index order + dim_orders = [(n, d) for n, d in enumerate(ds_new_order) if isinstance(d, int)] + + # Find the index at which the replications/independents split occurs + reps_ind_split_idx = len(dim_orders) - (num_ind_dims + rv_op.ndim_supp) + + ds_reps_new_dims = dim_orders[:reps_ind_split_idx] + ds_ind_new_dims = dim_orders[reps_ind_split_idx:] + ds_in_ind_space = ds_ind_new_dims and all( + d >= reps_ind_split_idx for n, d in ds_ind_new_dims + ) + + if ds_in_ind_space or (not ds_ind_new_dims and not ds_reps_new_dims): + + # Update the `size` array to reflect the `DimShuffle`d dimensions, + # since the trailing dimensions in `size` represent the independent + # variates dimensions (for univariate distributions, at least) + has_size = get_vector_length(size) > 0 + new_size = ( + [constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order] + if has_size + else size + ) + + # Compute the new axes parameter(s) for the `DimShuffle` that will be + # applied to the `RandomVariable` parameters (they need to be offset) + if ds_ind_new_dims: + rv_params_new_order = [ + d - reps_ind_split_idx if isinstance(d, int) else d + for d in ds_new_order[ds_ind_new_dims[0][0] :] + ] + + if not has_size and len(ds_new_order[: ds_ind_new_dims[0][0]]) > 0: + # Additional broadcast dimensions need to be added to the + # independent dimensions (i.e. parameters), since there's no + # `size` to which they can be added + rv_params_new_order = ( + list(ds_new_order[: ds_ind_new_dims[0][0]]) + rv_params_new_order + ) + else: + # This case is reached when, for example, `ds_new_order` only + # consists of new broadcastable dimensions (i.e. `"x"`s) + rv_params_new_order = ds_new_order + + # Lift the `DimShuffle`s into the parameters + # NOTE: The parameters might not be broadcasted against each other, so + # we can only apply the parts of the `DimShuffle` that are relevant. + new_dist_params = [] + for d in dist_params: + if d.ndim < len(ds_ind_new_dims): + _rv_params_new_order = [ + o + for o in rv_params_new_order + if (isinstance(o, int) and o < d.ndim) or o == "x" + ] + else: + _rv_params_new_order = rv_params_new_order + + new_dist_params.append( + type(ds_op)(d.type.broadcastable, _rv_params_new_order)(d) + ) + new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) + + if config.compute_test_value != "off": + compute_test_value(new_node) + + out = new_node.outputs[1] + if base_rv.name: + out.name = f"{base_rv.name}_lifted" + return [out] + + ds_in_reps_space = ds_reps_new_dims and all( + d < reps_ind_split_idx for n, d in ds_reps_new_dims + ) + + if ds_in_reps_space: + # Update the `size` array to reflect the `DimShuffle`d dimensions. + # There should be no need to `DimShuffle` now. + new_size = [ + constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order + ] + + new_node = rv_op.make_node(rng, new_size, dtype, *dist_params) + + if config.compute_test_value != "off": + compute_test_value(new_node) + + out = new_node.outputs[1] + if base_rv.name: + out.name = f"{base_rv.name}_lifted" + return [out] + + return False + + +@node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor]) +def local_subtensor_rv_lift(fgraph, node): + """Lift a ``*Subtensor`` through ``RandomVariable`` inputs. + + In a fashion similar to ``local_dimshuffle_rv_lift``, the indexed dimensions + need to be separated into distinct replication-space and (independent) + parameter-space ``*Subtensor``s. + + The replication-space ``*Subtensor`` can be used to determine a + sub/super-set of the replication-space and, thus, a "smaller"/"larger" + ``size`` tuple. The parameter-space ``*Subtensor`` is simply lifted and + applied to the distribution parameters. + + Consider the following example graph: + ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``. The + ``*Subtensor`` ``Op`` requests indices ``idx1``, ``idx2``, and ``idx3``, + which correspond to all three ``size`` dimensions. Now, depending on the + broadcasted dimensions of ``mu`` and ``std``, this ``*Subtensor`` ``Op`` + could be reducing the ``size`` parameter and/or sub-setting the independent + ``mu`` and ``std`` parameters. Only once the dimensions are properly + separated into the two replication/parameter subspaces can we determine how + the ``*Subtensor`` indices are distributed. + For instance, ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]`` + could become + ``normal(mu[idx1], std[idx2], size=np.shape(idx1) + np.shape(idx2) + np.shape(idx3))`` + if ``mu.shape == std.shape == ()`` + + ``normal`` is a rather simple case, because it's univariate. Multivariate + cases require a mapping between the parameter space and the image of the + random variable. This may not always be possible, but for many common + distributions it is. For example, the dimensions of the multivariate + normal's image can be mapped directly to each dimension of its parameters. + We use these mappings to change a graph like ``multivariate_normal(mu, Sigma)[idx1]`` + into ``multivariate_normal(mu[idx1], Sigma[idx1, idx1])``. + + """ + + st_op = node.op + + if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)): + return False + + base_rv = node.inputs[0] + + rv_node = base_rv.owner + if not (rv_node and isinstance(rv_node.op, RandomVariable)): + return False + + # If no one else is using the underlying `RandomVariable`, then we can + # do this; otherwise, the graph would be internally inconsistent. + if is_rv_used_in_graph(base_rv, node, fgraph): + return False + + rv_op = rv_node.op + rng, size, dtype, *dist_params = rv_node.inputs + + # TODO: Remove this once the multi-dimensional changes described below are + # in place. + if rv_op.ndim_supp > 0: + return False + + rv_op = base_rv.owner.op + rng, size, dtype, *dist_params = base_rv.owner.inputs + + idx_list = getattr(st_op, "idx_list", None) + if idx_list: + cdata = get_idx_list(node.inputs, idx_list) + else: + cdata = node.inputs[1:] + + st_indices, st_is_bool = zip( + *tuple( + (as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata + ) + ) + + # We need to separate dimensions into replications and independents + num_ind_dims = None + if len(dist_params) == 1: + num_ind_dims = dist_params[0].ndim + else: + # When there is more than one distribution parameter, assume that all + # of them will broadcast to the maximum number of dimensions + num_ind_dims = max(d.ndim for d in dist_params) + + reps_ind_split_idx = base_rv.ndim - (num_ind_dims + rv_op.ndim_supp) + + if len(st_indices) > reps_ind_split_idx: + # These are the indices that need to be applied to the parameters + ind_indices = tuple(st_indices[reps_ind_split_idx:]) + + # We need to broadcast the parameters before applying the `*Subtensor*` + # with these indices, because the indices could be referencing broadcast + # dimensions that don't exist (yet) + bcast_dist_params = broadcast_params(dist_params, rv_op.ndims_params) + + # TODO: For multidimensional distributions, we need a map that tells us + # which dimensions of the parameters need to be indexed. + # + # For example, `multivariate_normal` would have the following: + # `RandomVariable.param_to_image_dims = ((0,), (0, 1))` + # + # I.e. the first parameter's (i.e. mean's) first dimension maps directly to + # the dimension of the RV's image, and its second parameter's + # (i.e. covariance's) first and second dimensions map directly to the + # dimension of the RV's image. + + args_lifted = tuple(p[ind_indices] for p in bcast_dist_params) + else: + # In this case, no indexing is applied to the parameters; only the + # `size` parameter is affected. + args_lifted = dist_params + + # TODO: Could use `ShapeFeature` info. We would need to be sure that + # `node` isn't in the results, though. + # if hasattr(fgraph, "shape_feature"): + # output_shape = fgraph.shape_feature.shape_of(node.outputs[0]) + # else: + output_shape = indexed_result_shape(base_rv.shape, st_indices) + + size_lifted = ( + output_shape if rv_op.ndim_supp == 0 else output_shape[: -rv_op.ndim_supp] + ) + + # Boolean indices can actually change the `size` value (compared to just + # *which* dimensions of `size` are used). + if any(st_is_bool): + size_lifted = tuple( + at_sum(idx) if is_bool else s + for s, is_bool, idx in zip( + size_lifted, st_is_bool, st_indices[: (reps_ind_split_idx + 1)] + ) + ) + + new_node = rv_op.make_node(rng, size_lifted, dtype, *args_lifted) + _, new_rv = new_node.outputs + + # Calling `Op.make_node` directly circumvents test value computations, so + # we need to compute the test values manually + if config.compute_test_value != "off": + compute_test_value(new_node) + + return [new_rv] diff --git a/aesara/tensor/rewriting/__init__.py b/aesara/tensor/rewriting/__init__.py new file mode 100644 index 0000000000..316fe6fd21 --- /dev/null +++ b/aesara/tensor/rewriting/__init__.py @@ -0,0 +1,7 @@ +import aesara.tensor.rewriting.basic +import aesara.tensor.rewriting.elemwise +import aesara.tensor.rewriting.extra_ops +import aesara.tensor.rewriting.math +import aesara.tensor.rewriting.shape +import aesara.tensor.rewriting.subtensor +import aesara.tensor.rewriting.uncanonicalize diff --git a/aesara/tensor/rewriting/basic.py b/aesara/tensor/rewriting/basic.py new file mode 100644 index 0000000000..7cb095a346 --- /dev/null +++ b/aesara/tensor/rewriting/basic.py @@ -0,0 +1,1301 @@ +""" Tensor optimizations addressing the ops in basic.py.""" + +import logging +from typing import TYPE_CHECKING, Optional, Union + +import numpy as np + +import aesara.scalar.basic as aes +from aesara import compile +from aesara.compile.ops import ViewOp +from aesara.graph.basic import Constant, Variable +from aesara.graph.rewriting.basic import ( + NodeRewriter, + RemovalNodeRewriter, + Rewriter, + copy_stack_trace, + in2out, + node_rewriter, +) +from aesara.graph.rewriting.db import RewriteDatabase +from aesara.raise_op import Assert, CheckAndRaise, assert_op +from aesara.tensor.basic import ( + Alloc, + AllocEmpty, + Join, + MakeVector, + ScalarFromTensor, + Split, + TensorFromScalar, + alloc, + as_tensor_variable, + cast, + extract_constant, + fill, + get_scalar_constant_value, + join, + ones_like, + switch, + tensor_copy, + zeros, + zeros_like, +) +from aesara.tensor.elemwise import DimShuffle, Elemwise +from aesara.tensor.exceptions import NotScalarConstantError +from aesara.tensor.extra_ops import broadcast_shape, broadcast_to +from aesara.tensor.math import all as at_all +from aesara.tensor.math import eq +from aesara.tensor.shape import Shape_i +from aesara.tensor.sort import TopKOp +from aesara.tensor.type import DenseTensorType, TensorType +from aesara.tensor.var import TensorConstant +from aesara.utils import NoDuplicateOptWarningFilter + + +if TYPE_CHECKING: + from aesara.tensor.rewriting.shape import ShapeFeature + + +_logger = logging.getLogger("aesara.tensor.rewriting.basic") +_logger.addFilter(NoDuplicateOptWarningFilter()) + + +def encompasses_broadcastable(b1, b2): + """ + + Parameters + ---------- + b1 + The broadcastable attribute of a tensor type. + b2 + The broadcastable attribute of a tensor type. + + Returns + ------- + bool + True if the broadcastable patterns b1 and b2 are such that b2 is + broadcasted to b1's shape and not the opposite. + + """ + if len(b1) < len(b2): + return False + b1 = b1[-len(b2) :] + return not any(v1 and not v2 for v1, v2 in zip(b1, b2)) + + +def merge_broadcastables(broadcastables): + return [all(bcast) for bcast in zip(*broadcastables)] + + +def broadcast_like(value, template, fgraph, dtype=None): + """ + Return a Variable with the same shape and dtype as the template, + filled by broadcasting value through it. `value` will be cast as + necessary. + + """ + value = as_tensor_variable(value) + if value.type.is_super(template.type): + return value + if template not in fgraph.variables: + raise NotImplementedError( + "broadcast_like currently requires the " + "template Variable to be in the fgraph already" + ) + if dtype is None: + dtype = template.dtype + value = cast(value, dtype) + if value.type.is_super(template.type): + return value + if hasattr(fgraph, "shape_feature"): + new_shape = fgraph.shape_feature.shape_of[template] + else: + new_shape = template.shape + rval = alloc(value, *new_shape) + assert rval.type.dtype == dtype + + return rval + + +def register_useless( + node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): + return register_useless(inner_rewriter, node_rewriter, *tags, **kwargs) + + return register + else: + name = kwargs.pop("name", None) or node_rewriter.__name__ + + compile.mode.local_useless.register( + name, node_rewriter, "fast_run", *tags, position="last", **kwargs + ) + return node_rewriter + + +def register_canonicalize( + node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): + return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs) + + return register + else: + name = kwargs.pop("name", None) or node_rewriter.__name__ + compile.optdb["canonicalize"].register( + name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs + ) + return node_rewriter + + +def register_stabilize( + node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): + return register_stabilize(inner_rewriter, node_rewriter, *tags, **kwargs) + + return register + else: + name = kwargs.pop("name", None) or node_rewriter.__name__ + compile.optdb["stabilize"].register( + name, node_rewriter, "fast_run", *tags, **kwargs + ) + return node_rewriter + + +def register_specialize( + node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): + return register_specialize(inner_rewriter, node_rewriter, *tags, **kwargs) + + return register + else: + name = kwargs.pop("name", None) or node_rewriter.__name__ + compile.optdb["specialize"].register( + name, node_rewriter, "fast_run", *tags, **kwargs + ) + return node_rewriter + + +def register_uncanonicalize( + node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): + return register_uncanonicalize( + inner_rewriter, node_rewriter, *tags, **kwargs + ) + + return register + else: + name = (kwargs and kwargs.pop("name", None)) or node_rewriter.__name__ + compile.optdb["uncanonicalize"].register( + name, node_rewriter, "fast_run", *tags, **kwargs + ) + return node_rewriter + + +def register_specialize_device( + node_rewriter: Union[RewriteDatabase, Rewriter, str], *tags: str, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): + return register_specialize_device( + inner_rewriter, node_rewriter, *tags, **kwargs + ) + + return register + else: + name = (kwargs and kwargs.pop("name", None)) or node_rewriter.__name__ + compile.optdb["specialize_device"].register( + name, node_rewriter, "fast_run", *tags, **kwargs + ) + return node_rewriter + + +@register_canonicalize +@register_specialize +@node_rewriter([TensorFromScalar]) +def local_tensor_scalar_tensor(fgraph, node): + """tensor_from_scalar(scalar_from_tensor(x)) -> x""" + if isinstance(node.op, TensorFromScalar): + s = node.inputs[0] + if s.owner and isinstance(s.owner.op, ScalarFromTensor): + t = s.owner.inputs[0] + + # We don't need to copy over any stack traces here + return [t] + + +@register_canonicalize +@register_specialize +@node_rewriter([ScalarFromTensor]) +def local_scalar_tensor_scalar(fgraph, node): + """scalar_from_tensor(tensor_from_scalar(x)) -> x""" + if isinstance(node.op, ScalarFromTensor): + t = node.inputs[0] + if t.owner and isinstance(t.owner.op, TensorFromScalar): + s = t.owner.inputs[0] + + # We don't need to copy over any stack traces here + return [s] + + +@register_specialize("local_alloc_elemwise") +@node_rewriter([Elemwise]) +def local_elemwise_alloc(fgraph, node): + r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s. + + `Alloc`\s are effectively a type of `Elemwise` operation + (e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so + this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to + `Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it + broadcasts). + + In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant + `Alloc`\s. + + The rewrite essentially performs the following replacement: + ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``, + when ``y.shape`` for some input ``y`` (or the combined shapes of the + non-`Alloc`\s) is sufficient to maintain the same/correct output shape. + + In it's current form, it also explicitly accounts for `DimShuffle`\s of + `Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which + introduces them as a canonicalization of `Alloc`'s with leading + broadcastable dimensions. + """ + # Rewrite is only applicable when there are at least two inputs + if len(node.inputs) == 1: + return False + + if len(node.outputs) > 1: + return False + + def dimshuffled_alloc(i): + return ( + isinstance(i.owner.op, DimShuffle) + and i.owner.inputs[0].owner + and isinstance(i.owner.inputs[0].owner.op, Alloc) + ) + + # At least one input must have an owner that is either a `Alloc` or a + # `DimShuffle` with an owner that is a `Alloc` -- otherwise there is + # nothing to optimize. + alloc_idxs = [ + idx + for idx, i in enumerate(node.inputs) + if i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)) + ] + if len(alloc_idxs) == 0: + return False + + # Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a + # baseline for the dimensions. + ref_var_idx = None + for idx, i in enumerate(node.inputs): + if i.type.broadcastable == node.outputs[0].type.broadcastable: + # Prefer an input that is not an `Alloc` nor a `DimShuffle` of an + # `Alloc`, so that all `Alloc`s can be rewritten. + if idx not in alloc_idxs: + ref_var_idx = idx + break + + # If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one + if ref_var_idx is None: + for idx, i in enumerate(node.inputs): + # XXX: This broadcastable comparison doesn't work + if ( + i.type.broadcastable == node.outputs[0].type.broadcastable + ) and idx in alloc_idxs: + ref_var_idx = idx + break + + if not hasattr(fgraph, "shape_feature"): + return False + + input_shapes = [ + tuple(fgraph.shape_feature.get_shape(i, j) for j in range(i.type.ndim)) + for i in node.inputs + ] + bcasted_shape = broadcast_shape( + *input_shapes, + arrays_are_shapes=True, + ) + + new_inputs = list(node.inputs) + for idx in alloc_idxs: + i = node.inputs[idx] + + # Remove `Alloc` + if isinstance(i.owner.op, Alloc): + new_alloc = broadcast_to(i.owner.inputs[0], bcasted_shape) + + # TODO FIXME: This shouldn't be handled here. + # `DimShuffle`s should be lifted through `Alloc`s + # by other, more general rewrites. + # Remove `Alloc` in `DimShuffle` + elif isinstance(i.owner.op, DimShuffle): + old_alloc = i.owner.inputs[0] + new_alloc = old_alloc.owner.inputs[0] + # We need to keep the old `DimShuffle`. It could swap axes or + # add dimensions anywhere. + if new_alloc.ndim != old_alloc.ndim: + # The `Alloc` can add dimensions to the value. + # We replace those cases with a `DimShuffle` here. + nb_dim_to_add = old_alloc.ndim - new_alloc.ndim + new_alloc = new_alloc.dimshuffle( + ["x"] * nb_dim_to_add + list(range(new_alloc.ndim)) + ) + new_alloc = broadcast_to(i.owner.op(new_alloc), bcasted_shape) + + copy_stack_trace(i, new_alloc) + new_inputs[idx] = new_alloc + + # If this assert is triggered, it means we are recreating an equivalent graph + # which would result in cyclical merge rewrites. + if all(new is old for new, old in zip(new_inputs, node.inputs)): + return + + ret = node.op(*new_inputs, return_list=True) + copy_stack_trace(node.outputs, ret) + return ret + + +@register_canonicalize +@node_rewriter([Elemwise]) +def local_fill_sink(fgraph, node): + """ + f(fill(a, b), fill(c, d), e) -> fill(c, fill(a, f(b, d, e))) + f need to be an elemwise that isn't a fill. + """ + if not hasattr(node, "op") or not isinstance(node.op, Elemwise) or node.op == fill: + return False + models = [] + inputs = [] + for inp in node.inputs: + if inp.owner and inp.owner.op == fill: + models.append(inp.owner.inputs[0]) + inputs.append(inp.owner.inputs[1]) + else: + inputs.append(inp) + if not models: + return False + c = node.op(*inputs) + for model in models: + if ( + model.type.dtype != c.type.dtype + or model.type.broadcastable != c.type.broadcastable + ): + c = fill(model, c) + + # The newly created node c doesn't has 'clients', + # so this iteration is took place with node.outputs[0] + replacements = {node.outputs[0]: c} + for client, cl_idx in fgraph.clients[node.outputs[0]]: + if ( + hasattr(client, "op") + and isinstance(client.op, Elemwise) + and client.op != fill + ): + client_inputs = client.inputs[:] + client_inputs[cl_idx] = c + new_client = client.op(*client_inputs) + + # Add clients to new_client + fgraph.clients[new_client.owner.outputs[0]] = fgraph.clients[ + client.outputs[0] + ] + r = local_fill_sink.transform(fgraph, new_client.owner) + if not r: + continue + replacements.update(r) + return replacements + + +@register_specialize +@register_stabilize +@node_rewriter([fill]) +def local_fill_to_alloc(fgraph, node): + r"""Remove `fill`\s or replace them with `Alloc`\s. + + `Alloc`\s are preferable because they replace explicit tensor dependencies + with their dependencies on those tensors' shapes, and sometimes those + shapes can be computed without needing to compute the tensors themselves. + + XXX: This rewrite can produce inconsistent results, so do *not* consider + making it a canonicalization until those inconsistencies are + resolved/justified. + """ + shape_ref, values_ref = node.inputs + out_type = node.outputs[0].type + + if values_ref.type.broadcastable == out_type.broadcastable: + # The assumption here is that `values_ref` already has the same shape + # as `shape_ref`, so a `fill`/`Alloc` is unnecessary. + + # XXX FIXME TODO: The only way this can be determined is if one + # absolutely knows that the shapes of `shape_ref` and `values_ref` are + # equal. + # This is an old rewrite, and it's only a + # "specialization/stabilization", so we're going to leave it be for + # now. + return [values_ref] + + if shape_ref.type.broadcastable == out_type.broadcastable: + # In this case, we assume that some broadcasting is needed (otherwise + # the condition above would've been true), so we replace the `fill` + # with an `Alloc`. + o = broadcast_like(values_ref, shape_ref, fgraph, dtype=values_ref.dtype) + copy_stack_trace(node.outputs[0], o) + return [o] + + return + + +# Register this after stabilize at 1.5 to make sure stabilize don't +# get affected by less canonicalized graph due to alloc. +compile.optdb.register( + "local_fill_to_alloc", in2out(local_fill_to_alloc), "fast_run", position=1.51 +) +# Needed to clean some extra alloc added by local_fill_to_alloc +compile.optdb.register( + "local_elemwise_alloc", in2out(local_elemwise_alloc), "fast_run", position=1.52 +) + + +@register_canonicalize("fast_compile") +@register_useless +@node_rewriter([fill]) +def local_useless_fill(fgraph, node): + """fill(s,v) -> v + + This rewrite is only needed in FAST_COMPILE mode to make the code + more readable. Normally, it is done by the `local_fill_to_alloc` + rewrite. + + """ + r, v = node.inputs + out_type = node.outputs[0].type + + if ( + v.type.dtype == out_type.dtype + and v.type.broadcastable == out_type.broadcastable + ): + return [v] + + +@register_specialize +@register_stabilize +@register_canonicalize +@register_useless +@node_rewriter([Alloc]) +def local_useless_alloc(fgraph, node): + """ + If the input type is the same as the output type (dtype and broadcast) + there is no change in the shape of the input. So this is just a simple copy + of the input. This is not needed. + """ + if not isinstance(node.op, Alloc): + return False + + inp = node.inputs[0] + output = node.outputs[0] + + if ( + inp.type.dtype == output.type.dtype + and inp.type.broadcastable == output.type.broadcastable + ): + if inp.ndim == 0: + return [inp] + else: + return [ + Assert("Shapes must be equal")( + inp, at_all(eq(inp.shape, node.inputs[1:])) + ) + ] + + +@register_specialize +@register_stabilize +@register_canonicalize +@node_rewriter([Alloc]) +def local_alloc_sink_dimshuffle(fgraph, node): + r"""Convert broadcastable leading dimensions in an `Alloc` to `DimShuffle`\s.""" + op = node.op + if not isinstance(op, Alloc): + return False + + inp = node.inputs[0] + output = node.outputs[0] + + # Check if alloc adds a broadcastable dimension with shape 1. + output_shape = node.inputs[1:] + num_dims_with_size_1_added_to_left = 0 + for i in range(len(output_shape) - inp.ndim): + if extract_constant(output_shape[i], only_process_constants=True) == 1: + num_dims_with_size_1_added_to_left += 1 + else: + break + + new_output_shape = output_shape[num_dims_with_size_1_added_to_left:] + if num_dims_with_size_1_added_to_left > 0 and len(new_output_shape) >= inp.ndim: + if ( + output.broadcastable[num_dims_with_size_1_added_to_left:] + == inp.broadcastable + ): + inner = inp + else: + inner = op(*([inp] + new_output_shape)) + dimshuffle_new_order = ["x"] * num_dims_with_size_1_added_to_left + list( + range(len(new_output_shape)) + ) + return [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] + + +@node_rewriter([AllocEmpty]) +def local_alloc_empty_to_zeros(fgraph, node): + """This convert AllocEmpty to Alloc of 0. + + This helps one investigate NaNs in `NanGuardMode`. Not registered by + default. To activate it, use the setting + ``optimizer_including == alloc_empty_to_zeros``. + """ + if isinstance(node.op, AllocEmpty): + return [zeros(node.inputs, dtype=node.outputs[0].dtype)] + + +compile.optdb.register( + "local_alloc_empty_to_zeros", + in2out(local_alloc_empty_to_zeros), + # After move to gpu and merge2, before inplace. + "alloc_empty_to_zeros", + position=49.3, +) + + +@register_useless +@register_canonicalize("fast_compile") +@register_specialize +@node_rewriter([Elemwise]) +def local_useless_elemwise(fgraph, node): + """ + eq(x, x) -> 1 + neq(x, x) -> 0 + mul(x) -> x + add(x) -> x + identity(x) -> x + and(x, 1) -> x (if x.dtype == 'bool') + and(x, 0) -> zeros_like(x) + or(x, 0) -> x + or(x, 1) -> ones_like(x) (if x.dtype == 'bool') + xor(x, x) -> zeros_like(x) + + TODO: This implementation is painfully redundant. + + """ + if isinstance(node.op, Elemwise): + # We call zeros_like and one_like with opt=True to generate a + # cleaner graph. + dtype = node.outputs[0].dtype + + if node.op.scalar_op == aes.eq and len(node.inputs) == 2: + if node.inputs[0] == node.inputs[1]: + # it is the same var in the graph. That will always be true + ret = ones_like(node.inputs[0], dtype=dtype, opt=True) + + # Copy stack trace from input to constant output + copy_stack_trace(node.outputs[0], ret) + return [ret] + elif node.op.scalar_op == aes.neq and len(node.inputs) == 2: + if node.inputs[0] == node.inputs[1]: + # it is the same var in the graph. That will always be false + ret = zeros_like(node.inputs[0], dtype=dtype, opt=True) + + # Copy stack trace from input to constant output + copy_stack_trace(node.outputs[0], ret) + return [ret] + + elif node.op.scalar_op == aes.mul and len(node.inputs) == 1: + # No need to copy over any stack trace + return [node.inputs[0]] + + elif node.op.scalar_op == aes.add and len(node.inputs) == 1: + # No need to copy over any stack trace + return [node.inputs[0]] + elif node.op.scalar_op == aes.identity and len(node.inputs) == 1: + return [node.inputs[0]] + + elif isinstance(node.op.scalar_op, aes.AND) and len(node.inputs) == 2: + + if isinstance(node.inputs[0], TensorConstant): + const_val = extract_constant( + node.inputs[0], only_process_constants=True + ) + if not isinstance(const_val, Variable): + if const_val == 0: + return [zeros_like(node.inputs[1], dtype=dtype, opt=True)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise AND, + # and this rewrite would be wrong + return [node.inputs[1].astype(node.outputs[0].dtype)] + + if isinstance(node.inputs[1], TensorConstant): + const_val = extract_constant( + node.inputs[1], only_process_constants=True + ) + if not isinstance(const_val, Variable): + if const_val == 0: + return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise AND, + # and this rewrite would be wrong + return [node.inputs[0].astype(node.outputs[0].dtype)] + + elif isinstance(node.op.scalar_op, aes.OR) and len(node.inputs) == 2: + + if isinstance(node.inputs[0], TensorConstant): + const_val = extract_constant( + node.inputs[0], only_process_constants=True + ) + if not isinstance(const_val, Variable): + if const_val == 0: + return [node.inputs[1].astype(node.outputs[0].dtype)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise OR, + # and this rewrite would be wrong + return [ones_like(node.inputs[1], dtype=dtype, opt=True)] + + if isinstance(node.inputs[1], TensorConstant): + const_val = extract_constant( + node.inputs[1], only_process_constants=True + ) + if not isinstance(const_val, Variable): + if const_val == 0: + return [node.inputs[0].astype(node.outputs[0].dtype)] + elif node.outputs[0].dtype == "bool": + # If the output is not Boolean, it is the bitwise OR, + # and this rewrite would be wrong + return [ones_like(node.inputs[0], dtype=dtype, opt=True)] + + elif isinstance(node.op.scalar_op, aes.XOR) and len(node.inputs) == 2: + if node.inputs[0] is node.inputs[1]: + return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] + + +@register_specialize +@node_rewriter([Elemwise]) +def local_alloc_unary(fgraph, node): + """unary(alloc(x, shp)) -> alloc(unary(x), shp)""" + if isinstance(node.op, Elemwise) and len(node.inputs) == 1: + a = node.inputs[0] + if a.owner and isinstance(a.owner.op, Alloc): + x = a.owner.inputs[0] + shp = a.owner.inputs[1:] + v = node.op(x) + # at.alloc does not preserve the stacktrace of v, + # so we need to copy it over from x. + copy_stack_trace(node.outputs[0], v) + ret = alloc(cast(v, node.outputs[0].dtype), *shp) + + # at.cast does not preserve the stacktrace of x, + # so we need to copy it over to the output. + copy_stack_trace([node.outputs[0], a], ret) + return [ret] + + +@register_canonicalize +@register_specialize +@node_rewriter([Elemwise]) +def local_cast_cast(fgraph, node): + """cast(cast(x, dtype1), dtype2) + + when those contrain: + dtype1 == dtype2 + OR the base dtype is the same (int, uint, float, complex) + and the first cast cause an upcast. + + """ + if not isinstance(node.op, Elemwise) or not isinstance(node.op.scalar_op, aes.Cast): + return + x = node.inputs[0] + if ( + not x.owner + or not isinstance(x.owner.op, Elemwise) + or not isinstance(x.owner.op.scalar_op, aes.Cast) + ): + return + + type1 = x.owner.op.scalar_op.o_type + type2 = node.op.scalar_op.o_type + base = x.owner.inputs[0] + + if type1 == type2: + # We don't need to copy over any stack traces here + return [x] + + if is_an_upcast(base.dtype, type1.dtype): + # Checking for further redundancy. Eg: int8 -> int32 -> int8 + if type2.dtype == base.dtype: + return x.owner.inputs + else: + # Apply the second cast only + v = node.op(base) + # Copy stack trace from the output of the original cast + copy_stack_trace(node.outputs[0], v) + return [v] + + +def is_an_upcast(type1, type2): + """Given two data types (as strings), check if converting to + type2 from type1 constitutes an upcast. + Differs from aesara.scalar.upcast + + """ + category = { + # The first number in the pair is the dtype (bool, uint, int, float, + # complex). Conversion from higher to lower is never an upcast. + # The second number roughly indicates the precision. Again, conversion + # from higher to lower is never an upcast. + "bool": (0, 0), + "uint8": (1, 1), + "uint16": (1, 2), + "uint32": (1, 3), + "uint64": (1, 4), + "int8": (2, 1), + "int16": (2, 2), + "int32": (2, 3), + "int64": (2, 4), + "float16": (3, 1.5), + "float32": (3, 2.5), + "float64": (3, 3.5), + "complex64": (4, 3), + "complex128": (4, 4), + } + + cat1 = category[type1] + cat2 = category[type2] + + if cat2[0] >= cat1[0] and cat2[1] > cat1[1]: + return True + else: + return False + + +@register_useless +@register_specialize +@node_rewriter(None) +def local_remove_useless_assert(fgraph, node): + if not isinstance(node.op, CheckAndRaise): + return False + + new_conds = [] + n_conds = len(node.inputs[1:]) + for c in node.inputs[1:]: + try: + const = get_scalar_constant_value(c) + + if 0 != const.ndim or const == 0: + # Should we raise an error here? How to be sure it + # is not caught? + new_conds.append(c) + except NotScalarConstantError: + new_conds.append(c) + + if len(new_conds) == 0: + return [node.inputs[0]] + + if len(new_conds) < n_conds: + new_var = node.op(*(node.inputs[:1] + new_conds)) + copy_stack_trace(node.outputs[0], new_var) + return [new_var] + + +@node_rewriter([Assert]) +def local_remove_all_assert(fgraph, node): + r"""A rewrite that removes all `Assert`\s from a graph. + + Notes + ----- + See the :ref:`unsafe` section. + + """ + if not isinstance(node.op, Assert): + return + + return [node.inputs[0]] + + +compile.optdb["canonicalize"].register( + "local_remove_all_assert", + local_remove_all_assert, + "unsafe", + use_db_name_as_tag=False, +) +compile.optdb["stabilize"].register( + "local_remove_all_assert", + local_remove_all_assert, + "unsafe", + use_db_name_as_tag=False, +) +compile.optdb["specialize"].register( + "local_remove_all_assert", + local_remove_all_assert, + "unsafe", + use_db_name_as_tag=False, +) +compile.optdb["useless"].register( + "local_remove_all_assert", + local_remove_all_assert, + "unsafe", + use_db_name_as_tag=False, +) + + +@register_specialize +@register_canonicalize +@register_useless +@node_rewriter([Join]) +def local_join_1(fgraph, node): + """Join(i, x) => x + + Remove Join() when only one element is joined. + + """ + if not isinstance(node.op, Join): + return + tensors = node.inputs[1:] + if len(tensors) == 1: + # We don't need to copy over any stacktrace here, because the + # input variable should already have its own stacktrace. + return [tensors[0]] + + +# TODO: merge in local_useless_join +@register_useless +@register_specialize +@register_canonicalize +@node_rewriter([Join]) +def local_join_empty(fgraph, node): + """Join(i, x, y, empty) => Join(i, x, y) + + Remove empty inputs to joins. The empty inputs can be anywhere. + + """ + if not isinstance(node.op, Join): + return + new_inputs = [] + try: + join_idx = get_scalar_constant_value( + node.inputs[0], only_process_constants=True + ) + except NotScalarConstantError: + return + for idx in range(1, len(node.inputs)): + inp = node.inputs[idx] + # We can not use size == 0,, as this can change shape from 3,0 + # to 2,0. This trigger DebugMode error. This happen with + # stack(...,[]) as this add a dimshuffle on [], that add a + # dimensions with shape 1. + if isinstance(inp, Constant) and inp.data.shape[join_idx] == 0: + continue + new_inputs.append(inp) + if len(new_inputs) < len(node.inputs) - 1: + if len(new_inputs) == 0: + # at.join do not work in that case. + # constant folding will take care of this case. + return + ret = join(node.inputs[0], *new_inputs) + o = node.outputs[0] + if ret.dtype != o.dtype: + # Join can upcast some inputs + return + + # Copy over stacktrace from previous output (after join op) + # to new output, because an error in the new op must be caused + # by an error in the old join op. + copy_stack_trace(node.outputs, ret) + + return [ret] + + +@register_specialize +@register_canonicalize +@register_useless +@node_rewriter([Join]) +def local_join_make_vector(fgraph, node): + r"""Merge `MakeVector` inputs within a `Join`. + + For example: + + Join(0, make_vector1, make_vector2, ...) => Join(0, make_vector12, ...) + + This, in combination with the `local_join_1` rewrite, can make `Join`\s + completely disappear. + """ + if not isinstance(node.op, Join) or node.outputs[0].ndim != 1: + return + new_inputs = [node.inputs[1]] + for idx in range(2, len(node.inputs)): + inp = node.inputs[idx] + if ( + inp.owner + and isinstance(inp.owner.op, MakeVector) + and new_inputs[-1].owner + and isinstance(new_inputs[-1].owner.op, MakeVector) + and + # MakeVector have a dtype parameter + inp.owner.op == new_inputs[-1].owner.op + ): + inps = new_inputs[-1].owner.inputs + inp.owner.inputs + new_inputs[-1] = inp.owner.op(*inps) + + # Copy over stacktrace from previous output (after join op) + # to new intermediate output, because an error in the intermediate + # op must be caused by an error in the old join op. + copy_stack_trace(node.outputs, new_inputs[-1]) + else: + new_inputs.append(inp) + if len(new_inputs) < len(node.inputs) - 1: + ret = join(node.inputs[0], *new_inputs) + + # Copy over stacktrace from previous output (after join op) + # to new output, because an error in the new op must be caused + # by an error in the old join op. + copy_stack_trace(node.outputs, ret) + return [ret] + + +@register_useless("local_remove_switch_const_cond") +@register_canonicalize("fast_compile", "local_remove_switch_const_cond") +@register_specialize +@node_rewriter([Elemwise]) +def local_useless_switch(fgraph, node): + """ + This rewrite makes the following changes in a graph: + + at.switch(cond, left, right) -> + if cond is constant and cond == 0: right + if cond is constant and cond != 0: left + if left is right -> left + + and + + at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) + + """ + if not isinstance(node.op.scalar_op, aes.Switch): + return False + + shape_feature: Optional["ShapeFeature"] = getattr(fgraph, "shape_feature", None) + + if shape_feature is None: + return False + + left = node.inputs[1] + right = node.inputs[2] + cond_var = node.inputs[0] + cond = extract_constant(cond_var, only_process_constants=True) + + if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance( + cond, (np.number, np.bool_) + ): + if cond == 0: + correct_out = right + else: + correct_out = left + + if correct_out.dtype != node.outputs[0].dtype: + out = cast(correct_out, node.outputs[0].dtype) + else: + out = correct_out + + input_shapes = [ + tuple(shape_feature.get_shape(inp, i) for i in range(inp.type.ndim)) + for inp in node.inputs + ] + + out_shape = broadcast_shape(*input_shapes, arrays_are_shapes=True) + + out = alloc(out, *out_shape) + + # Copy over stacktrace from selected output to new output + copy_stack_trace(node.outputs + correct_out, out) + return [out] + + # if left is right -> left + if left == right: + # Note: No need to copy over stacktrace, because the input node + # already has its own stacktrace + if cond.type.is_super(left.type): + return [left] + + ret = fill(cond, left) + + # Copy over stacktrace from switch output and correct branch + copy_stack_trace(node.outputs + left, ret) + return [ret] + + # This case happens with scan. + # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) + if ( + cond_var.owner + and isinstance(cond_var.owner.op, Elemwise) + and isinstance(cond_var.owner.op.scalar_op, aes.LE) + and cond_var.owner.inputs[0].owner + and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) + and extract_constant(cond_var.owner.inputs[1], only_process_constants=True) == 0 + and extract_constant(left, only_process_constants=True) == 0 + and right == cond_var.owner.inputs[0] + ): + assert node.outputs[0].type.is_super(right.type) + # No need to copy over stacktrace, because the right input node + # already has its own stacktrace + return [right] + + +@register_canonicalize +@node_rewriter([Elemwise]) +def local_merge_switch_same_cond(fgraph, node): + """ + Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same + condition, to enable further simplification of their branches + Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y) + """ + # node must be binary elemwise or add or mul + if not isinstance(node.op, Elemwise) or not isinstance( + node.op.scalar_op, (aes.BinaryScalarOp, aes.Add, aes.Mul) + ): + return + # all inputs must be switch + if not all( + s.owner + and isinstance(s.owner.op, Elemwise) + and isinstance(s.owner.op.scalar_op, aes.Switch) + for s in node.inputs + ): + return + # all switch conditions must be the same + cond = node.inputs[0].owner.inputs[0] + if not all(s.owner.inputs[0] is cond for s in node.inputs[1:]): + return + # pull out switch + return [ + switch( + cond, + node.op(*[s.owner.inputs[1] for s in node.inputs]), + node.op(*[s.owner.inputs[2] for s in node.inputs]), + ) + ] + + +@register_useless +@register_canonicalize +@register_specialize +@node_rewriter([Split]) +def local_useless_split(fgraph, node): + """Split{n_splits=1}(x, y) -> x + + Remove Split with only 1 split. + + """ + if isinstance(node.op, Split): + if node.op.len_splits == 1: + x, axis, splits = node.inputs + out = assert_op(x, eq(splits.shape[0], 1)) + # Copy over stacktrace from previous output node. + copy_stack_trace(node.outputs, out) + out2 = assert_op(out, eq(x.shape[axis], splits[0])) + # Copy over stacktrace from previous output node. + copy_stack_trace(out, out2) + + return [out2] + + +@node_rewriter(None) +def constant_folding(fgraph, node): + + if not node.op.do_constant_folding(fgraph, node): + return False + + if not all(isinstance(inp, Constant) for inp in node.inputs): + return False + + storage_map = {i: [i.data] for i in node.inputs} + compute_map = {i: [True] for i in node.inputs} + for o in node.outputs: + storage_map[o] = [None] + compute_map[o] = [False] + + thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[]) + required = thunk() + + # A node whose inputs are all provided should always return successfully + assert not required + + rval = [] + for output in node.outputs: + data = storage_map[output][0] + assert compute_map[output][0], (output, data) + + # TODO: `Type` itself should provide an interface for constructing + # instances appropriate for a given constant. + # TODO: Add handling for sparse types. + if isinstance(output.type, DenseTensorType): + output_type = TensorType( + output.type.dtype, + tuple(s == 1 for s in data.shape), + name=output.type.name, + ) + else: + output_type = output.type + + v = output_type.make_constant(data) + + # We need to "narrow" types when we have additional information, + # and not "broaden" them. This is a case in which types are + # unnecessarily "broadened" + # assert not hasattr(output.type, "broadcastable") or output.type.broadcastable == tuple(s == 1 for s in data.shape) + + copy_stack_trace(output, v) + + rval.append(v) + + return rval + + +topo_constant_folding = in2out( + constant_folding, ignore_newtrees=True, name="topo_constant_folding" +) +register_canonicalize(topo_constant_folding, "fast_compile", final_rewriter=True) +register_uncanonicalize(topo_constant_folding, "fast_compile", final_rewriter=True) +register_stabilize(topo_constant_folding, "fast_compile", final_rewriter=True) +register_specialize(topo_constant_folding, "fast_compile", final_rewriter=True) + + +@register_canonicalize("fast_compile") +@register_useless("fast_compile") +@node_rewriter(None) +def local_view_op(fgraph, node): + if isinstance(node.op, ViewOp): + return node.inputs + + +@register_useless +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([Alloc]) +def local_merge_alloc(fgraph, node): + """ + This rewriter takes care of the following cases: + + Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) + Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) + Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> Alloc(m, x, assert(y1, y1==y2), z, w) + + """ + if not isinstance(node.op, Alloc): + return False + if not node.inputs[0].owner or not isinstance(node.inputs[0].owner.op, Alloc): + return False + inputs_outer = node.inputs + inputs_inner = node.inputs[0].owner.inputs + dims_outer = inputs_outer[1:] + dims_inner = inputs_inner[1:] + dims_outer_rev = dims_outer[::-1] + dims_inner_rev = dims_inner[::-1] + # check if the pattern of broadcasting is matched, in the reversed ordering. + # The reverse ordering is needed when an Alloc add an implicit new + # broadcasted dimensions to its inputs[0]. Eg: + # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) + i = 0 + for dim_inner, dim_outer in zip(dims_inner_rev, dims_outer_rev): + if dim_inner != dim_outer: + if isinstance(dim_inner, Constant) and dim_inner.data == 1: + pass + else: + dims_outer[-1 - i] = Assert( + "You have a shape error in your graph. To see a better" + " error message and a stack trace of where in your code" + " the error is created, use the Aesara flags" + " optimizer=None or optimizer=fast_compile." + )(dim_outer, eq(dim_outer, dim_inner)) + i += 1 + return [alloc(inputs_inner[0], *dims_outer)] + + +@register_useless("fast_compile") +@node_rewriter([TopKOp]) +def local_useless_topk(fgraph, node): + """Remove unused `TopKOp` outputs.""" + op = node.op + if not isinstance(op, TopKOp): + return + if not (op.return_values and op.return_indices): + return False + + x, k = node.inputs + ret_val = bool(fgraph.clients[node.outputs[0]]) + ret_idx = bool(fgraph.clients[node.outputs[1]]) + + if not (ret_val ^ ret_idx): + # both true -> nothing to remove + # both false -> let pruner handle + return False + + old_output = node.outputs[ret_idx] + new_output = TopKOp( + axis=op.axis, + sorted=op.sorted, + idx_dtype=op.idx_dtype, + return_values=ret_val, + return_indices=ret_idx, + )(x, k) + copy_stack_trace(node.outputs[0], new_output) + return {old_output: new_output} + + +def import_ShapeFeature(): + from aesara.tensor.rewriting.shape import ShapeFeature + + return ShapeFeature + + +DEPRECATED_NAMES = { + "ShapeFeature": ( + "`ShapeFeature` is now located in `aesara.tensor.rewriting.shape`.", + import_ShapeFeature, + ), +} + + +def __getattr__(name): + """Intercept module-level attribute access of deprecated symbols. + + Adapted from https://stackoverflow.com/a/55139609/3006474. + + """ + from warnings import warn + + res = DEPRECATED_NAMES.get(name) + if res: + msg, fn = res + warn(msg, DeprecationWarning, stacklevel=2) + return fn() + + raise AttributeError(f"module {__name__} has no attribute {name}") + + +register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy") diff --git a/aesara/tensor/rewriting/elemwise.py b/aesara/tensor/rewriting/elemwise.py new file mode 100644 index 0000000000..e80e871370 --- /dev/null +++ b/aesara/tensor/rewriting/elemwise.py @@ -0,0 +1,946 @@ +import sys +import time +from collections import defaultdict +from typing import Optional +from warnings import warn + +import aesara +import aesara.scalar.basic as aes +from aesara import compile +from aesara.configdefaults import config +from aesara.graph.basic import Apply, Constant, io_toposort +from aesara.graph.features import ReplaceValidate +from aesara.graph.op import compute_test_value, get_test_value +from aesara.graph.rewriting.basic import GraphRewriter, copy_stack_trace, node_rewriter +from aesara.graph.rewriting.db import SequenceDB +from aesara.graph.utils import InconsistencyError, MethodNotDefined, TestValueError +from aesara.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value +from aesara.tensor.elemwise import DimShuffle, Elemwise +from aesara.tensor.exceptions import NotScalarConstantError +from aesara.tensor.rewriting.basic import register_canonicalize, register_specialize +from aesara.tensor.shape import shape_padleft +from aesara.tensor.var import TensorConstant + + +class InplaceElemwiseOptimizer(GraphRewriter): + r""" + This is parameterized so that it works for `Elemwise` `Op`\s. + """ + + def __init__(self, OP): + self.op = OP + + def add_requirements(self, fgraph): + from aesara.graph.destroyhandler import DestroyHandler + + fgraph.attach_feature(DestroyHandler()) + + @classmethod + def print_profile(cls, stream, prof, level=0): + blanc = " " * level + print(blanc, cls.__name__, prof["opt"].op, file=stream) + for k in [ + "node_before", + "nb_call_replace", + "nb_call_validate", + "nb_inconsistent", + ]: + print(blanc, k, prof[k], file=stream) + ndim = prof["ndim"] + if ndim: + print(blanc, "ndim", "nb", file=stream) + for n in sorted(ndim.keys()): + print(blanc, n, ndim[n], file=stream) + + def apply(self, fgraph): + r""" + + Attempts to replace all `Elemwise`\s by versions of them that operate + inplace. It operates greedily: for each `Elemwise` that is encountered, + for each output, it tries each input to see if it can operate inplace + on that input. If so, it makes the change and goes to the next output + or `Elemwise`. + + Examples + -------- + + x + y + z -> x += y += z + (x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y) + + """ + # We should not validate too often as this takes too much time to + # execute! + # It is the _dfs_toposort() fct in aesara/graph/destroyhandler.py + # that takes so much time. + # Should we try to use another lib that does toposort? + # igraph: http://igraph.sourceforge.net/ + # networkx: https://networkx.lanl.gov/ + # Should we try to use cython? + # Compiling only that fct is not enough, should we try to add the + # deque class too? + # And init the deque and other list to an upper bound number of + # elements? + # Maybe Aesara should do online toposort as in + # http://code.google.com/p/acyclic + # + # The next longest rewriter is the canonizer phase. + # Then I think it is the [io_?]toposort (need to validate) so check if + # the solution is also applicable there. + + # We execute `validate` after this number of change. + prof = { + "opt": self, + "node_before": len(fgraph.apply_nodes), + "nb_call_replace": 0, + "nb_call_validate": 0, + "nb_inconsistent": 0, + "ndim": defaultdict(lambda: 0), + } + + check_each_change = config.tensor__insert_inplace_optimizer_validate_nb + if check_each_change == -1: + if len(fgraph.apply_nodes) > 500: + check_each_change = 10 + else: + check_each_change = 1 + + nb_change_no_validate = 0 + chk = fgraph.checkpoint() + + if fgraph.update_mapping: + update_outs = [fgraph.outputs[i] for i in fgraph.update_mapping] + else: + update_outs = [] + + protected_inputs = [ + f.protected + for f in fgraph._features + if isinstance(f, aesara.compile.function.types.Supervisor) + ] + protected_inputs = sum(protected_inputs, []) # flatten the list + protected_inputs.extend(fgraph.outputs) + for node in list(io_toposort(fgraph.inputs, fgraph.outputs)): + op = node.op + if not isinstance(op, self.op): + continue + # If big graph and the outputs are scalar, do not make it + # inplace. + if ( + check_each_change != 1 + and + # If multiple outputs, they must all have the same size, + # so only check the first. + getattr(node.outputs[0].type, "ndim", -1) == 0 + ): + continue + + if op.inplace_pattern: + # Maybe this isn't needed anymore, but I don't want to + # rish regression now. This case only happen if the + # original node add already some inplace patter and we + # still try to add more pattern. + + baseline = op.inplace_pattern + candidate_outputs = [ + i for i in range(len(node.outputs)) if i not in baseline + ] + # node inputs that are Constant, already destroyed, + # or fgraph protected inputs and fgraph outputs can't be used as + # inplace target. + # Remove here as faster. + candidate_inputs = [ + i + for i in range(len(node.inputs)) + if i not in baseline.values() + and not isinstance(node.inputs[i], Constant) + and + # the next line should not be costly most of the time. + not fgraph.has_destroyers([node.inputs[i]]) + and node.inputs[i] not in protected_inputs + ] + else: + baseline = [] + candidate_outputs = list(range(len(node.outputs))) + # node inputs that are Constant, already destroyed, + # fgraph protected inputs and fgraph outputs can't be used as inplace + # target. + # Remove here as faster. + candidate_inputs = [ + i + for i in range(len(node.inputs)) + if not isinstance(node.inputs[i], Constant) + and not fgraph.has_destroyers([node.inputs[i]]) + and node.inputs[i] not in protected_inputs + ] + + verbose = False + + raised_warning = not verbose + + for candidate_output in candidate_outputs: + + # If the output of the node can be established as an update + # output of the fgraph, visit the candidate_inputs in an order + # that will improve the chances of making the node operate + # inplace on the input it's meant to update + candidate_out_var = node.outputs[candidate_output] + sorted_candidate_inputs = candidate_inputs + + if candidate_out_var in update_outs: + + # The candidate output is an update. Sort the + # variables in candidate_inputs in the following order: + # - Vars corresponding to the actual updated input + # (best case scenario is for the node that procudes + # an update to operate inplace on the variable to + # update) + # - Vars computed inplace on the updates input (second + # best scenario if for the node to work inplace on + # a variable obtained by a chain of inplace on the + # variable to update. In some cases, this will be + # equivalent to operating inplace on the variable to + # update) + # - Remaining variables + updated_inputs = [] + for i, f_out in enumerate(fgraph.outputs): + if f_out is candidate_out_var and i in fgraph.update_mapping: + updated_inp_idx = fgraph.update_mapping[i] + updated_inputs.append(fgraph.inputs[updated_inp_idx]) + + updated_vars = [] + vars_from_inplace = [] + other_vars = [] + for inp_idx in candidate_inputs: + inp = node.inputs[inp_idx] + if inp in updated_inputs: + # the candidate input is the actual updated input + updated_vars.append(inp_idx) + elif ( + hasattr(fgraph, "destroy_handler") + and inp.owner + and any( + fgraph.destroy_handler.root_destroyer.get(up_inp, None) + is inp.owner + for up_inp in updated_inputs + ) + ): + + # the candidate input is a variable computed + # inplace on the updated input via a sequence of + # one or more inplace operations + vars_from_inplace.append(inp_idx) + else: + other_vars.append(inp_idx) + + sorted_candidate_inputs = ( + updated_vars + vars_from_inplace + other_vars + ) + + for candidate_input in sorted_candidate_inputs: + # remove inputs that don't have the same dtype as the output + if ( + node.inputs[candidate_input].type + != node.outputs[candidate_output].type + ): + continue + + inplace_pattern = dict(baseline) + inplace_pattern[candidate_output] = candidate_input + try: + if hasattr(op.scalar_op, "make_new_inplace"): + new_scal = op.scalar_op.make_new_inplace( + aes.transfer_type( + *[ + inplace_pattern.get(i, o.dtype) + for i, o in enumerate(node.outputs) + ] + ) + ) + else: + new_scal = op.scalar_op.__class__( + aes.transfer_type( + *[ + inplace_pattern.get(i, None) + for i in range(len(node.outputs)) + ] + ) + ) + new_outputs = self.op(new_scal, inplace_pattern)( + *node.inputs, return_list=True + ) + new_node = new_outputs[0].owner + + for r, new_r in zip(node.outputs, new_outputs): + prof["nb_call_replace"] += 1 + fgraph.replace( + r, new_r, reason="inplace_elemwise_optimizer" + ) + nb_change_no_validate += 1 + prof["ndim"][candidate_out_var.ndim] += 1 + if nb_change_no_validate >= check_each_change: + prof["nb_call_validate"] += 1 + fgraph.validate() + chk = fgraph.checkpoint() + nb_change_no_validate = 0 + except (ValueError, InconsistencyError) as e: + prof["nb_inconsistent"] += 1 + if check_each_change != 1 and not raised_warning: + print( + ( + "Some inplace rewriting was not " + "performed due to an unexpected error:" + ), + file=sys.stderr, + ) + print(e, file=sys.stderr) + raised_warning = True + fgraph.revert(chk) + continue + candidate_inputs.remove(candidate_input) + node = new_node + baseline = inplace_pattern + break + + if nb_change_no_validate > 0: + try: + fgraph.validate() + except Exception: + if not raised_warning: + print( + ( + "Some inplace rewriting was not " + "performed due to an unexpected error" + ), + file=sys.stderr, + ) + fgraph.revert(chk) + return prof + + def print_summary(self, stream=sys.stdout, level=0, depth=-1): + print( + f"{' ' * level}{self.__class__.__name__} ({self.op})", + file=stream, + ) + return inplace_elemwise_optimizer + + +inplace_elemwise_optimizer = InplaceElemwiseOptimizer(Elemwise) +compile.optdb.register( # type: ignore + "inplace_elemwise_opt", + inplace_elemwise_optimizer, + "inplace_opt", # for historic reason + "inplace_elemwise_optimizer", + "fast_run", + "inplace", + position=75, +) + + +def apply_local_dimshuffle_lift(fgraph, var): + """ + lift recursively + """ + if not var.owner: + return var + new = local_dimshuffle_lift.transform(fgraph, var.owner) + if new: + return new[0] + return var + + +def is_dimshuffle_useless(new_order, input): + """ + Checks for two types of useless dimshuffles: + 1 - dimshuffle all dimensions in order. + 2 - dimshuffle a broadcastable dimension. + """ + is_useless = True + if len(new_order) == input.type.ndim: + all_broadcastable_dims = [ + i + for (i, is_broadcastable) in enumerate(input.type.broadcastable) + if is_broadcastable + ] + ["x"] + for i in range(input.type.ndim): + if new_order[i] == i or ( + i in all_broadcastable_dims and new_order[i] in all_broadcastable_dims + ): + is_useless = True + else: + is_useless = False + break + else: + is_useless = False + return is_useless + + +@register_canonicalize +@register_specialize +@node_rewriter([DimShuffle]) +def local_dimshuffle_lift(fgraph, node): + """ + "Lifts" DimShuffle through Elemwise operations and merges + consecutive DimShuffles. Basically, applies the following + transformations on the whole graph: + + DimShuffle(Elemwise(x, y)) => Elemwise(DimShuffle(x), DimShuffle(y)) + DimShuffle(DimShuffle(x)) => DimShuffle(x) + DimShuffle{0,1,...}(x) => x (when the dimshuffle do nothing) + + After this transform, clusters of Elemwise operations are + void of DimShuffle operations. + + """ + op = node.op + if not isinstance(op, DimShuffle): + return False + + inp = node.inputs[0] + inode = inp.owner + new_order = op.new_order + if inode and isinstance(inode.op, Elemwise) and (len(fgraph.clients[inp]) == 1): + # Don't use make_node to have tag.test_value set. + new_inputs = [] + for inp in inode.inputs: + new_inp = op.__class__(inp.type.broadcastable, op.new_order)(inp) + new_inputs.append(apply_local_dimshuffle_lift(fgraph, new_inp)) + copy_stack_trace(node.outputs[0], new_inputs) + ret = inode.op(*new_inputs, return_list=True) + return ret + if inode and isinstance(inode.op, DimShuffle): + new_order = [x == "x" and "x" or inode.op.new_order[x] for x in new_order] + inp = inode.inputs[0] + + if is_dimshuffle_useless(new_order, inp): + return [inp] + elif inode and isinstance(inode.op, DimShuffle): + ret = op.__class__(inp.type.broadcastable, new_order)(inp) + ret = apply_local_dimshuffle_lift(fgraph, ret) + copy_stack_trace(node.outputs[0], ret) + return [ret] + + +@register_canonicalize +@register_specialize +@node_rewriter([DimShuffle]) +def local_useless_dimshuffle_makevector(fgraph, node): + r"""Remove `DimShuffle`\s that drop one dimensional broadcastable `MakeVector`s. + + This rewrite is needed in order to clean up after + `local_subtensor_remove_broadcastable_index`, which produces a + not-so-intuitive canonical form for `x[0]` when `x.shape == (1,)` + (i.e. one broadcastable dimension): i.e. `x.dimshuffle(())`. + """ + + # The `DimShuffle` should be removing the single broadcastable dimension + if node.op.new_order != (): + return + + makevector_out = node.inputs[0] + + if ( + not makevector_out.owner + or not isinstance(makevector_out.owner.op, MakeVector) + or not makevector_out.broadcastable == (True,) + ): + return + + assert len(makevector_out.owner.inputs) == 1 + + return [makevector_out.owner.inputs[0]] + + +@register_canonicalize +@node_rewriter([Elemwise]) +def local_upcast_elemwise_constant_inputs(fgraph, node): + """This explicitly upcasts constant inputs to elemwise Ops, when + those Ops do implicit upcasting anyway. + + Rationale: it helps merge things like (1-x) and (1.0 - x). + + """ + if len(node.outputs) > 1: + return + try: + shape_i = fgraph.shape_feature.shape_i + except AttributeError: + shape_i = None + if isinstance(node.op, Elemwise): + scalar_op = node.op.scalar_op + # print "aa", scalar_op.output_types_preference + if getattr(scalar_op, "output_types_preference", None) in ( + aes.upgrade_to_float, + aes.upcast_out, + ): + # this is the kind of op that we can screw with the input + # dtypes by upcasting explicitly + output_dtype = node.outputs[0].type.dtype + new_inputs = [] + for i in node.inputs: + if i.type.dtype == output_dtype: + new_inputs.append(i) + else: + try: + # works only for scalars + cval_i = get_scalar_constant_value( + i, only_process_constants=True + ) + if all(i.broadcastable): + new_inputs.append( + shape_padleft(cast(cval_i, output_dtype), i.ndim) + ) + else: + if shape_i is None: + return + new_inputs.append( + alloc( + cast(cval_i, output_dtype), + *[shape_i(d)(i) for d in range(i.ndim)], + ) + ) + # print >> sys.stderr, "AAA", + # *[Shape_i(d)(i) for d in range(i.ndim)] + except NotScalarConstantError: + # for the case of a non-scalar + if isinstance(i, TensorConstant): + new_inputs.append(cast(i, output_dtype)) + else: + new_inputs.append(i) + + if new_inputs != node.inputs: + rval = [node.op(*new_inputs)] + if not node.outputs[0].type.is_super(rval[0].type): + # This can happen for example when floatX=float32 + # and we do the true division between and int64 + # and a constant that will get typed as int8. + + # As this is just to allow merging more case, if + # the upcast don't work, we can just skip it. + return + + # Copy over output stacktrace from before upcasting + copy_stack_trace(node.outputs[0], rval) + return rval + + +def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None): + r"""Create a recursive function that fuses `Elemwise` `Op`\s. + + The basic idea is that we loop through an `Elemwise` node's inputs, find + other `Elemwise` nodes, determine the scalars input types for all of the + `Elemwise` `Op`\s, construct a new scalar `Op` using the scalar input types + and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a + new "fused" `Elemwise`. + + It's parameterized in order to work for `Elemwise` `Op`\s. + + Parameters + ---------- + op_class : type + `Elemwise` class (the one that we want to fuse) + max_input_fct : callable + A function that returns the maximum number of inputs that this `Elemwise` + can take. + On the CPU we limit to 32 input variables since that is the maximum + NumPy support. + + maker: callable + A function with the signature ``(node, *args)`` that constructs an + `op_class` instance (e.g. ``op_class(*args)``). + + """ + if maker is None: + + def maker(node, scalar_op): + return op_class(scalar_op) + + def local_fuse(fgraph, node): + r"""Fuse `Elemwise` `Op`\s in a node. + + As part of specialization, we fuse two consecutive `Elemwise` `Op`\s of the + same shape. + + For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C + compiler do the cast. + + The number of dimensions is validated at call time by Aesara itself. + + """ + # TODO: use broadcast flag? + + # TODO: don't do this rewrite as a `NodeRewriter`. + # Analyze the graph in terms of elemwise subgraphs, and then + # replace each subgraph with a Composite version. + + # TODO: use malloc and copy to transfer arguments that don't + # fit within the parameter space of 256 bytes + # + # TODO: Merge with multiple output to merge when an inputs + # have multiple clients. This can't be done with a `NodeRewriter` + + # TODO: Related: Support composites with multiple outputs + + # TODO: Use Composite to combine Elemwise and Reduce + # operations. We have to loop over the data anyway... might + # as well sum it up while we're at it (this can be trickier + # than i'm making it seound here. The data-traversal should be + # done contiguously, and the summing-up might not be easy or + # worthwhile if the summation axis doesn't line up with a + # contiguous dimension) + + if type(node.op) is not op_class: + return False + + if len(node.outputs) > 1: + # We don't support fusion for nodes with multiple outputs. + return + + inputs = [] # inputs of the new Elemwise op. + s_inputs = [] # inputs of the new scalar op used by the Composite. + # Inputs of the new scalar op that represents the current node. + s_g = [] + + # There is a hard limit of 256 bytes for the formal argument list to a + # GPU kernel function. + max_nb_input = max_input_fct(node) + # The number of inputs to the new fused op if we do not fuse more + # inputs. + new_nb_input = len(node.inputs) + # Did we fuse something? + # Needed as we can fuse unary op that don't change the number of + # inputs. + # And there is a case where the inputs are the same as the current + # node. That won't change the number of inputs of the new op. + fused = False + + for i in node.inputs: + scalar_node: Optional[Apply] = None + # Will store inputs of the fused node that are not currently inputs + # of the node we want to create (to avoid duplicating inputs). + tmp_input = [] + # Same as tmp_input, but for scalars. + tmp_scalar = [] + + # We should not check the number of inputs here + # As fusing op don't always change the number of input. + # If a variable is used as multiple into to the same node, + # we still want to fusion. So we take the set. + if ( + i.owner + and isinstance(i.owner.op, op_class) + and len({n for n, idx in fgraph.clients[i]}) == 1 + and + # Do not merge elemwise that don't have the same + # broadcastable pattern to don't redo duplicate + # computation due to broadcast. + i.owner.outputs[0].broadcastable == node.outputs[0].broadcastable + ): + try: + tmp_s_input = [] + # we should not put duplicate input into s_inputs and inputs + for ii in i.owner.inputs: + if ii in inputs: + tmp_s_input.append(s_inputs[inputs.index(ii)]) + elif ii in tmp_input: + tmp_s_input.append(tmp_scalar[tmp_input.index(ii)]) + else: + tmp = aes.get_scalar_type(ii.type.dtype).make_variable() + + try: + tv = get_test_value(ii) + # Sometimes the original inputs have + # zero-valued shapes in some dimensions, which + # implies that this whole scalar thing doesn't + # make sense (i.e. we're asking for the scalar + # value of an entry in a zero-dimensional + # array). + # This will eventually lead to an error in the + # `compute_test_value` call below when/if + # `config.compute_test_value_opt` is enabled + # (for debugging, more or less) + tmp.tag.test_value = tv.item() + except (TestValueError, ValueError): + pass + + tmp_s_input.append(tmp) + tmp_input.append(ii) + tmp_scalar.append(tmp_s_input[-1]) + + # Use the `Op.make_node` interface in case `Op.__call__` + # has been customized + scalar_node = i.owner.op.scalar_op.make_node(*tmp_s_input) + + if config.compute_test_value_opt != "off": + # This is required because `Op.make_node` won't do it + compute_test_value(scalar_node) + + # If the scalar_op doesn't have a C implementation, we skip + # its fusion to allow fusion of the other ops + i.owner.op.scalar_op.c_code( + scalar_node, + "test_presence_of_c_code", + ["x" for x in i.owner.inputs], + ["z" for z in i.owner.outputs], + {"fail": "%(fail)s"}, + ) + + except (NotImplementedError, MethodNotDefined): + warn( + ( + "Rewrite warning: " + f"The Op {i.owner.op.scalar_op} does not provide a C implementation." + " As well as being potentially slow, this also disables " + "loop fusion." + ) + ) + scalar_node = None + + # Compute the number of inputs in case we fuse this input. + # We subtract 1 because we replace the existing input with the new + # inputs from `tmp_input`. + new_nb_input_ = new_nb_input + len(tmp_input) - 1 + + # If the new input is already an input of the current node, it was + # already counted when `new_nb_input` was initialized to + # len(node.inputs). + # This can happen when a variable is used both by the Elemwise to + # fuse and the current node. + for x in tmp_input: + if x in node.inputs: + new_nb_input_ -= 1 + + if scalar_node and (new_nb_input_ <= max_nb_input): + fused = True + new_nb_input = new_nb_input_ + inputs.extend(tmp_input) + s_inputs.extend(tmp_scalar) + s_g.extend(scalar_node.outputs) + else: + # We must support the case where the same variable appears many + # times within the inputs + if inputs.count(i) == node.inputs.count(i): + s = s_inputs[inputs.index(i)] + else: + s = aes.get_scalar_type(i.type.dtype).make_variable() + if config.compute_test_value_opt != "off": + try: + v = get_test_value(i) + # See the zero-dimensional test value situation + # described above. + s.tag.test_value = v.item() + except (TestValueError, ValueError): + pass + + inputs.append(i) + s_inputs.append(s) + s_g.append(s) + + if not fused: + return False + + if new_nb_input != len(inputs) or len(s_inputs) != len(inputs): + # TODO FIXME: This shouldn't be a generic `Exception` + raise Exception( + "Something has gone wrong with the elemwise fusion rewrite; skipping." + ) + + s_new_out = node.op.scalar_op(*s_g, return_list=True) + try: + s_new_out[0].owner.op.c_code( + s_new_out[0].owner, + "test_presence_of_c_code", + ["x" for x in s_g], + ["z" for x in s_new_out], + {"fail": "%(fail)s"}, + ) + except (NotImplementedError, MethodNotDefined): + name = str(s_new_out[0].owner.op) + warn( + ( + "Rewrite warning: " + f"The Op {name} does not provide a C implementation." + " As well as being potentially slow, this also disables " + "loop fusion." + ) + ) + return False + + # create the composite op. + composite_op = aes.Composite(s_inputs, s_new_out) + + # create the new node. + # Do not call make_node to have test_value + new_node = maker(node, composite_op)(*inputs).owner + + assert len(new_node.outputs) == 1 + assert node.outputs[0].type.dtype == new_node.outputs[0].type.dtype + + if len(new_node.inputs) > max_nb_input: + warn( + "Loop fusion failed because the resulting node " + "would exceed the kernel argument limit." + ) + return False + + # we fuse as many that we can at the same time to make debug mode faster + # debug mode will be faster as it won't test all intermediate step. + while True: + ret = local_fuse(fgraph, new_node) + if ret is not False and ret is not None: + assert len(ret) == len(new_node.outputs) + assert len(ret) == 1 + new_node = ret[0].owner + else: + break + + return new_node.outputs + + return local_fuse + + +def elemwise_max_input_fct(node): + # `Elemwise.perform` uses NumPy ufuncs and they are limited to 31 inputs. + if not config.cxx: + return 31 + return 1024 + + +local_elemwise_fusion = local_elemwise_fusion_op(Elemwise, elemwise_max_input_fct) + + +class FusionOptimizer(GraphRewriter): + """Graph rewriter that simply runs node fusion operations. + + TODO: This is basically an `EquilibriumGraphRewriter`; we should just use that. + + """ + + def __init__(self, node_rewriter): + super().__init__() + self.node_rewriter = node_rewriter + + def add_requirements(self, fgraph): + fgraph.attach_feature(ReplaceValidate()) + + def apply(self, fgraph): + did_something = True + nb_iter = 0 + nb_replacement = 0 + nb_inconsistency_replace = 0 + time_toposort = 0 + if fgraph.profile: + validate_before = fgraph.profile.validate_time + callbacks_before = fgraph.execute_callbacks_times.copy() + callback_before = fgraph.execute_callbacks_time + while did_something: + t0 = time.time() + nodelist = list(fgraph.toposort()) + time_toposort += time.time() - t0 + nodelist.reverse() + did_something = False + for node in nodelist: + # Don't try to fuse node that have already been fused. + if node in fgraph.apply_nodes: + new_outputs = self.node_rewriter(fgraph, node) + if new_outputs: + assert len(new_outputs) == len(node.outputs) + try: + fgraph.replace_all_validate( + list(zip(node.outputs, new_outputs)), + reason=self.__class__.__name__, + ) + did_something = True + nb_replacement += 1 + except InconsistencyError: + nb_inconsistency_replace += 1 + nb_iter += 1 + + if fgraph.profile: + validate_time = fgraph.profile.validate_time - validate_before + callback_time = fgraph.execute_callbacks_time - callback_before + callbacks_time = {} + for k, v in fgraph.execute_callbacks_times.items(): + if k in callbacks_before: + callbacks_time[k] = v - callbacks_before[k] + else: + callbacks_time[k] = v + else: + validate_time = None + callback_time = None + callbacks_time = {} + return ( + self, + nb_iter, + nb_replacement, + nb_inconsistency_replace, + validate_time, + callback_time, + callbacks_time, + time_toposort, + ) + + @classmethod + def print_profile(cls, stream, prof, level=0): + blanc = " " * level + print(blanc, cls.__name__, file=stream) + print(blanc, " nb_iter", prof[1], file=stream) + print(blanc, " nb_replacement", prof[2], file=stream) + print(blanc, " nb_inconsistency_replace", prof[3], file=stream) + print(blanc, " validate_time", prof[4], file=stream) + print(blanc, " callback_time", prof[5], file=stream) + if prof[5] is not None and prof[5] > 1: + print(blanc, " callbacks_time", file=stream) + for i in sorted(prof[6].items(), key=lambda a: a[1])[::-1]: + if i[1] > 0: + print(blanc, " ", i) + print(blanc, " time_toposort", prof[7], file=stream) + + +if config.tensor__local_elemwise_fusion: + # Must be after gpu(48.5) and before AddDestroyHandler(49.5) + fuse_seqopt = SequenceDB() + fuse_seqopt.register( + "composite_elemwise_fusion", + FusionOptimizer(local_elemwise_fusion), + "fast_run", + "fusion", + position=1, + ) + compile.optdb.register( # type: ignore + "elemwise_fusion", + fuse_seqopt, + "fast_run", + "fusion", + "local_elemwise_fusion", + "FusionOptimizer", + position=49, + ) +else: + compile.optdb.register( # type: ignore + "elemwise_fusion", + FusionOptimizer(local_elemwise_fusion), + "fusion", + "local_elemwise_fusion", + "FusionOptimizer", + position=49, + ) + + +@register_canonicalize +@node_rewriter([Elemwise]) +def local_useless_composite(fgraph, node): + """For elemwise Composite that have multiple outputs, remove the + outputs that are not used. + + """ + if not isinstance(node.op, Elemwise) or not isinstance( + node.op.scalar_op, aes.Composite + ): + return + comp = node.op.scalar_op + idx = [i for i, o_extern in enumerate(node.outputs) if fgraph.clients[o_extern]] + if len(idx) < len(node.outputs): + new_outputs = [comp.outputs[i] for i in idx] + c = aes.Composite(inputs=comp.inputs, outputs=new_outputs) + e = Elemwise(scalar_op=c)(*node.inputs, return_list=True) + return dict(zip([node.outputs[i] for i in idx], e)) diff --git a/aesara/tensor/rewriting/extra_ops.py b/aesara/tensor/rewriting/extra_ops.py new file mode 100644 index 0000000000..d21e65beae --- /dev/null +++ b/aesara/tensor/rewriting/extra_ops.py @@ -0,0 +1,177 @@ +import aesara.scalar.basic as aes +from aesara.graph.rewriting.basic import node_rewriter +from aesara.tensor.basic import Alloc, as_tensor_variable +from aesara.tensor.elemwise import Elemwise +from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique +from aesara.tensor.rewriting.basic import register_canonicalize, register_useless + + +@register_useless +@register_canonicalize +@node_rewriter([Unique]) +def local_Unique_scalar(fgraph, node): + """Convert ``unique(x)`` to ``x`` when ``x`` is a scalar.""" + if not isinstance(node.op, Unique): + return False + + if node.op.return_index or node.op.return_inverse or node.op.return_counts: + return False + + uniqued_var = node.inputs[0] + + if uniqued_var.ndim != 0: + return False + + old_out = node.outputs[0] + res = as_tensor_variable(uniqued_var, ndim=old_out.ndim, dtype=old_out.dtype) + return [res] + + +@register_useless +@register_canonicalize +@node_rewriter([Unique]) +def local_Unique_Alloc_lift(fgraph, node): + """Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``. + + This isn't really so much a lift as a "reduction/consumption". + """ + if not isinstance(node.op, Unique): + return False + + if ( + node.op.return_index + or node.op.return_inverse + or node.op.return_counts + or node.op.axis is not None + ): + return False + + alloc_var = node.inputs[0] + + if not (alloc_var.owner and isinstance(alloc_var.owner.op, Alloc)): + return False + + alloced_var, *alloc_shape = alloc_var.owner.inputs + + new_unique, *_ = node.op.make_node(alloced_var).outputs + + old_out = node.outputs[0] + new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype) + return [new_x] + + +@register_useless +@register_canonicalize +@node_rewriter([Unique]) +def local_Unique_BroadcastTo_lift(fgraph, node): + """Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``. + + This isn't really so much a lift as a "reduction/consumption". + """ + if not isinstance(node.op, Unique): + return False + + if ( + node.op.return_index + or node.op.return_inverse + or node.op.return_counts + or node.op.axis is not None + ): + return False + + bcast_var = node.inputs[0] + + if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)): + return False + + bcasted_var, *bcast_shape = bcast_var.owner.inputs + + new_unique, *_ = node.op.make_node(bcasted_var).outputs + + old_out = node.outputs[0] + new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype) + return [new_x] + + +@register_useless +@register_canonicalize +@node_rewriter([Unique]) +def local_Unique_Repeat_lift(fgraph, node): + """Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``. + + This isn't really so much a lift as a "reduction/consumption". + """ + if not isinstance(node.op, Unique): + return False + + if ( + node.op.return_index + or node.op.return_inverse + or node.op.return_counts + or node.op.axis is not None + ): + return False + + repeat_var = node.inputs[0] + + if not (repeat_var.owner and isinstance(repeat_var.owner.op, Repeat)): + return False + + repeated_var, *repeat_shape = repeat_var.owner.inputs + + new_unique, *_ = node.op.make_node(repeated_var).outputs + + old_out = node.outputs[0] + new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype) + return [new_x] + + +@register_useless +@register_canonicalize +@node_rewriter([Unique]) +def local_Unique_second(fgraph, node): + """Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``. + + This isn't really so much a lift as a "reduction/consumption". + """ + if not isinstance(node.op, Unique): + return False + + if ( + node.op.return_index + or node.op.return_inverse + or node.op.return_counts + or node.op.axis is not None + ): + return False + + second_var = node.inputs[0] + + if not ( + second_var.owner + and isinstance(second_var.owner.op, Elemwise) + and isinstance(second_var.owner.op.scalar_op, aes.Second) + ): + return False + + shape_var, seconded_var = second_var.owner.inputs + + new_unique, *_ = node.op.make_node(seconded_var).outputs + + old_out = node.outputs[0] + new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype) + return [new_x] + + +@register_useless +@register_canonicalize +@node_rewriter([BroadcastTo]) +def local_remove_scalar_BroadcastTo(fgraph, node): + + bcast_shape = node.inputs[1:] + + if not bcast_shape: + bcasted_var = node.inputs[0] + # If this isn't true, the graph is invalid + assert bcasted_var.ndim == 0 + return [bcasted_var] diff --git a/aesara/tensor/rewriting/math.py b/aesara/tensor/rewriting/math.py new file mode 100644 index 0000000000..3c38a6086b --- /dev/null +++ b/aesara/tensor/rewriting/math.py @@ -0,0 +1,3577 @@ +r"""Rewrites for the `Op`\s in `aesara.tensor.math`.""" + +import itertools +import operator +from functools import partial, reduce + +import numpy as np + +import aesara.scalar.basic as aes +import aesara.scalar.math as aes_math +from aesara.graph.basic import Constant, Variable +from aesara.graph.rewriting.basic import ( + NodeRewriter, + PatternNodeRewriter, + SequentialNodeRewriter, + copy_stack_trace, + in2out, + node_rewriter, +) +from aesara.graph.rewriting.utils import get_clients_at_depth +from aesara.misc.safe_asarray import _asarray +from aesara.raise_op import assert_op +from aesara.tensor.basic import ( + Alloc, + Join, + MakeVector, + alloc, + as_tensor_variable, + cast, + constant, + extract_constant, + fill, + get_scalar_constant_value, + ones_like, + switch, + zeros_like, +) +from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise +from aesara.tensor.exceptions import NotScalarConstantError +from aesara.tensor.math import ( + All, + Any, + Dot, + NonZeroCAReduce, + Prod, + ProdWithoutZeros, + Sum, + _conj, +) +from aesara.tensor.math import abs as at_abs +from aesara.tensor.math import ( + add, + dot, + eq, + erf, + erfc, + exp, + expm1, + ge, + int_div, + isinf, + le, + log, + log1mexp, + log1p, + makeKeepDims, +) +from aesara.tensor.math import max as at_max +from aesara.tensor.math import maximum, mul, neg +from aesara.tensor.math import pow as at_pow +from aesara.tensor.math import prod, reciprocal, sgn, sigmoid, softplus, sqr, sqrt, sub +from aesara.tensor.math import sum as at_sum +from aesara.tensor.math import true_div +from aesara.tensor.rewriting.basic import ( + broadcast_like, + encompasses_broadcastable, + local_fill_sink, + register_canonicalize, + register_specialize, + register_specialize_device, + register_stabilize, + register_uncanonicalize, + register_useless, +) +from aesara.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt +from aesara.tensor.shape import Shape, Shape_i +from aesara.tensor.subtensor import Subtensor +from aesara.tensor.type import ( + complex_dtypes, + uint_dtypes, + values_eq_approx_remove_inf, + values_eq_approx_remove_inf_nan, + values_eq_approx_remove_nan, +) +from aesara.tensor.var import TensorConstant, get_unique_value + + +def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): + """Partition a list of variables into two kinds: + scalar constants, and the rest.""" + consts = [] + origconsts = [] + nonconsts = [] + for i in inputs: + try: + v = get_scalar_constant_value( + i, elemwise=elemwise, only_process_constants=only_process_constants + ) + consts.append(v) + origconsts.append(i) + except NotScalarConstantError: + nonconsts.append(i) + return consts, origconsts, nonconsts + + +def get_constant(v): + """ + + Returns + ------- + object + A numeric constant if v is a Constant or, well, a + numeric constant. If v is a plain Variable, returns None. + + """ + if isinstance(v, Constant): + unique_value = get_unique_value(v) + if unique_value is not None: + data = unique_value + else: + data = v.data + if data.ndim == 0: + return data + else: + return None + elif isinstance(v, Variable): + return None + else: + return v + + +def fill_chain(new_out, orig_inputs): + for i in orig_inputs: + new_out = fill(i, new_out) + return [new_out] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Dot]) +def local_0_dot_x(fgraph, node): + if not isinstance(node.op, Dot): + return False + + x = node.inputs[0] + y = node.inputs[1] + replace = False + try: + if get_scalar_constant_value(x, only_process_constants=True) == 0: + replace = True + except NotScalarConstantError: + pass + + try: + if get_scalar_constant_value(y, only_process_constants=True) == 0: + replace = True + except NotScalarConstantError: + pass + + if replace: + constant_zero = constant(0, dtype=node.outputs[0].type.dtype) + if x.ndim == 2 and y.ndim == 2: + constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0])) + return [alloc(constant_zero, x.shape[0], y.shape[1])] + elif x.ndim == 1 and y.ndim == 2: + constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0])) + return [alloc(constant_zero, y.shape[1])] + elif x.ndim == 2 and y.ndim == 1: + constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0])) + return [alloc(constant_zero, x.shape[0])] + elif x.ndim == 1 and y.ndim == 1: + constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0])) + return [constant_zero] + + +@register_canonicalize +@node_rewriter([DimShuffle]) +def local_lift_transpose_through_dot(fgraph, node): + r"""Perform the rewrite ``dot(x,y).T -> dot(y.T, x.T)``. + + These rewrites "lift" (propagate towards the inputs) `DimShuffle` + through dot product. It allows to put the graph in a more standard shape, + and to later merge consecutive `DimShuffle`\s. + + The transformation should be apply whether or not the transpose is + inplace. The newly-introduced transpositions are not inplace, this will + be taken care of in a later rewrite phase. + + """ + if not (isinstance(node.op, DimShuffle) and node.op.new_order == (1, 0)): + return False + if not (node.inputs[0].owner and isinstance(node.inputs[0].owner.op, Dot)): + return False + x, y = node.inputs[0].owner.inputs + + if x.ndim == y.ndim == 2: + # Output is dot product of transposed inputs in reverse order + ret = [dot(y.T, x.T)] + + # Copy over stack trace to output from result of dot-product + copy_stack_trace(node.inputs[0], ret) + return ret + + +def is_inverse_pair(node_op, prev_op, inv_pair): + """ + Given two consecutive operations, check if they are the + provided pair of inverse functions. + + """ + node_is_op0 = isinstance(node_op, inv_pair[0]) + node_is_op1 = isinstance(node_op, inv_pair[1]) + prev_is_op0 = isinstance(prev_op, inv_pair[0]) + prev_is_op1 = isinstance(prev_op, inv_pair[1]) + + return (node_is_op0 and prev_is_op1) or (node_is_op1 and prev_is_op0) + + +@register_canonicalize +@register_specialize +@node_rewriter([Elemwise]) +def local_func_inv(fgraph, node): + """ + Check for two consecutive operations that are functional inverses + and remove them from the function graph. + + """ + inv_pairs = ( + (aes.Deg2Rad, aes.Rad2Deg), + (aes.Cosh, aes.ArcCosh), + (aes.Tanh, aes.ArcTanh), + (aes.Sinh, aes.ArcSinh), + (aes.Conj, aes.Conj), + (aes.Neg, aes.Neg), + (aes.Reciprocal, aes.Reciprocal), + ) + x = node.inputs[0] + + if not isinstance(node.op, Elemwise): + return + if not x.owner or not isinstance(x.owner.op, Elemwise): + return + + prev_op = x.owner.op.scalar_op + node_op = node.op.scalar_op + + for inv_pair in inv_pairs: + if is_inverse_pair(node_op, prev_op, inv_pair): + # We don't need to copy stack trace, because the rewrite + # is trivial and maintains the earlier stack trace + ottype = node.out.dtype + inp = x.owner.inputs[0] + # Functions may have casted integer input to float + if inp.dtype != ottype: + inp = cast(inp, ottype) + return [inp] + + return + + +@register_canonicalize +@register_specialize +@node_rewriter([Elemwise]) +def local_exp_log(fgraph, node): + x = node.inputs[0] + + if not isinstance(node.op, Elemwise): + return + if not x.owner or not isinstance(x.owner.op, Elemwise): + return + + prev_op = x.owner.op.scalar_op + node_op = node.op.scalar_op + + # Case for log(exp(x)) -> x + if isinstance(prev_op, aes.Exp) and isinstance(node_op, aes.Log): + new_out = x.owner.inputs[0] + old_out = node.outputs[0] + # Exp may have cast integer input to float + if new_out.dtype != old_out.dtype: + new_out = cast(new_out, old_out.dtype) + return [new_out] + + # Case for log1p(expm1(x)) -> x + if isinstance(prev_op, aes.Expm1) and isinstance(node_op, aes.Log1p): + new_out = x.owner.inputs[0] + old_out = node.outputs[0] + # Expm1 may have cast integer input to float + if new_out.dtype != old_out.dtype: + new_out = cast(new_out, old_out.dtype) + return [new_out] + + # Case for exp(softplus(x)) aka exp(log1pexp) -> 1 + exp(x) + if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Exp): + x = x.owner.inputs[0] + return [add(1, exp(x))] + + # Case for expm1(softplus(x)) aka expm1(log1pexp) -> exp(x) + if isinstance(prev_op, aes_math.Softplus) and isinstance(node_op, aes.Expm1): + x = x.owner.inputs[0] + return [exp(x)] + + +@register_specialize +@node_rewriter([Elemwise]) +def local_exp_log_nan_switch(fgraph, node): + # Rewrites of the kind exp(log...(x)) that require a `nan` switch + x = node.inputs[0] + + if not isinstance(node.op, Elemwise): + return + if not x.owner or not isinstance(x.owner.op, Elemwise): + return + + prev_op = x.owner.op.scalar_op + node_op = node.op.scalar_op + + # Case for exp(log(x)) -> x + if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Exp): + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = switch(ge(x, 0), x, np.asarray(np.nan, old_out.dtype)) + return [new_out] + + # Case for exp(log1p(x)) -> x + 1 + if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Exp): + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = switch(ge(x, -1), add(1, x), np.asarray(np.nan, old_out.dtype)) + return [new_out] + + # Case for expm1(log(x)) -> x - 1 + if isinstance(prev_op, aes.Log) and isinstance(node_op, aes.Expm1): + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = switch(ge(x, 0), sub(x, 1), np.asarray(np.nan, old_out.dtype)) + return [new_out] + + # Case for expm1(log1p(x)) -> x + if isinstance(prev_op, aes.Log1p) and isinstance(node_op, aes.Expm1): + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = switch(ge(x, -1), x, np.asarray(np.nan, old_out.dtype)) + return [new_out] + + # Case for exp(log1mexp(x)) -> 1 - exp(x) + if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Exp): + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = switch(le(x, 0), sub(1, exp(x)), np.asarray(np.nan, old_out.dtype)) + return [new_out] + + # Case for expm1(log1mexp(x)) -> -exp(x) + if isinstance(prev_op, aes_math.Log1mexp) and isinstance(node_op, aes.Expm1): + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype)) + return [new_out] + + +@register_canonicalize +@register_specialize +@node_rewriter([Sum]) +def local_sumsqr2dot(fgraph, node): + """ + This rewrite detects + ``at.sqr(W.dimshuffle("x", 0, 1) * G.dimshuffle(0, "x", 1) ).sum(axis=(1, 2))`` + and converts it to ``at.dot(at.sqr(G), at.sqr(W).sum(axis=0))``. + """ + if ( + isinstance(node.op, Sum) + and isinstance(node.op.scalar_op, aes.Add) + and node.op.axis == (1, 2) + ): + in1 = node.inputs[0] + out = node.outputs[0] + + if ( + in1.owner + and isinstance(in1.owner.op, Elemwise) + and isinstance(in1.owner.op.scalar_op, aes.Sqr) + ): + in_sqr = in1.owner.inputs[0] + if ( + in_sqr.owner + and isinstance(in_sqr.owner.op, Elemwise) + and isinstance(in_sqr.owner.op.scalar_op, aes.Mul) + and len(in_sqr.owner.inputs) == 2 + ): + in_mul1, in_mul2 = in_sqr.owner.inputs + + if ( + isinstance(in_mul1.owner.op, DimShuffle) + and in_mul1.owner.op.new_order == ("x", 0, 1) + and isinstance(in_mul2.owner.op, DimShuffle) + and in_mul2.owner.op.new_order == (0, "x", 1) + ): + W = in_mul1.owner.inputs[0] + G = in_mul2.owner.inputs[0] + + new_out = dot(sqr(G), sqr(W).sum(axis=0)) + if new_out.dtype != out.dtype: + new_out = cast(new_out, dtype=out.dtype) + return [new_out] + + +@register_stabilize +@register_specialize +@register_canonicalize +@node_rewriter([Elemwise]) +def local_expm1(fgraph, node): + """Detect ``exp(a) - 1`` and convert them to ``expm1(a)``.""" + if isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, aes.Sub): + in1, in2 = node.inputs + out = node.outputs[0] + + if ( + in1.owner + and isinstance(in1.owner.op, Elemwise) + and isinstance(in1.owner.op.scalar_op, aes.Exp) + and extract_constant(in2, only_process_constants=False) == 1 + ): + in11 = in1.owner.inputs[0] + new_out = expm1(in11) + + if new_out.dtype != out.dtype: + new_out = cast(new_out, dtype=out.dtype) + + if not out.type.is_super(new_out.type): + return + return [new_out] + + +@register_specialize +@register_canonicalize +@node_rewriter([mul]) +def local_mul_switch_sink(fgraph, node): + """ + This rewrite makes the following changes in the graph: + + at.mul(A, at.switch(cond, 0, iff), B) -> at.switch(cond, 0, at.mul(A, B, iff)) + at.mul(A, at.switch(cond, ift, 0), B) -> at.switch(cond, at.mul(A, B, ift), 0) + + ``A`` and ``B`` being several (or none) symbolic variables. + This is useful because ``A`` and ``B`` may not be numerically stable and give + NaN or inf values for cases where the switch returns 0. + With this rewrite ``at.grad(at.switch(...))`` has the right behavior. + + Examples + -------- + + x -> f(x) + x -> g(x) + y = at.switch(cond, f(x), g(x)) + + without the rewrite: + + at.grad(y, x) -> grad(f(x), x) * grad(y, f(x)) + grad(g(x), x) * grad(y, g(x)) + + with the rewrite + + at.grad(y, x) -> switch(cond, grad(f(x), x), 0) + switch(cond, 0, grad(g(x), x)) + + This will be particularly useful for the lazy ``if`` because we skip an entire + part of the graph. + + """ + if node.op != mul: + return False + for idx, i in enumerate(node.inputs): + if i.owner and i.owner.op == switch: + switch_node = i.owner + try: + if ( + get_scalar_constant_value( + switch_node.inputs[1], only_process_constants=True + ) + == 0.0 + ): + listmul = node.inputs[:idx] + node.inputs[idx + 1 :] + fmul = mul(*(listmul + [switch_node.inputs[2]])) + + # Copy over stacktrace for elementwise multiplication op + # from previous elementwise multiplication op. + # An error in the multiplication (e.g. errors due to + # inconsistent shapes), will point to the + # multiplication op. + copy_stack_trace(node.outputs, fmul) + + fct = [switch(switch_node.inputs[0], 0, fmul)] + fct[0].tag.values_eq_approx = values_eq_approx_remove_nan + + # Copy over stacktrace for switch op from both previous + # elementwise multiplication op and previous switch op, + # because an error in this part can be caused by either + # of the two previous ops. + copy_stack_trace(node.outputs + switch_node.outputs, fct) + return fct + except NotScalarConstantError: + pass + try: + if ( + get_scalar_constant_value( + switch_node.inputs[2], only_process_constants=True + ) + == 0.0 + ): + listmul = node.inputs[:idx] + node.inputs[idx + 1 :] + fmul = mul(*(listmul + [switch_node.inputs[1]])) + # Copy over stacktrace for elementwise multiplication op + # from previous elementwise multiplication op. + # An error in the multiplication (e.g. errors due to + # inconsistent shapes), will point to the + # multiplication op. + copy_stack_trace(node.outputs, fmul) + + fct = [switch(switch_node.inputs[0], fmul, 0)] + fct[0].tag.values_eq_approx = values_eq_approx_remove_nan + + # Copy over stacktrace for switch op from both previous + # elementwise multiplication op and previous switch op, + # because an error in this part can be caused by either + # of the two previous ops. + copy_stack_trace(node.outputs + switch_node.outputs, fct) + return fct + except NotScalarConstantError: + pass + return False + + +@register_canonicalize +@node_rewriter([true_div, int_div]) +def local_div_switch_sink(fgraph, node): + """ + This rewrite makes the following changes in the graph: + + at.div(at.switch(cond, 0, iff), A) -> at.switch(cond, 0, at.div(iff, A)) + at.div(at.switch(cond, ift, 0), A) -> at.switch(cond, at.div(ift, A), 0) + + where ``A`` is a symbolic variable. + + This is useful because ``A`` may not be numerically stable and give + ``nan`` or ``inf`` values for cases where the switch returns 0. + + See `local_mul_switch_sink` for more details. + + """ + if node.op != true_div and node.op != int_div: + return False + op = node.op + if node.inputs[0].owner and node.inputs[0].owner.op == switch: + switch_node = node.inputs[0].owner + try: + if ( + get_scalar_constant_value( + switch_node.inputs[1], only_process_constants=True + ) + == 0.0 + ): + fdiv = op(switch_node.inputs[2], node.inputs[1]) + # Copy over stacktrace for elementwise division op + # from previous elementwise multiplication op. + # An error in the division (e.g. errors due to + # inconsistent shapes or division by zero), + # will point to the new division op. + copy_stack_trace(node.outputs, fdiv) + + fct = [switch(switch_node.inputs[0], 0, fdiv)] + fct[0].tag.values_eq_approx = values_eq_approx_remove_nan + + # Copy over stacktrace for switch op from both previous + # elementwise division op and previous switch op, + # because an error in this part can be caused by either + # of the two previous ops. + copy_stack_trace(node.outputs + switch_node.outputs, fct) + return fct + except NotScalarConstantError: + pass + try: + if ( + get_scalar_constant_value( + switch_node.inputs[2], only_process_constants=True + ) + == 0.0 + ): + fdiv = op(switch_node.inputs[1], node.inputs[1]) + # Copy over stacktrace for elementwise division op + # from previous elementwise multiplication op. + # An error in the division (e.g. errors due to + # inconsistent shapes or division by zero), + # will point to the new division op. + copy_stack_trace(node.outputs, fdiv) + + fct = [switch(switch_node.inputs[0], fdiv, 0)] + fct[0].tag.values_eq_approx = values_eq_approx_remove_nan + + # Copy over stacktrace for switch op from both previous + # elementwise division op and previous switch op, + # because an error in this part can be caused by either + # of the two previous ops. + copy_stack_trace(node.outputs + switch_node.outputs, fct) + return fct + except NotScalarConstantError: + pass + return False + + +class AlgebraicCanonizer(NodeRewriter): + r"""A `Rewriter` that rewrites algebraic expressions. + + The variable is a `node_rewriter`. It is best used + with a `WalkingGraphRewriter` in in-to-out order. + + Usage: ``AlgebraicCanonizer(main, inverse, reciprocal, calculate)`` + + Parameters + ---------- + main + A suitable `Op` class that is commutative, associative and + takes one to an arbitrary number of inputs, e.g. add or + mul + inverse + An `Op` class such that ``inverse(main(x, y), y) == x`` + (e.g. `sub` or `true_div`). + reciprocal + A function such that ``main(x, reciprocal(y)) == inverse(x, y)`` + (e.g. `neg` or `reciprocal`). + calculate + Function that takes a list of `numpy.ndarray` instances + for the numerator, another list for the denumerator, + and calculates ``inverse(main(\*num), main(\*denum))``. It + takes a keyword argument, `aslist`. If ``True``, the value + should be returned as a list of one element, unless + the value is such that ``value = main()``. In that case, + the return value should be an empty list. + + Examples + -------- + >>> import aesara.tensor as at + >>> from aesara.tensor.rewriting.math import AlgebraicCanonizer + >>> add_canonizer = AlgebraicCanonizer(add, sub, neg, \\ + ... lambda n, d: sum(n) - sum(d)) + >>> mul_canonizer = AlgebraicCanonizer(mul, true_div, inv, \\ + ... lambda n, d: prod(n) / prod(d)) + + Examples of rewrites `mul_canonizer` can perform: + + | x / x -> 1 + | (x * y) / x -> y + | x / y / x -> 1 / y + | x / y / z -> x / (y * z) + | x / (y / z) -> (x * z) / y + | (a / b) * (b / c) * (c / d) -> a / d + | (2.0 * x) / (4.0 * y) -> (0.5 * x) / y + | 2 * x / 2 -> x + | x * y * z -> Elemwise(mul){x,y,z} #only one pass over the memory. + | !-> Elemwise(mul){x,Elemwise(mul){y,z}} + + """ + + def __init__(self, main, inverse_fn, reciprocal_fn, calculate, use_reciprocal=True): + self.main = main + self.inverse = inverse_fn + self.reciprocal = reciprocal_fn + self.calculate = calculate + self.use_reciprocal = use_reciprocal + + self.external_simplifiers = [] + + def add_simplifier(self, simplifier, reason): + self.external_simplifiers.append((reason, simplifier)) + + def tracks(self): + return [self.main, self.inverse, self.reciprocal] + + def get_num_denum(self, inp): + r""" + This extract two lists, ``num`` and ``denum``, such that the input is: + ``self.inverse(self.main(\*num), self.main(\*denum))``. It returns + the two lists in a ``(num, denum)`` pair. + + For example, for main, inverse and ``reciprocal = \*, / and inv()``, + + | input -> returned value (num, denum) + + | x*y -> ([x, y], []) + | inv(x) -> ([], [x]) + | inv(x) * inv(y) -> ([], [x, y]) + | x*y/z -> ([x, y], [z]) + | log(x) / y * (z + x) / y -> ([log(x), z + x], [y, y]) + | (((a / b) * c) / d) -> ([a, c], [b, d]) + | a / (b / c) -> ([a, c], [b]) + | log(x) -> ([log(x)], []) + | x**y -> ([x**y], []) + | x * y * z -> ([x, y, z], []) + + """ + # This function is recursive. The idea is that there is a + # get_num_denum recursion in which the internal ops are all + # one of (main, inverse, reciprocal, DimShuffle) and the + # internal data nodes all have the dtype of the 'input' + # argument. The leaf-Variables of the graph covered by the + # recursion may be of any Variable type. + + if inp.owner is None or inp.owner.op not in [ + self.main, + self.inverse, + self.reciprocal, + ]: + if inp.owner and isinstance(inp.owner.op, DimShuffle): + # If input is a DimShuffle of some input which does + # something like this: + + # * change a vector of length N into a 1xN row matrix + # * change a scalar into a 1x1x1 tensor + # * in general, complete the shape of a tensor + # with broadcastable 1s to the *left* + # Then we will simply discard the DimShuffle and return + # the num/denum of its input + dsn = inp.owner # dimshuffle node + dsop = dsn.op # dimshuffle op + + # the first input of the dimshuffle i.e. the ndarray to redim + dsi0 = dsn.inputs[0] + + # The compatible order is a DimShuffle "new_order" of the form: + # ('x', ..., 'x', 0, 1, 2, ..., dimshuffle_input.type.ndim) + + # That kind of DimShuffle only adds broadcastable + # dimensions on the left, without discarding any + # existing broadcastable dimension and is inserted + # automatically by Elemwise when the inputs have + # different numbers of dimensions (hence why we can + # discard its information - we know we can retrieve it + # later on). + compatible_order = ("x",) * (inp.type.ndim - dsi0.type.ndim) + tuple( + range(dsi0.type.ndim) + ) + if dsop.new_order == compatible_order: + # If the "new_order" is the one we recognize, + # we return the num_denum of the dimshuffled input. + return self.get_num_denum(inp.owner.inputs[0]) + else: + # This is when the input isn't produced by main, + # inverse or reciprocal. + return [inp], [] + else: + return [inp], [] + num = [] + denum = [] + parent = inp.owner + + # We get the (num, denum) pairs for each input + # pairs = [self.get_num_denum(input2) if input2.type.dtype == + # input.type.dtype else ([input2], []) for input2 in + # parent.inputs] + pairs = [self.get_num_denum(input2) for input2 in parent.inputs] + + if parent.op == self.main: + # If we have main(x, y, ...), numx, denumx, numy, denumy, ... + # then num is concat(numx, numy, num...) and denum is + # concat(denumx, denumy, denum...) note that main() can have any + # number of arguments >= 0 concat is list concatenation + num = reduce(list.__iadd__, map(operator.itemgetter(0), pairs)) + denum = reduce(list.__iadd__, map(operator.itemgetter(1), pairs)) + elif parent.op == self.inverse: + # If we have inverse(x, y), numx, denumx, numy and denumy + # then num is concat(numx, denumy) and denum is + # concat(denumx, numy) note that inverse() is binary + num = pairs[0][0] + pairs[1][1] + denum = pairs[0][1] + pairs[1][0] + elif parent.op == self.reciprocal: + # If we have reciprocal(x), numx, denumx + # then num is denumx and denum is numx + # note that reciprocal() is unary + num = pairs[0][1] + denum = pairs[0][0] + return num, denum + + def merge_num_denum(self, num, denum): + r""" + Utility function which takes two lists, num and denum, and + returns something which is equivalent to inverse(main(\*num), + main(\*denum)), but depends on the length of num and the length + of denum (in order to minimize the number of operations). + + Let n = len(num) and d = len(denum): + + | n=0, d=0: neutral element (given by self.calculate([], [])) + | (for example, this would be 0 if main is addition + | and 1 if main is multiplication) + | n=1, d=0: num[0] + | n=0, d=1: reciprocal(denum[0]) + | n=1, d=1: inverse(num[0], denum[0]) + | n=0, d>1: reciprocal(main(\*denum)) + | n>1, d=0: main(\*num) + | n=1, d>1: inverse(num[0], main(\*denum)) + | n>1, d=1: inverse(main(\*num), denum[0]) + | n>1, d>1: inverse(main(\*num), main(\*denum)) + + Given the values of n and d to which they are associated, all + of the above are equivalent to: + inverse(main(\*num), main(\*denum)) + + """ + + ln, ld = len(num), len(denum) + if not ln and not ld: + return as_tensor_variable(self.calculate([], [])) + if not ln: + if self.use_reciprocal: + return self.reciprocal(self.merge_num_denum(denum, [])) + else: + ln = [self.calculate([], [], aslist=False)] + if not ld: + if ln == 1: + # num[0] should always be a variable + assert isinstance(num[0], Variable) + return num[0] + else: + return self.main(*num) + return self.inverse( + self.merge_num_denum(num, []), self.merge_num_denum(denum, []) + ) + + def simplify(self, num, denum, out_type): + """ + Shorthand for: + + .. code-block:: python + + self.simplify_constants(*self.simplify_factors(num, denum)) + + """ + rval = self.simplify_constants( + *self.simplify_factors(num, denum), out_type=out_type + ) + for reason, simplifier in self.external_simplifiers: + # TODO: document that 'reason' is associated with this + # simplification to help auditing when things go + # wrong + rval = simplifier(*rval) + return rval + + def simplify_factors(self, num, denum): + """ + For any Variable r which is both in num and denum, removes it + from both lists. Modifies the lists inplace. Returns the + modified lists. For example: + + | [x], [x] -> [], [] + | [x, y], [x] -> [y], [] + | [a, b], [c, d] -> [a, b], [c, d] + + """ + ln = len(num) + ld = len(denum) + if ld > 2 and ln > 2: + # Faster version for "big" inputs. + while True: + s = set(num) + # Inputs can appear multiple times + redo = len(s) != len(num) + inter = s.intersection(denum) + for v in inter: + num.remove(v) + denum.remove(v) + if not redo or not inter: + break + else: + for v in list(num): + if v in denum: + num.remove(v) + denum.remove(v) + return num, denum + + def simplify_constants(self, orig_num, orig_denum, out_type=None): + """ + Find all constants and put them together into a single constant. + + Finds all constants in orig_num and orig_denum (using + get_constant) and puts them together into a single + constant. The constant is inserted as the first element of the + numerator. If the constant is the neutral element, it is + removed from the numerator. + + Examples + -------- + Let main be multiplication: + + | [2, 3, x], [] -> [6, x], [] + | [x, y, 2], [4, z] -> [0.5, x, y], [z] + | [x, 2, y], [z, 2] -> [x, y], [z] + + """ + # Lists representing the numerator and denumerator + num, denum = [], [] + + # Lists representing the *constant* elements of num and denum + numct, denumct = [], [] + + for v in orig_num: + ct = get_constant(v) + if ct is not None: + # We found a constant in the numerator! + # We add it to numct + numct.append(ct) + else: + num.append(v) + for v in orig_denum: + ct = get_constant(v) + if ct is not None: + denumct.append(ct) + else: + denum.append(v) + + if self.use_reciprocal or num: + # This will calculate either: + # [inverse(main(*numct), main(*denumct))] + # [] - if inverse(main(*numct), main(*denumct)) is the + # neutral element + ct = self.calculate(numct, denumct, aslist=True, out_type=out_type) + else: + # This happens if we don't allow the reciprocal and the + # numerator is empty. That means we will need to represent + # reciprocal(x) like inverse(neutral_element, x) so + # we can't allow ct == [] + # TODO: why is this branch needed when merge_num_denum + # does it for us? + ct = [self.calculate(numct, denumct, aslist=False, out_type=out_type)] + + # Wrapping ct in a Constant with the right dtype + ct = [constant(c, dtype=out_type.dtype) for c in ct] + + if orig_num and len(numct) == 1 and len(denumct) == 0 and ct: + # In that case we should only have one constant in `ct`. + assert len(ct) == 1 + first_num_ct = get_constant(orig_num[0]) + if first_num_ct is not None and ct[0].type.values_eq( + ct[0].data, first_num_ct + ): + # This is an important trick :( if it so happens that: + # * there's exactly one constant on the numerator and none on + # the denominator + # * it's not the neutral element (ct is an empty list in that + # case) + # * the constant is the same as the first argument in the + # numerator (we only check the first argument because the + # canonizer puts the computed constants first) + # -> then we return very exactly the original num/denum. + # If we don't do that the rewrite will just loop + # infinitely because it will not catch on that there are + # no changes to be made and every time it will want to + # replace something by the same thing... + # Note that it is important to use `values_eq` instead of + # the == operator, to handle NaN values correctly. + return orig_num, orig_denum + + return ct + num, denum + + def transform(self, fgraph, node): + op = node.op + if op not in [self.main, self.inverse, self.reciprocal]: + return False + + assert len(node.outputs) == 1 + out = node.outputs[0] + + out_clients = fgraph.clients.get(out) + + if not out_clients: + return False + + # check if any of the clients of this node would be part of + # this canonized graph... if so, we do nothing and wait for + # them to be transformed. + for c, c_idx in out_clients: + if c == "output": + continue + while ( + isinstance(getattr(c, "op", None), DimShuffle) + and len(fgraph.clients[c.outputs[0]]) <= 1 + ): + c = fgraph.clients[c.outputs[0]][0][0] + if getattr(c, "op", "") in [self.main, self.inverse, self.reciprocal]: + return False + + # Here we make the canonical version of the graph around this node + # See the documentation of get_num_denum and simplify + orig_num, orig_denum = self.get_num_denum(node.outputs[0]) + num, denum = self.simplify(list(orig_num), list(orig_denum), out.type) + + def same(x, y): + return len(x) == len(y) and all(np.all(xe == ye) for xe, ye in zip(x, y)) + + if ( + same(orig_num, num) + and same(orig_denum, denum) + and + # Check to see if we've collapsed some nested ops. + not ( + len(orig_denum) == 0 + and + # Make sure this change would increase the number of vector + # arguments--decreasing the number of unnecessary `self.main` + # nodes. + len(node.inputs) < len(orig_num) + ) + and + # Do a similar check for the reciprocal op. + not ( + self.use_reciprocal + and node.op == self.reciprocal + and len(orig_num) == 0 + and node.inputs[0].owner + and len(node.inputs[0].owner.inputs) < len(orig_denum) + ) + ): + return False + + new = self.merge_num_denum(num, denum) + if new.type.dtype != out.type.dtype: + new = cast(new, out.type.dtype) + + if new.type != out.type: + new = fill_chain(new, node.inputs)[0] + + if new.type == out.type: + new.tag.values_eq_approx = values_eq_approx_remove_inf_nan + copy_stack_trace(out, new) + return [new] + else: + return False + + def __str__(self): + return getattr( + self, + "name", + f"AlgebraicCanonizer({self.main}, {self.inverse}, {self.reciprocal})", + ) + + +def mul_calculate(num, denum, aslist=False, out_type=None): + if not num and not denum: + # Smallest 1 possible. + if aslist: + return [] + else: + return np.int8(1) + + # Make sure we do not accidentally upcast data types. + if out_type is None: + out_dtype = aes.upcast(*[v.dtype for v in (num + denum)]) + else: + out_dtype = out_type.dtype + one = _asarray(1, dtype=out_dtype) + + v = reduce(np.multiply, num, one) / reduce(np.multiply, denum, one) + if aslist: + if np.all(v == 1): + return [] + else: + return [v] + return v + + +local_mul_canonizer = AlgebraicCanonizer( + mul, true_div, reciprocal, mul_calculate, False +) +register_canonicalize(local_mul_canonizer, name="local_mul_canonizer") + + +@register_canonicalize +@node_rewriter([neg]) +def local_neg_to_mul(fgraph, node): + if node.op == neg: + return [mul(np.array(-1, dtype=node.inputs[0].dtype), node.inputs[0])] + + +@register_specialize +@node_rewriter([Sum, Prod]) +def local_sum_prod_mul_by_scalar(fgraph, node): + """ + sum(scalar * smth) -> scalar * sum(smth) + sum(-smth) -> -sum(smth) + + or + + prod(scalar * smth) -> scalar ** size(smth) * prod(smth) + prod(-smth) -> -1 ** size(smth) * prod(smth) + + """ + # TODO: if the the thing inside the Sum is a division, + # we should get at the numerator.... + if isinstance(node.op, (Sum, Prod)): + (node_inps,) = node.inputs + if node_inps.owner and node_inps.owner.op == mul: + terms = node_inps.owner.inputs + scalars = [t.dimshuffle() for t in terms if all(t.type.broadcastable)] + + if len(scalars) == 0: + return + + non_scalars = [t for t in terms if not all(t.broadcastable)] + + # Perform the op only on the non-scalar inputs, if applicable + if len(non_scalars) == 0: + new_op_input_nb_elements = 1 + new_op_output = 1 + elif len(non_scalars) == 1: + new_op_input_nb_elements = non_scalars[0].size + new_op_output = node.op(non_scalars[0]) + else: + new_op_input = mul(*non_scalars) + # We assume that errors always come from the prod/mul op in the + # original computational graph, and therefore need to only + # copy over its output stacktrace. + copy_stack_trace(node.outputs, new_op_input) + + new_op_input_nb_elements = new_op_input.size + new_op_output = node.op(new_op_input) + + if len(non_scalars) != 0: + # Copy over stacktrace from previous output to new mul op, + # for same reason as above. + copy_stack_trace(node.outputs, new_op_output) + + # If `node.op` is a `Prod`, then the scalars need to be raised to + # the power of the number of elements in the input to the `Prod` + if isinstance(node.op, Prod) and new_op_input_nb_elements != 1: + + scalars = [s**new_op_input_nb_elements for s in scalars] + + # Scale the output of the op by the scalars and return as + # replacement for the original output + mul_inputs = scalars + if new_op_input_nb_elements != 1: + mul_inputs.append(new_op_output) + + if len(mul_inputs) == 1: + # Copy over stacktrace from previous output to new mul op, + # for same reason as above. + copy_stack_trace(node.outputs, mul_inputs) + + return mul_inputs + else: + ret = mul(*mul_inputs) + # Copy over stacktrace from previous output to new mul op, + # for same reason as above. + copy_stack_trace(node.outputs, [ret] + mul_inputs) + + return [ret] + + if isinstance(node.op, Sum) and node_inps.owner and node_inps.owner.op == neg: + s = node.op(node_inps.owner.inputs[0]) + ret = neg(s) + # There are never errors in the negative op, thus + # we need only to copy over stacktrace from previous output node to + # the two new ops. + copy_stack_trace(node.outputs, [s, ret]) + + return [ret] + + +@register_specialize +@node_rewriter([Elemwise]) +def local_elemwise_sub_zeros(fgraph, node): + """ + Elemwise{sub}(X,X) -> zeros_like(X) + """ + if ( + isinstance(node.op, Elemwise) + and node.op.scalar_op.nin == 2 + and node.op.scalar_op == aes.sub + and node.inputs[0] == node.inputs[1] + ): + res = zeros_like(node.inputs[0]) + # Copy over stacktrace from previous output. + # This could help for failures due to out-of-memory. + copy_stack_trace(node.outputs, res) + return [res] + + +@register_useless +@register_specialize +@register_stabilize +@register_canonicalize +@node_rewriter([Elemwise]) +def local_useless_elemwise_comparison(fgraph, node): + """... + + # Comparing to itself is constant + Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) + Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) + Elemwise[{minimum,maximum}](X, X) -> X + + # Comparing shape to 0 can be constant + Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X) + Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) + Elemwise[maximum](X.shape[i], 0) -> X.shape[i] + Elemwise[maximum](0, X.shape[i]) -> X.shape[i] + Elemwise[minimum](X.shape[i], 0) -> 0 + Elemwise[minimum](0, X.shape[i]) -> 0 + + # The shape can be replaced with sum of shapes + Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) + Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) + + # Shapes are never negative + # Needed by Reshape.infer_shape + Elemwise[EQ](Subtensor(Shape(x)), -N) -> Elemwise[zeros](X) + + Notes + ----- + These cases appear in the graph generated by scan. These rewrites will make + the graph easier to read. + + """ + if not isinstance(node.op, Elemwise): + return + if node.op.scalar_op.nin != 2: + return + + # We call zeros_like and one_like with opt=True to generate a + # cleaner graph. + dtype = node.outputs[0].dtype + + # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) + if ( + isinstance(node.op.scalar_op, (aes.LT, aes.GT)) + and node.inputs[0] is node.inputs[1] + ): + res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + # Copy over stacktrace from previous output. + copy_stack_trace(node.outputs, res) + return [res] + # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) + if ( + isinstance(node.op.scalar_op, (aes.LE, aes.GE)) + and node.inputs[0] is node.inputs[1] + ): + res = ones_like(node.inputs[0], dtype=dtype, opt=True) + + # Copy over stacktrace from previous output. + copy_stack_trace(node.outputs, res) + return [res] + # Elemwise[{minimum,maximum}](X, X) -> X + if ( + isinstance(node.op.scalar_op, (aes.ScalarMinimum, aes.ScalarMaximum)) + and node.inputs[0] is node.inputs[1] + ): + res = node.inputs[0] + # Copy over stacktrace from previous output. + copy_stack_trace(node.outputs, res) + return [res] + + # Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X) + if ( + isinstance(node.op.scalar_op, aes.LT) + and node.inputs[0].owner + and isinstance(node.inputs[0].owner.op, Shape_i) + and extract_constant(node.inputs[1], only_process_constants=True) == 0 + ): + res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + # Copy over stacktrace from previous output. + copy_stack_trace(node.outputs, res) + return [res] + # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) + if ( + isinstance(node.op.scalar_op, aes.GE) + and node.inputs[0].owner + and isinstance(node.inputs[0].owner.op, Shape_i) + and extract_constant(node.inputs[1], only_process_constants=True) == 0 + ): + res = ones_like(node.inputs[0], dtype=dtype, opt=True) + # Copy over stacktrace from previous output. + copy_stack_trace(node.outputs, res) + return [res] + # Elemwise[maximum](X.shape[i], 0) -> X.shape[i] + if ( + isinstance(node.op.scalar_op, aes.ScalarMaximum) + and node.inputs[0].owner + and isinstance(node.inputs[0].owner.op, Shape_i) + and extract_constant(node.inputs[1], only_process_constants=True) == 0 + ): + # No need to copy over stacktrace. + return [node.inputs[0]] + # Elemwise[maximum](0, X.shape[i]) -> X.shape[i] + if ( + isinstance(node.op.scalar_op, aes.ScalarMaximum) + and extract_constant(node.inputs[0], only_process_constants=True) == 0 + and node.inputs[1].owner + and isinstance(node.inputs[1].owner.op, Shape_i) + ): + # No need to copy over stacktrace. + return [node.inputs[1]] + # Elemwise[minimum](X.shape[i], 0) -> 0 + if ( + isinstance(node.op.scalar_op, aes.ScalarMinimum) + and node.inputs[0].owner + and isinstance(node.inputs[0].owner.op, Shape_i) + and extract_constant(node.inputs[1], only_process_constants=True) == 0 + ): + res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + # Copy over stacktrace from previous output. + copy_stack_trace(node.outputs, res) + return [res] + + # Elemwise[minimum](0, X.shape[i]) -> 0 + if ( + isinstance(node.op.scalar_op, aes.ScalarMinimum) + and extract_constant(node.inputs[0], only_process_constants=True) == 0 + and node.inputs[1].owner + and isinstance(node.inputs[1].owner.op, Shape_i) + ): + res = zeros_like(node.inputs[1], dtype=dtype, opt=True) + # Copy over stacktrace from previous output. + copy_stack_trace(node.outputs, res) + return [res] + + # Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) + if ( + isinstance(node.op.scalar_op, aes.LT) + and node.inputs[0].owner + and isinstance(node.inputs[0].owner.op, Elemwise) + and isinstance(node.inputs[0].owner.op.scalar_op, aes.Add) + and all( + isinstance(var.owner and var.owner.op, Shape_i) + for var in node.inputs[0].owner.inputs + ) + and extract_constant(node.inputs[1], only_process_constants=True) == 0 + ): + res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + # Copy over stacktrace from previous output. + copy_stack_trace(node.outputs, res) + return [res] + # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) + if ( + isinstance(node.op.scalar_op, aes.GE) + and node.inputs[0].owner + and isinstance(node.inputs[0].owner.op, Elemwise) + and isinstance(node.inputs[0].owner.op.scalar_op, aes.Add) + and all( + isinstance(var.owner and var.owner.op, Shape_i) + for var in node.inputs[0].owner.inputs + ) + and extract_constant(node.inputs[1], only_process_constants=True) == 0 + ): + res = ones_like(node.inputs[0], dtype=dtype, opt=True) + + # Copy over stacktrace from previous output. + copy_stack_trace(node.outputs, res) + return [res] + + # Elemwise[EQ](Subtensor(Shape(x)), -N) + # Elemwise[EQ](somegraph that only depend of shape, -N) + # TODO: handle the case where the -N is on either side + """ + |Elemwise{eq,no_inplace} [id B] '' + | |Subtensor{int64} [id C] '' + | | |Join [id D] '' + | | | |TensorConstant{0} [id E] + | | | |Subtensor{int64:int64:} [id F] '' + | | | | |Shape [id G] '' + """ + + def investigate(node): + "Return True if values will be shapes, so >= 0" + if isinstance(node.op, (Shape, Shape_i)): + return True + elif isinstance(node.op, Subtensor) and node.inputs[0].owner: + return investigate(node.inputs[0].owner) + elif isinstance(node.op, Join): + return all(v.owner and investigate(v.owner) for v in node.inputs[1:]) + elif isinstance(node.op, MakeVector): + return all(v.owner and investigate(v.owner) for v in node.inputs) + + if ( + isinstance(node.op.scalar_op, aes.EQ) + and node.inputs[0].owner + and investigate(node.inputs[0].owner) + ): + try: + cst = get_scalar_constant_value(node.inputs[1], only_process_constants=True) + + res = zeros_like(node.inputs[0], dtype=dtype, opt=True) + + if cst < 0: + # Copy over stacktrace from previous output. + copy_stack_trace(node.outputs, res) + + return [res] + + except NotScalarConstantError: + pass + return + + +@register_canonicalize +@register_specialize +@node_rewriter([Sum, Prod]) +def local_sum_prod_div_dimshuffle(fgraph, node): + """ + sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b, + if dimension l of the DimShuffle is 'x' + + or + + prod(a / dimshuffle{...}(b), axis=l) -> + prod(a, axis={...}) / b ** a.shape[l], + if dimension l of the DimShuffle is 'x' + """ + + # It does not make much sense now to extend it to the case where the + # dimshuffle is in the numerator, since elemwise inversion of the + # denominator would still be needed before the summation or production. + + if isinstance(node.op, (Sum, Prod)): + axis = node.op.axis + if axis is None: + axis = list(range(node.inputs[0].ndim)) + node_input = node.inputs[0] + if node_input.owner and node_input.owner.op == true_div: + numerator, denominator = node_input.owner.inputs + + if denominator.owner and isinstance(denominator.owner.op, DimShuffle): + dimshuffle_input = denominator.owner.inputs[0] + dimshuffle_order = denominator.owner.op.new_order + + compatible_dims = [] + incompatible_dims = [] + for ax in axis: + if ax < len(dimshuffle_order) and dimshuffle_order[ax] == "x": + compatible_dims.append(ax) + else: + incompatible_dims.append(ax) + reordered_incompatible_dims = [] + for ic_ax in incompatible_dims: + reordered_incompatible_dims.append( + ic_ax - sum(1 for c_ax in compatible_dims if c_ax < ic_ax) + ) + + if len(compatible_dims) > 0: + optimized_dimshuffle_order = list( + ax + for i, ax in enumerate(dimshuffle_order) + if (i not in axis) or (ax != "x") + ) + + # Removing leading 'x' (since it will be done automatically) + while ( + len(optimized_dimshuffle_order) > 0 + and optimized_dimshuffle_order[0] == "x" + ): + del optimized_dimshuffle_order[0] + + # if optimized_dimshuffle_order is sorted with + # not 'x', then dimshuffle is useless. + if all(i == e for i, e in enumerate(optimized_dimshuffle_order)): + optimized_dimshuffle = dimshuffle_input + else: + optimized_dimshuffle = DimShuffle( + dimshuffle_input.type.broadcastable, + optimized_dimshuffle_order, + )(dimshuffle_input) + + if isinstance(node.op, Sum): + op_on_compatible_dims = at_sum(numerator, axis=compatible_dims) + rval = true_div(op_on_compatible_dims, optimized_dimshuffle) + if len(reordered_incompatible_dims) > 0: + rval = at_sum(rval, axis=reordered_incompatible_dims) + elif isinstance(node.op, Prod): + op_on_compatible_dims = prod(numerator, axis=compatible_dims) + dtype = numerator.dtype + rval = true_div( + op_on_compatible_dims, + ( + optimized_dimshuffle + ** prod( + [ + numerator.shape[ax].astype(dtype) + for ax in compatible_dims + ] + ) + ), + ) + if len(reordered_incompatible_dims) > 0: + rval = prod(rval, axis=reordered_incompatible_dims) + return [rval] + + +@register_canonicalize +@node_rewriter([Sum, Prod]) +def local_sum_prod_all_to_none(fgraph, node): + """ + Sum{0,1,...N} -> Sum{} or + Prod{0,1,...N} -> Prod{} + + """ + if isinstance(node.op, Sum) or isinstance(node.op, Prod): + op_type = Sum if isinstance(node.op, Sum) else Prod + # if all the axes are named, then use None as a shorthand + # this permits more merging + if node.op.axis is None: + return + if set(node.op.axis) == set(range(node.inputs[0].type.ndim)): + return [op_type(axis=None, dtype=node.op.dtype)(node.inputs[0])] + + +@register_canonicalize +@node_rewriter([Sum, Prod]) +def local_op_of_op(fgraph, node): + """ + Prod(Prod()) -> single Prod() + or + Sum(Sum()) -> single Sum() + + """ + if isinstance(node.op, Prod) or isinstance(node.op, Sum): + op_type = Sum if isinstance(node.op, Sum) else Prod + (node_inps,) = node.inputs + out_dtype = node.op.dtype + # This is done to make sure the rewrite doesn't affect other + # computations. + if len(fgraph.clients[node_inps]) == 1: + if node_inps.owner and (isinstance(node_inps.owner.op, node.op.__class__)): + + # check to see either the inner or outer prod is doing a + # product over all axis, in which case we can remove it + if node_inps.owner.op.axis is None or node.op.axis is None: + return [op_type(None, dtype=out_dtype)(node_inps.owner.inputs[0])] + + # figure out which axes were in the original sum + newaxis = list(tuple(node_inps.owner.op.axis)) + for i in node.op.axis: + new_i = i + for ii in node_inps.owner.op.axis: + if new_i >= ii: + new_i += 1 + assert new_i not in newaxis + newaxis.append(new_i) + + assert len(newaxis) == len( + list(node_inps.owner.op.axis) + list(node.op.axis) + ) + + combined = op_type(newaxis, dtype=out_dtype) + return [combined(node_inps.owner.inputs[0])] + + +ALL_REDUCE = ( + [ + CAReduce, + All, + Any, + Sum, + Prod, + ProdWithoutZeros, + ] + + CAReduce.__subclasses__() + + NonZeroCAReduce.__subclasses__() +) + + +@register_canonicalize +@register_uncanonicalize # Needed for MaxAndArgmax -> CAReduce +@node_rewriter(ALL_REDUCE) +def local_reduce_join(fgraph, node): + """ + CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b) + + Notes + ----- + Supported scalar.op are Maximum, Minimum in some cases and Add and Mul in + all cases. + + Currently we must reduce on axis 0. It is probably extensible to the case + where we join and reduce on the same set of axis. + + """ + if ( + isinstance(node.op, CAReduce) + and node.inputs[0].owner + and isinstance(node.inputs[0].owner.op, Join) + ): + join_node = node.inputs[0].owner + if extract_constant(join_node.inputs[0], only_process_constants=True) != 0: + return + + if isinstance(node.op.scalar_op, (aes.ScalarMaximum, aes.ScalarMinimum)): + # Support only 2 inputs for now + if len(join_node.inputs) != 3: + return + elif not isinstance(node.op.scalar_op, (aes.Add, aes.Mul)): + return + elif len(join_node.inputs) <= 2: + # This is a useless join that should get removed by another rewrite? + return + + new_inp = [] + for inp in join_node.inputs[1:]: + inp = inp.owner + if not inp: + return + if not isinstance(inp.op, DimShuffle) or inp.op.new_order != ("x",) + tuple( + range(inp.inputs[0].ndim) + ): + return + new_inp.append(inp.inputs[0]) + ret = Elemwise(node.op.scalar_op)(*new_inp) + + if ret.dtype != node.outputs[0].dtype: + # The reduction do something about the dtype. + return + + reduce_axis = node.op.axis + if reduce_axis is None: + reduce_axis = tuple(range(node.inputs[0].ndim)) + + if len(reduce_axis) != 1 or 0 not in reduce_axis: + return + + # We add the new check late to don't add extra warning. + try: + join_axis = get_scalar_constant_value( + join_node.inputs[0], only_process_constants=True + ) + + if join_axis != reduce_axis[0]: + return + except NotScalarConstantError: + return + + return [ret] + + +@register_canonicalize("fast_compile", "local_cut_useless_reduce") +@register_useless("local_cut_useless_reduce") +@node_rewriter(ALL_REDUCE) +def local_useless_reduce(fgraph, node): + """Sum(a, axis=[]) -> a""" + if isinstance(node.op, CAReduce): + (summed,) = node.inputs + # if reduce were doing anything, the output ndim would be reduced + if summed.type == node.outputs[0].type: + return [summed] + + +@register_canonicalize +@register_uncanonicalize +@register_specialize +@node_rewriter(ALL_REDUCE) +def local_reduce_broadcastable(fgraph, node): + """Remove reduction over broadcastable dimensions.""" + if isinstance(node.op, CAReduce): + (reduced,) = node.inputs + odtype = node.outputs[0].dtype + if node.op.axis is None: + if all(reduced.broadcastable): + return [reduced.dimshuffle().astype(odtype)] + else: + axis = list(node.op.axis) + cuttable = [a for a in axis if reduced.broadcastable[a]] + if cuttable: + # -- we can remove some axes of summation. + new_axis = [] + pattern = [] + ii = 0 + for p in range(reduced.ndim): + if p not in cuttable: + if p in axis: + new_axis.append(ii) + pattern.append(p) + ii += 1 + new_reduced = reduced.dimshuffle(*pattern) + if new_axis: + if type(node.op) == CAReduce: + # This case handles `CAReduce` instances + # (e.g. generated by `scalar_elemwise`), and not the + # scalar `Op`-specific subclasses + # TODO FIXME: This highlights a major design flaw in + # `CAReduce` (or at least our use of it), and it needs + # to be fixed + new_op = node.op.__class__(node.op.scalar_op, axis=new_axis) + else: + new_op = node.op.__class__(axis=new_axis) + return [new_op(new_reduced)] + else: + # -- in this case we can remove the reduction completely + return [new_reduced.astype(odtype)] + + +@register_specialize +@node_rewriter([Sum, Prod]) +def local_opt_alloc(fgraph, node): + """ + sum(alloc(constant,shapes...)) => constant*prod(shapes) + or + prod(alloc(constant,shapes...)) => constant**prod(shapes) + + """ + if isinstance(node.op, Sum) or isinstance(node.op, Prod): + (node_inps,) = node.inputs + if node_inps.owner and isinstance(node_inps.owner.op, Alloc): + inp = node_inps.owner.inputs[0] + shapes = node_inps.owner.inputs[1:] + try: + val = get_scalar_constant_value(inp, only_process_constants=True) + assert val.size == 1 + val = val.reshape(1)[0] + # check which type of op + size = mul(*shapes) + if inp.dtype in ("float16", "float32"): + # shapes are ints and normally int64. + # We don't want to have a float64 upcast + # We don't want to downcast to float16 + # as we fear it could loose too much precision + # that will be amplified by the mul/pow below. + size = size.astype("float32") + if node.op.axis is None or node.op.axis == tuple(range(inp.ndim)): + if isinstance(node.op, Sum): + val = val * size + else: + val = val**size + # Sum can change the input dtype (upcast or bool + # -> float32) by default or by user request. + # We can ignore the acc_dtype, as there is only 1 + # elemwise we will do and not a sequence, so there is no + # accumulation of errors. + # So mostly, we just need to cast the output to the old + # dtype. + val = val.astype(node.outputs[0].dtype) + return [val] + to_prod = [shapes[i] for i in range(len(shapes)) if i in node.op.axis] + if to_prod: + size = mul(*to_prod) + if isinstance(node.op, Sum): + val *= size + else: + val = val**size + # See comments above. + val = val.astype(node.outputs[0].dtype) + return [ + alloc( + val, + *[ + shapes[i] + for i in range(len(shapes)) + if i not in node.op.axis + ], + ) + ] + except NotScalarConstantError: + pass + + +@register_specialize +@node_rewriter([neg]) +def local_neg_div_neg(fgraph, node): + """ + - (-a / b) -> a / b + + Also performs - (c / b) -> ((-c) / b) when c is a scalar constant. + + """ + if node.op == neg: + if node.inputs[0].owner and node.inputs[0].owner.op == true_div: + frac = node.inputs[0] + num, denom = frac.owner.inputs + if num.owner and num.owner.op == neg: + if len(fgraph.clients[frac]) == 1: + # No other clients of the original division + new_num = num.owner.inputs[0] + return [true_div(new_num, denom)] + elif all(num.broadcastable) and isinstance(num, Constant): + if len(fgraph.clients[frac]) == 1: + new_num = -num.data + return [true_div(new_num, denom)] + + +@register_canonicalize +@node_rewriter([mul]) +def local_mul_zero(fgraph, node): + """ + As part of canonicalization, we replace multiplication by zero + with zero. + + """ + if node.op == mul: + otype = node.outputs[0].type + + for i in node.inputs: + try: + value = get_scalar_constant_value(i) + except NotScalarConstantError: + continue + # print 'MUL by value', value, node.inputs + if value == 0: + # print '... returning zeros' + return fill_chain(_asarray(0, dtype=otype.dtype), node.inputs) + + +# TODO: Add this to the canonicalization to reduce redundancy. +@register_specialize +@node_rewriter([true_div]) +def local_div_to_reciprocal(fgraph, node): + if node.op == true_div and np.all(get_constant(node.inputs[0]) == 1.0): + out = node.outputs[0] + new_out = reciprocal(local_mul_canonizer.merge_num_denum(node.inputs[1:], [])) + # The ones could have forced upcasting + if new_out.dtype != out.dtype: + new_out = cast(new_out, dtype=out.dtype) + # The ones could have forced a specific length + if not out.type.is_super(new_out.type): + new_out = broadcast_like(new_out, out, fgraph) + return [new_out] + else: + return False + + +@register_canonicalize +@node_rewriter([reciprocal]) +def local_reciprocal_canon(fgraph, node): + if node.op == reciprocal: + return [at_pow(node.inputs[0], -1.0)] + else: + return False + + +@register_canonicalize +@node_rewriter([at_pow]) +def local_pow_canonicalize(fgraph, node): + if node.op == at_pow: + cst = get_constant(node.inputs[1]) + if cst == 0: + return [broadcast_like(1, node.outputs[0], fgraph)] + if cst == 1: + return [broadcast_like(node.inputs[0], node.outputs[0], fgraph)] + else: + return False + + +@register_specialize +@node_rewriter([mul]) +def local_mul_to_sqr(fgraph, node): + """ + x*x -> sqr(x) + """ + if node.op == mul: + if len(node.inputs) == 2: + if node.inputs[0] is node.inputs[1]: + return [sqr(node.inputs[0])] + + +@register_canonicalize +@node_rewriter([int_div]) +def local_intdiv_by_one(fgraph, node): + """x // 1 -> x""" + if node.op in [int_div]: + if isinstance(node.inputs[1], TensorConstant) and np.all( + node.inputs[1].value == 1 + ): + return [node.inputs[0].astype(node.outputs[0].dtype)] + + +@register_canonicalize +@register_specialize +@node_rewriter([int_div, true_div]) +def local_zero_div(fgraph, node): + """0 / x -> 0""" + if isinstance(node.op, Elemwise) and isinstance( + node.op.scalar_op, (aes.IntDiv, aes.TrueDiv) + ): + if get_constant(node.inputs[0]) == 0: + ret = broadcast_like(0, node.outputs[0], fgraph) + ret.tag.values_eq_approx = values_eq_approx_remove_nan + return [ret] + + +@register_specialize +@node_rewriter([at_pow]) +def local_pow_specialize(fgraph, node): + # here, we are past the point of canonicalization, so we don't want + # to put in un-necessary fills. + if node.op == at_pow: + # the idea here is that we have pow(x, y) + odtype = node.outputs[0].dtype + xsym = node.inputs[0] + ysym = node.inputs[1] + y = get_constant(ysym) + if (y is not None) and encompasses_broadcastable( + xsym.type.broadcastable, ysym.type.broadcastable + ): + rval = None + + if np.all(y == 2): + rval = [sqr(xsym)] + if np.all(y == 1): + rval = [xsym] + if np.all(y == 0): + rval = [fill(xsym, np.asarray(1, dtype=odtype))] + if np.all(y == 0.5): + rval = [sqrt(xsym)] + if np.all(y == -0.5): + rval = [reciprocal(sqrt(xsym))] + if np.all(y == -1): + rval = [reciprocal(xsym)] + if np.all(y == -2): + rval = [reciprocal(sqr(xsym))] + if rval: + rval[0] = cast(rval[0], odtype) + assert rval[0].type == node.outputs[0].type, (rval, node.outputs) + return rval + else: + return False + + +@register_specialize_device +@node_rewriter([at_pow]) +def local_pow_specialize_device(fgraph, node): + """ + This rewrite is not the same on all device. We do it only on cpu here. + """ + if node.op == at_pow: + # the idea here is that we have pow(x, y) + odtype = node.outputs[0].dtype + xsym = node.inputs[0] + ysym = node.inputs[1] + y = get_constant(ysym) + + # the next line is needed to fix a strange case that I don't + # know how to make a separate test. + # That happen in the `test_log_erfc` test. + # y is a ndarray with dtype int8 and value 2,4 or 6. This make + # the abs(y) <= 512 fail! + # taking the value outside ndarray solve the problem. + # it could be that in that case, numpy make the comparison + # into the wrong type(do in int8 that overflow.) + if isinstance(y, np.ndarray): + assert y.size == 1 + try: + y = y[0] + except IndexError: + pass + if (y is not None) and encompasses_broadcastable( + xsym.type.broadcastable, ysym.type.broadcastable + ): + rval = None + # 512 is too small for the cpu and too big for some gpu! + if abs(y) == int(abs(y)) and abs(y) <= 512: + pow2 = [xsym] + pow2_scal = [aes.get_scalar_type(xsym.dtype)()] + y_to_do = abs(y) + for i in range(int(np.log2(y_to_do))): + pow2.append(sqr(pow2[i])) + pow2_scal.append(aes.sqr(pow2_scal[i])) + rval1 = None + rval1_scal = None + while y_to_do > 0: + log_to_do = int(np.log2(y_to_do)) + if rval1: + rval1 *= pow2[log_to_do] + rval1_scal *= pow2_scal[log_to_do] + else: + rval1 = pow2[log_to_do] + rval1_scal = pow2_scal[log_to_do] + y_to_do -= 2**log_to_do + + if abs(y) > 2: + # We fuse all the pow together here to make + # compilation faster + rval1 = Elemwise( + aes.Composite([pow2_scal[0]], [rval1_scal]) + ).make_node(xsym) + if y < 0: + rval = [reciprocal(rval1)] + else: + rval = [rval1] + if rval: + rval[0] = cast(rval[0], odtype) + assert rval[0].type == node.outputs[0].type, (rval, node.outputs) + return rval + + +@register_specialize +@node_rewriter([mul]) +def local_mul_specialize(fgraph, node): + """ + Remove special-case constants from mul arguments and useless neg in inputs. + + mul(-1, x) -> neg(x) + mul(1, x, y) -> mul(x, y) + mul(0, ...) -> alloc(0, shapes...) + + This is not done if we would add more nodes in the graph, like with: + + mul(-1, x, y) -/-> neg(mul(x, y)) + + """ + # here, we are past the point of canonicalization, so we don't + # want to put in un-necessary fills. + # + # at this point [post canonicalize], mul() may have many inputs. + if node.op == mul: + # the idea here is that we have pow(x, y) + has_neg = False + new_inputs = [] + nb_neg_node = 0 + nb_cst = 0 + for inp in node.inputs: + # remove any neg arguments + while inp.owner and inp.owner.op == neg: + has_neg ^= True + inp = inp.owner.inputs[0] + nb_neg_node += 1 + + # remove special case arguments of 1, -1 or 0 + y = get_constant(inp) + if y == 1.0: + nb_cst += 1 + elif y == -1.0: + nb_cst += 1 + has_neg ^= True # toggles + elif y == 0.0: + # if we find any zero, we just return right away + return [broadcast_like(0, node.outputs[0], fgraph)] + else: + new_inputs.append(inp) + + if new_inputs != node.inputs: + if new_inputs: + if len(new_inputs) == 1: + if has_neg: + if new_inputs[0].dtype in (uint_dtypes + ["bool"]): + return + else: + rval = -new_inputs[0] + else: + rval = new_inputs[0] + else: + # The next case would cause a replace by an equivalent case. + if has_neg and nb_neg_node == 0 and nb_cst == 1: + return + elif has_neg: + # Don't add an extra neg node as we can't + # fully replace this mul by a neg. + m1 = np.asarray(-1, dtype=node.outputs[0].dtype) + new_inputs = [m1] + new_inputs + rval = mul(*new_inputs) + + return [broadcast_like(rval, node.outputs[0], fgraph)] + else: + # there are no variable inputs to mul + # N.B. this could have been constant-folded... + if has_neg: + return [broadcast_like(-1, node.outputs[0], fgraph)] + else: + return [broadcast_like(1, node.outputs[0], fgraph)] + + +@register_specialize +@node_rewriter([add]) +def local_add_specialize(fgraph, node): + """Remove zeros from ``add``s. + + TODO: This should be a canonicalization, no? + """ + # here, we are past the point of canonicalization, so we don't want + # to put in un-necessary fills. + if node.op != add: + return False + + new_inputs = [] + for inp in node.inputs: + try: + y = get_scalar_constant_value(inp) + except NotScalarConstantError: + y = inp + if np.all(y == 0.0): + continue + new_inputs.append(inp) + + if len(new_inputs) == len(node.inputs): + return False + + node_output = node.outputs[0] + dtype = node_output.type.dtype + + if len(new_inputs) == 0: + # we got rid of the entire expression! + ndim = node_output.type.ndim + # Reuse call to constant for cache() + cst = constant(np.zeros((1,) * ndim, dtype=dtype)) + assert cst.type.broadcastable == (True,) * ndim + return fill_chain(cst, node.inputs) + + if len(new_inputs) == 1: + ret = fill_chain(new_inputs[0], node.inputs) + else: + ret = fill_chain(add(*new_inputs), node.inputs) + + # The dtype should not be changed. It can happen if the input + # that was forcing upcasting was equal to 0. + if ret[0].dtype != dtype: + ret = [cast(ret[0], dtype)] + + return ret + + +mul_canonizer = in2out( + SequentialNodeRewriter( + local_mul_canonizer, local_fill_sink, apply_all_rewrites=True + ), + name="mul_canonizer_groups", +) + + +def check_for_x_over_absX(numerators, denominators): + """Convert x/abs(x) into sign(x).""" + # TODO: this function should dig/search through dimshuffles + # This won't catch a dimshuffled absolute value + for den in list(denominators): + if den.owner and den.owner.op == at_abs and den.owner.inputs[0] in numerators: + if den.owner.inputs[0].type.dtype.startswith("complex"): + # TODO: Make an Op that projects a complex number to + # have unit length but projects 0 to 0. That + # would be a weird Op, but consistent with the + # special case below. I heard there's some + # convention in Matlab that is similar to + # this... but not sure. + pass + else: + denominators.remove(den) + numerators.remove(den.owner.inputs[0]) + numerators.append(sgn(den.owner.inputs[0])) + return numerators, denominators + + +local_mul_canonizer.add_simplifier(check_for_x_over_absX, "X_over_absX") + + +@register_canonicalize +@node_rewriter([at_abs]) +def local_abs_lift(fgraph, node): + """ + Move the abs toward the input. + + This is needed for check_for_x_over_absX to apply in more case. + + """ + if node.op == at_abs and node.inputs[0].owner: + assert node.nin == 1 + if node.inputs[0].owner.op == mul: + return [mul(*[at_abs(i) for i in node.inputs[0].owner.inputs])] + if node.inputs[0].owner.op == true_div: + i = node.inputs[0].owner.inputs + return [true_div(at_abs(i[0]), at_abs(i[1]))] + + +@register_specialize +@node_rewriter([mul, true_div]) +def local_abs_merge(fgraph, node): + """ + Merge abs generated by local_abs_lift when the canonizer don't + need it anymore + + """ + if node.op == mul and sum(i.owner.op == at_abs for i in node.inputs if i.owner) > 1: + inputs = [] + for i in node.inputs: + if i.owner and i.owner.op == at_abs: + inputs.append(i.owner.inputs[0]) + elif isinstance(i, Constant): + try: + const = get_scalar_constant_value(i, only_process_constants=True) + except NotScalarConstantError: + return False + if not (const >= 0).all(): + return False + inputs.append(i) + else: + return False + return [at_abs(mul(*inputs))] + if ( + node.op == true_div + and sum(i.owner.op == at_abs for i in node.inputs if i.owner) == 2 + ): + return [ + at_abs( + true_div(node.inputs[0].owner.inputs[0], node.inputs[1].owner.inputs[0]) + ) + ] + + +@register_stabilize +@register_specialize +@node_rewriter([log]) +def local_log1p(fgraph, node): + # log(1+x) -> log1p(x) + # log(1-x) -> log1p(-x) + if node.op == log: + (log_arg,) = node.inputs + if log_arg.owner and log_arg.owner.op == add: + scalars, scalar_inputs, nonconsts = scalarconsts_rest( + log_arg.owner.inputs, only_process_constants=True + ) + # scalar_inputs are potentially dimshuffled and fill'd scalars + if scalars and np.allclose(np.sum(scalars), 1): + if nonconsts: + if len(nonconsts) > 1: + ninp = add(*nonconsts) + else: + ninp = nonconsts[0] + if ninp.dtype != log_arg.type.dtype: + ninp = ninp.astype(node.outputs[0].dtype) + return fill_chain(log1p(ninp), scalar_inputs) + + elif log_arg.owner and log_arg.owner.op == sub: + one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True) + if one != 1: + return + other = log_arg.owner.inputs[1] + if other.dtype != log_arg.dtype: + other = other.astype(log_arg.dtype) + return [log1p(neg(other))] + + +@register_stabilize +@register_specialize +@node_rewriter([log]) +def local_log_add_exp(fgraph, node): + """ + ``log(exp(x)+exp(y)+exp(z)) = max + log(x-max, y-max, z-max)`` + + TODO: in canonicalize, change log10 and log2 -> log + """ + + if node.op == log: + z = node.inputs[0] + if z.owner and z.owner.op == add: + zi = z.owner.inputs + pre_exp = [x.owner.inputs[0] for x in zi if x.owner and x.owner.op == exp] + # all arguments to add are exp() + if len(pre_exp) == len(zi): + # Do not offset when max_pre = -np.inf, to avoid nan in the output + # Switch statement is placed directly inside add to break the self-symmetry + # of the returned output (otherwise the rewrite would not stabilize) + max_pre = reduce(maximum, pre_exp) + ret = max_pre + log( + add( + *[ + switch(isinf(max_pre), exp(max_pre), exp(p - max_pre)) + for p in pre_exp + ] + ) + ) + return [ret] + + +@register_stabilize +@register_specialize +@node_rewriter([log]) +def local_log_sum_exp(fgraph, node): + # log(sum_i(exp(x_i))) = x_max + log(sum_i(exp(x_i - x_max))) + + if node.op != log: + return + + sum_node = node.inputs[0].owner + # If the sum has keepdims=True, there might be a dimshuffle + if sum_node and isinstance(sum_node.op, DimShuffle): + dimshuffle_op = sum_node.op + sum_node = sum_node.inputs[0].owner + else: + dimshuffle_op = None + + if not sum_node or not isinstance(sum_node.op, Sum): + return + + exp_node, axis = sum_node.inputs[0].owner, sum_node.op.axis + if not exp_node or not ( + isinstance(exp_node.op, Elemwise) and isinstance(exp_node.op.scalar_op, aes.Exp) + ): + return + + pre_exp = exp_node.inputs[0] + max_pre_exp = at_max(pre_exp, axis=axis) + max_pre_exp_keepdims = makeKeepDims(pre_exp, max_pre_exp, axis) + + # Do not offset when max_pre = -np.inf, to avoid nan in the output + # Switch statement is placed directly inside sum to break the self-symmetry + # of the returned output (otherwise the rewrite would not stabilize) + ret = max_pre_exp + log( + at_sum( + switch( + isinf(max_pre_exp_keepdims), + exp(max_pre_exp_keepdims), + exp(pre_exp - max_pre_exp_keepdims), + ), + axis=axis, + ), + ) + + # Restore the dimshuffle op, if any. + if dimshuffle_op: + ret = dimshuffle_op(ret) + + return [ret] + + +def add_calculate(num, denum, aslist=False, out_type=None): + # TODO: make sure that this function and mul_calculate are similar + if out_type is None: + zero = 0.0 + else: + zero = _asarray(0, dtype=out_type.dtype) + # zero = 0.0 if out_type is None else _asarray(0, + # dtype=out_type.dtype) + if out_type and out_type.dtype == "bool": + if len(denum) == 0: + # NumPy 1.14 do not accept to do "bool - bool" + v = reduce(np.add, num, zero) + else: + raise Exception( + "bool subtraction not supported. This should not happen as" + " an earlier error should have been raised" + ) + else: + v = reduce(np.add, num, zero) - reduce(np.add, denum, zero) + if aslist: + if np.all(v == 0): + return [] + else: + return [v] + return v + + +local_add_canonizer = AlgebraicCanonizer(add, sub, neg, add_calculate) +add_canonizer = in2out( + SequentialNodeRewriter( + local_add_canonizer, local_fill_sink, apply_all_rewrites=True + ), + name="add_canonizer_group", +) + + +register_canonicalize(local_add_canonizer, name="local_add_canonizer") + + +def distribute_greedy(pos_pairs, neg_pairs, num, denum, out_type, minscore=0): + # each pair in pos_pairs and neg_pairs is a num/denum pair. this + # function attempts to add num and denum to the corresponding parts + # of each pair, and counts how many multiplications/divisions can + # be saved in that way. + + # each division is counted like div_cost multiplications + # (typically, division costs more so we are willing to multiply more + # in order to divide less) + # 1.5 was obtained through an informal test and may very well be + # platform dependent + div_cost = 1.5 + + # score is number of operations saved, higher is better + score = len(num) + div_cost * len(denum) + new_pos_pairs = list( + itertools.starmap( + local_mul_canonizer.simplify, + [(n + num, d + denum, out_type) for (n, d) in pos_pairs], + ) + ) + new_neg_pairs = list( + itertools.starmap( + local_mul_canonizer.simplify, + [(n + num, d + denum, out_type) for (n, d) in neg_pairs], + ) + ) + for (n, d), (nn, dd) in zip(pos_pairs + neg_pairs, new_pos_pairs + new_neg_pairs): + # We calculate how many operations we are saving with the new + # num and denum + score += len(n) + div_cost * len(d) - len(nn) - div_cost * len(dd) + if score <= minscore: + # the change is not applied because it adds too many operations + return False, pos_pairs, neg_pairs + return True, new_pos_pairs, new_neg_pairs + + +def attempt_distribution(factor, num, denum, out_type): + """Try to insert each `num` and each `denum` in the factor? + + Returns + ------- + changes?, new_factor, new_num, new_denum + If there are changes, `new_num` and `new_denum` contain all the + numerators and denominators that could not be distributed in the factor + + """ + pos_terms, neg_terms = local_add_canonizer.get_num_denum(factor) + if len(pos_terms) == 1 and not neg_terms: + return False, factor, num, denum + pos_pairs = list(map(local_mul_canonizer.get_num_denum, pos_terms)) + neg_pairs = list(map(local_mul_canonizer.get_num_denum, neg_terms)) + change = False + for n in list(num): + success, pos_pairs, neg_pairs = distribute_greedy( + pos_pairs, neg_pairs, [n], [], out_type + ) + if success: + change = True + num.remove(n) + for d in list(denum): + success, pos_pairs, neg_pairs = distribute_greedy( + pos_pairs, neg_pairs, [], [d], out_type + ) + if success: + change = True + denum.remove(d) + if not change: + return change, factor, num, denum + else: + return ( + change, + local_add_canonizer.merge_num_denum( + list(itertools.starmap(local_mul_canonizer.merge_num_denum, pos_pairs)), + list(itertools.starmap(local_mul_canonizer.merge_num_denum, neg_pairs)), + ), + num, + denum, + ) + + +@register_canonicalize +@register_stabilize +@node_rewriter([mul, true_div, reciprocal]) +def local_greedy_distributor(fgraph, node): + """Reduce the number of multiplications and/or divisions. + + This rewrite tries to apply distributivity of multiplication + to addition in order to reduce the number of multiplications + and/or divisions that must be done. The algorithm weighs division + more than multiplication to account for the former's slightly + greater computational cost. + + The following expressions are simplified: + 1. ``((a/x + b/y) * x * y) -> a*y + b*x`` + 2. ``((a/x + b) * x) -> a + b*x`` + 3. There are other forms too where node is a true_div. + + The following expressions are not simplified: + 4. ``((a + b) * x) /> a*x + b*x`` + + This rewrite aims to reduce computational cost. It may also + increase numerical stability, e.g. when ``x`` and/or ``y`` tend to ``0`` in + Example 1. + + """ + + out = node.outputs[0] + num, denum = local_mul_canonizer.get_num_denum(out) + if len(num) == 1 and not denum: + return False + + new_num, new_denum = [], [] + + change = False + + out_type = out.type + for candidate in list(num): + if candidate not in num: + continue + num.remove(candidate) + _change, candidate, num, denum = attempt_distribution( + candidate, + num, + denum, + out_type, + ) + + change |= _change + new_num.append(candidate) + + for candidate in list(denum): + if candidate not in denum: + continue + denum.remove(candidate) + _change, candidate, denum, num = attempt_distribution( + candidate, denum, num, out_type + ) + change |= _change + new_denum.append(candidate) + if not change: + return False + + new_num += num + new_denum += denum + + rval = local_mul_canonizer.merge_num_denum(new_num, new_denum) + + if rval.type != out.type: + # WHY DOES THIS HAPPEN? + return False + + return [rval] + + +get_clients_at_depth1 = partial(get_clients_at_depth, depth=1) +get_clients_at_depth2 = partial(get_clients_at_depth, depth=2) + +# 1+erf(x)=>erfc(-x) +local_one_plus_erf = PatternNodeRewriter( + (add, 1, (erf, "x")), + (erfc, (neg, "x")), + allow_multiple_clients=True, + name="local_one_plus_erf", + tracks=[erf], + get_nodes=get_clients_at_depth1, +) +register_canonicalize(local_one_plus_erf) +register_stabilize(local_one_plus_erf) +register_specialize(local_one_plus_erf) + +# Only one of the two rewrites below is needed if a canonicalization is added +# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y) +# 1-erf(x)=>erfc(x) +local_one_minus_erf = PatternNodeRewriter( + (sub, 1, (erf, "x")), + (erfc, "x"), + allow_multiple_clients=True, + name="local_one_minus_erf", + tracks=[erf], + get_nodes=get_clients_at_depth1, +) +register_canonicalize(local_one_minus_erf) +register_stabilize(local_one_minus_erf) +register_specialize(local_one_minus_erf) + +local_one_minus_erf2 = PatternNodeRewriter( + (add, 1, (neg, (erf, "x"))), + (erfc, "x"), + allow_multiple_clients=True, + name="local_one_minus_erf2", + tracks=[erf], + get_nodes=get_clients_at_depth2, +) +register_canonicalize(local_one_minus_erf2) +register_stabilize(local_one_minus_erf2) +register_specialize(local_one_minus_erf2) + +# (-1)+erf(x) => -erfc(x) +# There is no need for erf(x)+(-1) nor erf(x) - 1, as the canonicalize will +# convert those to the matched pattern +local_erf_minus_one = PatternNodeRewriter( + (add, -1, (erf, "x")), + (neg, (erfc, "x")), + allow_multiple_clients=True, + name="local_erf_minus_one", + tracks=[erf], + get_nodes=get_clients_at_depth1, +) +register_canonicalize(local_erf_minus_one) +register_stabilize(local_erf_minus_one) +register_specialize(local_erf_minus_one) + +# Only one of the two rewrites below is needed if a canonicalization is added +# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y) +# 1-erfc(x) => erf(x) +local_one_minus_erfc = PatternNodeRewriter( + (sub, 1, (erfc, "x")), + (erf, "x"), + allow_multiple_clients=True, + name="local_one_minus_erfc", + tracks=[erfc], + get_nodes=get_clients_at_depth1, +) +register_canonicalize(local_one_minus_erfc) +register_stabilize(local_one_minus_erfc) +register_specialize(local_one_minus_erfc) + +local_one_minus_erfc2 = PatternNodeRewriter( + (add, 1, (neg, (erfc, "x"))), + (erf, "x"), + allow_multiple_clients=True, + name="local_one_minus_erfc2", + tracks=[erfc], + get_nodes=get_clients_at_depth2, +) +register_canonicalize(local_one_minus_erfc2) +register_stabilize(local_one_minus_erfc2) +register_specialize(local_one_minus_erfc2) + +# (-1)+erfc(-x)=>erf(x) +local_erf_neg_minus_one = PatternNodeRewriter( + (add, -1, (erfc, (neg, "x"))), + (erf, "x"), + allow_multiple_clients=True, + name="local_erf_neg_minus_one", + tracks=[erfc], + get_nodes=get_clients_at_depth1, +) +register_canonicalize(local_erf_neg_minus_one) +register_stabilize(local_erf_neg_minus_one) +register_specialize(local_erf_neg_minus_one) + + +@register_stabilize +@register_specialize +@node_rewriter([log]) +def local_log_erfc(fgraph, node): + """Stability rewrite for ``log(erfc(x))``. + + Notes + ----- + log(erfc(x)) => when x>threshold, + -x**2-log(x)-.5*log(pi)+log(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)) + for float64: threshold=26.641747557 was chosen with: + [(i,numpy.log(scipy.special.erfc(numpy.asarray([i],dtype='float64')))) + for i in numpy.arange(26.641747557,26.6417475571,.00000000001)] + for float32: threshold=10.0541949, [(i,numpy.log(scipy.special.erfc( + numpy.asarray([i],dtype='float32')))) for i in numpy.arange( + 10.0541948,10.0541951,.0000001)] + """ + if node.op != log: + return False + if not node.inputs[0].owner or node.inputs[0].owner.op != erfc: + return False + + if hasattr(node.tag, "local_log_erfc_applied"): + # We use that flag to don't apply the rewrite recursively + # TODO FIXME: We shouldn't need to use tags for this. + return False + + node.tag.local_log_erfc_applied = True + + x = node.inputs[0].owner.inputs[0] + stab_value = ( + -(x**2) + - log(x) + - 0.5 * log(np.pi) + + log(1 - 1 / (2 * x**2) + 3 / (4 * x**4) - 15 / (8 * x**6)) + ) + + if node.outputs[0].dtype == "float32" or node.outputs[0].dtype == "float16": + threshold = 10.0541949 + elif node.outputs[0].dtype == "float64": + threshold = 26.641747557 + + ret = switch(x < threshold, node.outputs[0], stab_value) + ret.tag.values_eq_approx = values_eq_approx_remove_inf + return [ret] + + +@register_stabilize +@register_specialize +@node_rewriter([true_div]) +def local_grad_log_erfc_neg(fgraph, node): + """Stability rewrite for the grad of ``log(erfc(x))``. + + Notes + ----- + ([y*]exp(-(x**2)))/erfc(x) # The y* is optional + ([y*]exp(x**2))/erfc(-x) => [y*](when x > threshold, + sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))) + + for float64: threshold=26.63 see at the end of the fct for the explanation + for float32: threshold=9.3 see at the end of the fct for the explanation + + TODO: remove the constraint that there are only 2 inputs to exp(x**2) + is the second. + + TODO: at the test point 10 in float32, there is instability in the original + value. The original gives -30.0, the stab -20.1 and in float64 -18.1. + Make it so that the test does not generate an error in that case! + + """ + if node.op != true_div: + return False + if not node.inputs[1].owner or node.inputs[1].owner.op != erfc: + return False + + erfc_in = node.inputs[1] + erfc_x = erfc_in.owner.inputs[0] + + if not node.inputs[0].owner: + return False + + # TODO: All of this should be replaced with a single, simple unification + # The mul is optional. + if node.inputs[0].owner.op != mul: + mul_in = None + y = [] + if not node.inputs[0].owner or node.inputs[0].owner.op != exp: + return False + exp_in = node.inputs[0] + else: + mul_in = node.inputs[0] + exp_in = None + for idx, inp in enumerate(mul_in.owner.inputs): + if inp.owner and inp.owner.op == exp: + exp_in = inp + break + else: + return False + + if len(mul_in.owner.inputs) == 2: + y = [mul_in.owner.inputs[1 - idx]] + else: + y = mul_in.owner.inputs[:] + del y[idx] + + if not exp_in.owner.inputs[0].owner: + return False + + if exp_in.owner.inputs[0].owner.op == neg: + neg_in = exp_in.owner.inputs[0] + if not neg_in.owner.inputs[0].owner or neg_in.owner.inputs[0].owner.op != sqr: + return False + sqr_in = neg_in.owner.inputs[0] + x = sqr_in.owner.inputs[0] + elif exp_in.owner.inputs[0].owner.op == mul: + # We should compare that -(erfc_x**2) is equivalent to mul_neg. + # There is currently no easy way to do this in the general case, + # so we implement some common case for now. + + # In many cases the neg are replaced by mul in the graph. + # This also allows to stabilize log(erfc(cst*x)). + mul_neg = exp_in.owner.inputs[0] + + # In case that multiple mul are not fused together, we do it here. + def check_input(inputs): + new_inputs = [] + for i in inputs: + if i.owner and i.owner.op == mul: + new_inputs.extend(check_input(i.owner.inputs)) + else: + new_inputs.append(i) + return new_inputs + + mul_inputs = check_input(mul_neg.owner.inputs) + + # Put the constant first. + for i in range(len(mul_inputs)): + if isinstance(i, Constant): + if i == 0: + break + else: + tmp = mul_inputs[0] + mul_inputs[0] = mul_inputs[i] + mul_inputs[i] = tmp + break + mul_neg = mul(*mul_inputs) + + try: + cst2 = get_scalar_constant_value( + mul_neg.owner.inputs[0], only_process_constants=True + ) + except NotScalarConstantError: + return False + + if len(mul_neg.owner.inputs) == 2: + if ( + not mul_neg.owner.inputs[1].owner + or mul_neg.owner.inputs[1].owner.op != sqr + ): + return False + sqr_in = mul_neg.owner.inputs[1] + x = sqr_in.owner.inputs[0] + elif len(mul_neg.owner.inputs) == 3: + if mul_neg.owner.inputs[1] is not mul_neg.owner.inputs[2]: + return False + x = mul_neg.owner.inputs[1] + else: + return False + + if cst2 != -1: + if ( + not erfc_x.owner + or erfc_x.owner.op != mul + or len(erfc_x.owner.inputs) != 2 + ): + # todo implement that case + return False + if erfc_x.owner.inputs[1] is not mul_neg.owner.inputs[1]: + return False + + x = erfc_x + try: + cst = get_scalar_constant_value( + erfc_x.owner.inputs[0], only_process_constants=True + ) + except NotScalarConstantError: + return False + if cst2 != -cst * 2: + return False + + # The constant is valid. Must check that the + elif erfc_x is not x: + return False + + else: + return False + + if hasattr(node.tag, "local_grad_log_erfc_neg"): + # We use that flag to don't apply the rewrite recursively + # TODO FIXME: We shouldn't need to use tags for this. + return False + + if erfc_x is not x: + return None + + # we move the y outside the div. + true_div_no_mul = true_div(exp_in, erfc_in) + true_div_no_mul.owner.tag.local_grad_log_erfc_neg = True + + # aaron value + stab_value = ( + x + * at_pow(1 - 1 / (2 * (x**2)) + 3 / (4 * (x**4)) - 15 / (8 * (x**6)), -1) + * cast(sqrt(np.pi), dtype=x.dtype) + ) + + if x.dtype == "float32" or x.dtype == "float16": + threshold = 9.3 + # threshold = 10.1 + elif x.dtype == "float64": + threshold = 26.641747557 + + ret = switch(x < threshold, true_div_no_mul, stab_value) + + if y: + ret = mul(ret, *y) + + ret.tag.values_eq_approx = values_eq_approx_remove_inf_nan + + return [ret] + + +def local_add_mul_fusion(fgraph, node): + """Fuse consecutive add or mul in one such node with more inputs. + + It is better to fuse add/mul that way then in a Composite node as + this make the inner graph of the Composite smaller. This allow to + put more computation in a Composite before hitting the max + recursion limit when pickling Composite. + + """ + if not isinstance(node.op, Elemwise) or not isinstance( + node.op.scalar_op, (aes.Add, aes.Mul) + ): + return False + + s_op = node.op.scalar_op.__class__ + new_inp = [] + fused = False + nb_inputs = len(node.inputs) + max_inputs = float("inf") + if hasattr(node.op, "max_inputs"): + max_inputs = node.op.max_inputs(node) + for inp in node.inputs: + if ( + inp.owner + and isinstance(inp.owner.op, Elemwise) + and isinstance(inp.owner.op.scalar_op, s_op) + and + # Do not duplicate the operation. + len(fgraph.clients[inp]) == 1 + and (nb_inputs + len(inp.owner.inputs) - 1) <= max_inputs + ): + new_inp.extend(inp.owner.inputs) + fused = True + else: + new_inp.append(inp) + + # We can not compare the number of inputs as Mul and Add could have + # 0 or 1 inputs in some corner cases. + if fused: + output = node.op(*new_inp) + copy_stack_trace(node.outputs[0], output) + + # Do the recursion here to help lower the number of + # FusionOptimizer iteration. + if output.owner: + output2 = local_add_mul_fusion(fgraph, output.owner) + if output2: + return output2 + return [output] + + +fuse_seqopt.register( + "local_add_mul_fusion", + FusionOptimizer(local_add_mul_fusion), + "fast_run", + "fusion", + position=0, +) + + +def _skip_mul_1(r): + if r.owner and r.owner.op == mul: + not_is_1 = [i for i in r.owner.inputs if not _is_1(i)] + if len(not_is_1) == 1: + return not_is_1[0] + + +def _is_1(expr): + """ + + Returns + ------- + bool + True iff expr is a constant close to 1. + + """ + try: + v = get_scalar_constant_value(expr) + return np.allclose(v, 1) + except NotScalarConstantError: + return False + + +logsigm_to_softplus = PatternNodeRewriter( + (log, (sigmoid, "x")), + (neg, (softplus, (neg, "x"))), + allow_multiple_clients=True, + values_eq_approx=values_eq_approx_remove_inf, + skip_identities_fn=_skip_mul_1, + tracks=[sigmoid], + get_nodes=get_clients_at_depth1, +) +log1msigm_to_softplus = PatternNodeRewriter( + (log, (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x"))), + (neg, (softplus, "x")), + allow_multiple_clients=True, + values_eq_approx=values_eq_approx_remove_inf, + skip_identities_fn=_skip_mul_1, + tracks=[sigmoid], + get_nodes=get_clients_at_depth2, +) +log1pexp_to_softplus = PatternNodeRewriter( + (log1p, (exp, "x")), + (softplus, "x"), + values_eq_approx=values_eq_approx_remove_inf, + allow_multiple_clients=True, +) +log1p_neg_sigmoid = PatternNodeRewriter( + (log1p, (neg, (sigmoid, "x"))), + (neg, (softplus, "x")), + values_eq_approx=values_eq_approx_remove_inf, + allow_multiple_clients=True, + tracks=[sigmoid], + get_nodes=get_clients_at_depth2, +) + +register_stabilize(logsigm_to_softplus, name="logsigm_to_softplus") +register_stabilize(log1msigm_to_softplus, name="log1msigm_to_softplus") +register_stabilize(log1pexp_to_softplus, name="log1pexp_to_softplus") +register_stabilize(log1p_neg_sigmoid, name="log1p_neg_sigmoid") +register_specialize(log1p_neg_sigmoid, name="log1p_neg_sigmoid") + + +def is_1pexp(t, only_process_constants=True): + """ + + Returns + ------- + object + If 't' is of the form (1+exp(x)), return (False, x). + Else return None. + + """ + if t.owner and t.owner.op == add: + scalars, scalar_inputs, nonconsts = scalarconsts_rest( + t.owner.inputs, only_process_constants=only_process_constants + ) + # scalar_inputs are potentially dimshuffled and filled with scalars + if len(nonconsts) == 1: + maybe_exp = nonconsts[0] + if maybe_exp.owner and maybe_exp.owner.op == exp: + # Verify that the constant terms sum to 1. + if scalars: + scal_sum = scalars[0] + for s in scalars[1:]: + scal_sum = scal_sum + s + if np.allclose(scal_sum, 1): + return False, maybe_exp.owner.inputs[0] + return None + + +def is_exp(var): + """ + Match a variable with either of the `exp(x)` or `-exp(x)` patterns. + + Parameters + ---------- + var + The Variable to analyze. + + Returns + ------- + tuple + A pair (b, x) with `b` a boolean set to True if `var` is of the + form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var` + cannot be cast into either form, then return `None`. + + """ + _neg = False + neg_info = is_neg(var) + if neg_info is not None: + _neg = True + var = neg_info + if var.owner and var.owner.op == exp: + return _neg, var.owner.inputs[0] + + +def is_mul(var): + """ + Match a variable with `x * y * z * ...`. + + Parameters + ---------- + var + The Variable to analyze. + + Returns + ------- + object + A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`, + or None if `var` cannot be cast into this form. + + """ + if var.owner and var.owner.op == mul: + return var.owner.inputs + else: + return None + + +def partition_num_or_denom(r, f): + if r.owner and r.owner.op == mul: + a = r.owner.inputs + else: + a = [r] + + # ugly 2.4-compatible thing + f_terms = [] + _neg = False + rest = [] + for t in a: + f_t = f(t) + if f_t is None: + rest.append(t) + else: + neg_t, f_t = f_t + f_terms.append(f_t) + _neg ^= neg_t # bit flip if neg_t is true + return f_terms, rest, _neg + + +def is_neg(var): + """ + Match a variable with the `-x` pattern. + + Parameters + ---------- + var + The Variable to analyze. + + Returns + ------- + object + `x` if `var` is of the form `-x`, or None otherwise. + + """ + var_node = var.owner + if not var_node: + return None + # First match against `neg`. + if var_node.op == neg: + return var_node.inputs[0] + # Then match against a multiplication by -1. + if var_node.op == mul and len(var_node.inputs) >= 2: + for idx, mul_input in enumerate(var_node.inputs): + try: + constant = get_scalar_constant_value(mul_input) + is_minus_1 = np.allclose(constant, -1) + except NotScalarConstantError: + is_minus_1 = False + if is_minus_1: + # Found a multiplication by -1. + if len(var_node.inputs) == 2: + # Only return the other input. + return var_node.inputs[1 - idx] + else: + # Return the multiplication of all other inputs. + return mul(*(var_node.inputs[0:idx] + var_node.inputs[idx + 1 :])) + # No match. + return None + + +@register_stabilize +@node_rewriter([true_div]) +def local_exp_over_1_plus_exp(fgraph, node): + """ + + exp(x)/(1+exp(x)) -> sigm(x) + c/(1+exp(x)) -> c*sigm(-x) + + """ + # This rewrite should be done for numerical stability + # so we don't care to check client counts + if node.op == true_div: + + # find all the exp() terms in the numerator + num, denom = node.inputs + num_exp_x, num_rest, num_neg = partition_num_or_denom(num, is_exp) + denom_1pexp, denom_rest, denom_neg = partition_num_or_denom(denom, is_1pexp) + + sigmoids = [] + for t in denom_1pexp: + if t in num_exp_x: + # case: exp(x) /(1+exp(x)) + sigmoids.append(sigmoid(t)) + del num_exp_x[num_exp_x.index(t)] + else: + # case: 1/(1+exp(x)) + sigmoids.append(sigmoid(-t)) + copy_stack_trace(node.outputs[0], sigmoids[-1]) + + if not sigmoids: # we didn't find any. abort + return + # put the new numerator together + new_num = sigmoids + [exp(t) for t in num_exp_x] + num_rest + if len(new_num) == 1: + new_num = new_num[0] + else: + new_num = mul(*new_num) + + if num_neg ^ denom_neg: + new_num = -new_num + + copy_stack_trace(num, new_num) + + if len(denom_rest) == 0: + return [new_num] + elif len(denom_rest) == 1: + out = new_num / denom_rest[0] + else: + out = new_num / mul(*denom_rest) + + copy_stack_trace(node.outputs[0], out) + return [out] + + +def parse_mul_tree(root): + """ + Parse a tree of multiplications starting at the given root. + + Parameters + ---------- + root + The variable at the root of the tree. + + Returns + ------- + object + A tree where each non-leaf node corresponds to a multiplication + in the computation of `root`, represented by the list of its inputs. + Each input is a pair [n, x] with `n` a boolean value indicating whether + sub-tree `x` should be negated. + + Examples + -------- + + .. code-block:: python + + x * y -> [False, [[False, x], [False, y]]] + -(x * y) -> [True, [[False, x], [False, y]]] + -x * y -> [False, [[True, x], [False, y]]] + -x -> [True, x] + (x * y) * -z -> [False, [[False, [[False, x], [False, y]]], + [True, z]]] + + """ + # Is it a multiplication? + mul_info = is_mul(root) + if mul_info is None: + # Is it a negation? + neg_info = is_neg(root) + if neg_info is None: + # Keep the root "as is". + return [False, root] + else: + # Recurse, inverting the negation. + neg, sub_tree = parse_mul_tree(neg_info) + return [not neg, sub_tree] + else: + # Recurse into inputs. + return [False, list(map(parse_mul_tree, mul_info))] + + +def replace_leaf(arg, leaves, new_leaves, op, neg): + """ + Attempt to replace a leaf of a multiplication tree. + + We search for a leaf in `leaves` whose argument is `arg`, and if we find + one, we remove it from `leaves` and add to `new_leaves` a leaf with + argument `arg` and variable `op(arg)`. + + Parameters + ---------- + arg + The argument of the leaf we are looking for. + leaves + List of leaves to look into. Each leaf should be a pair + (x, l) with `x` the argument of the Op found in the leaf, and `l` the + actual leaf as found in a multiplication tree output by `parse_mul_tree` + (i.e. a pair [boolean, variable]). + new_leaves + If a replacement occurred, then the leaf is removed from `leaves` + and added to the list `new_leaves` (after being modified by `op`). + op + A function that, when applied to `arg`, returns the Variable + we want to replace the original leaf variable with. + neg : bool + If True, then the boolean value associated to the leaf should + be swapped. If False, then this value should remain unchanged. + + Returns + ------- + bool + True if a replacement occurred, or False otherwise. + + """ + for idx, x in enumerate(leaves): + if x[0] == arg: + x[1][0] ^= neg + x[1][1] = op(arg) + leaves.pop(idx) + new_leaves.append(x) + return True + return False + + +def simplify_mul(tree): + """ + Simplify a multiplication tree. + + Parameters + ---------- + tree + A multiplication tree (as output by `parse_mul_tree`). + + Returns + ------- + object + A multiplication tree computing the same output as `tree` but without + useless multiplications by 1 nor -1 (identified by leaves of the form + [False, None] or [True, None] respectively). Useless multiplications + (with less than two inputs) are also removed from the tree. + + """ + neg, inputs = tree + if isinstance(inputs, list): + # Recurse through inputs. + s_inputs = [] + for s_i in map(simplify_mul, inputs): + if s_i[1] is None: + # Multiplication by +/-1. + neg ^= s_i[0] + else: + s_inputs.append(s_i) + if not s_inputs: + # The multiplication is empty. + rval = [neg, None] + elif len(s_inputs) == 1: + # The multiplication has a single input. + s_inputs[0][0] ^= neg + rval = s_inputs[0] + else: + rval = [neg, s_inputs] + else: + rval = tree + # print 'simplify_mul: %s -> %s' % (tree, rval) + return rval + + +def compute_mul(tree): + """ + Compute the Variable that is the output of a multiplication tree. + + This is the inverse of the operation performed by `parse_mul_tree`, i.e. + compute_mul(parse_mul_tree(tree)) == tree. + + Parameters + ---------- + tree + A multiplication tree (as output by `parse_mul_tree`). + + Returns + ------- + object + A Variable that computes the multiplication represented by the tree. + + """ + neg, inputs = tree + if inputs is None: + raise AssertionError( + "Function `compute_mul` found a missing leaf, did you forget to " + "call `simplify_mul` on the tree first?" + ) + elif isinstance(inputs, list): + # Recurse through inputs. + rval = mul(*list(map(compute_mul, inputs))) + else: + rval = inputs + if neg: + rval = -rval + return rval + + +def perform_sigm_times_exp( + tree, + exp_x=None, + exp_minus_x=None, + sigm_x=None, + sigm_minus_x=None, + parent=None, + child_idx=None, + full_tree=None, +): + """ + Core processing of the `local_sigm_times_exp` rewrite. + + This recursive function operates on a multiplication tree as output by + `parse_mul_tree`. It walks through the tree and modifies it in-place + by replacing matching pairs (exp, sigmoid) with the desired version. + + Parameters + ---------- + tree + The sub-tree to operate on. + exp_x + List of arguments ``x`` so that ``exp(x)`` exists somewhere in the whole + multiplication tree. Each argument is a pair ``(x, leaf)`` with ``x`` the + argument of the exponential, and ``leaf`` the corresponding leaf in the + multiplication tree (of the form ``[n, exp(x)]`` -- see `parse_mul_tree`). + If ``None``, this argument is initialized to an empty list. + exp_minus_x + Similar to `exp_x`, but for ``exp(-x)``. + sigm_x + Similar to `exp_x`, but for ``sigmoid(x)``. + sigm_minus_x + Similar to `exp_x`, but for ``sigmoid(-x)``. + parent + Parent of `tree` (``None`` if `tree` is the global root). + child_idx + Index of `tree` in its parent's inputs (``None`` if `tree` is the global + root). + full_tree + The global multiplication tree (should not be set except by recursive + calls to this function). Used for debugging only. + + Returns + ------- + bool + ``True`` if a modification was performed somewhere in the whole multiplication + tree, or ``False`` otherwise. + + """ + if exp_x is None: + exp_x = [] + if exp_minus_x is None: + exp_minus_x = [] + if sigm_x is None: + sigm_x = [] + if sigm_minus_x is None: + sigm_minus_x = [] + if full_tree is None: + full_tree = tree + if False: # Debug code. + print("") + print(f" full_tree = {full_tree}") + print(f" tree = {tree}") + print(f" exp_x = {exp_x}") + print(f" exp_minus_x = {exp_minus_x}") + print(f" sigm_x = {sigm_x}") + print(f" sigm_minus_x= {sigm_minus_x}") + neg, inputs = tree + if isinstance(inputs, list): + # Recurse through inputs of the multiplication. + rval = False + for sub_idx, sub_tree in enumerate(inputs): + rval |= perform_sigm_times_exp( + tree=sub_tree, + parent=tree, + child_idx=sub_idx, + exp_x=exp_x, + exp_minus_x=exp_minus_x, + sigm_x=sigm_x, + sigm_minus_x=sigm_minus_x, + full_tree=full_tree, + ) + return rval + else: + # Reached a leaf: if it is an exponential or a sigmoid, then we + # first attempt to find a match in leaves already visited. + # If there is such a match, we modify the already-visited leaf + # accordingly: for instance if we visited a leaf sigmoid(x), then + # find later a -exp(-x), we replace the previous leaf by + # -sigmoid(-x) and remove the -exp(-x) from the tree. + # If no match is found, then we register this leaf so that it can + # be found later while walking the tree. + var = inputs + keep_it = False + exp_info = is_exp(var) + if exp_info is not None: + exp_neg, exp_arg = exp_info + neg ^= exp_neg + neg_arg = is_neg(exp_arg) + if neg_arg is None: + if not replace_leaf(exp_arg, sigm_minus_x, sigm_x, sigmoid, neg): + exp_x.append((exp_arg, tree)) + keep_it = True + else: + if not replace_leaf( + neg_arg, sigm_x, sigm_minus_x, lambda x: sigmoid(-x), neg + ): + exp_minus_x.append((neg_arg, tree)) + keep_it = True + elif var.owner and var.owner.op == sigmoid: + sigm_arg = var.owner.inputs[0] + neg_arg = is_neg(sigm_arg) + if neg_arg is None: + if not replace_leaf( + sigm_arg, exp_minus_x, sigm_minus_x, lambda x: sigmoid(-x), neg + ): + sigm_x.append((sigm_arg, tree)) + keep_it = True + else: + if not replace_leaf(neg_arg, exp_x, sigm_x, sigmoid, neg): + sigm_minus_x.append((neg_arg, tree)) + keep_it = True + else: + # It is not an exponential nor a sigmoid. + keep_it = True + if not keep_it: + # Delete this leaf, i.e. replace it by [False, None] (corresponding + # to a multiplication by 1). + assert parent is not None + parent[1][child_idx] = [False, None] + return not keep_it + + +@register_stabilize +@node_rewriter([mul]) +def local_sigm_times_exp(fgraph, node): + """ + exp(x) * sigm(-x) -> sigm(x) + exp(-x) * sigm(x) -> sigm(-x) + + todo: add stack traces to the intermediate variables + """ + # Bail early if it is not a multiplication. + if node.op != mul: + return None + # Obtain tree of multiplications starting at this node. + mul_tree = parse_mul_tree(node.outputs[0]) + did_something = perform_sigm_times_exp(mul_tree) + if not did_something: + # No change. + return None + # The rewrite may have introduced multiplications by 1 in the tree: + # get rid of them. + mul_tree = simplify_mul(mul_tree) + # Recompute final output based on the updated tree. + out = compute_mul(mul_tree) + # keep the stack trace + copy_stack_trace(node.outputs[0], out) + return [out] + + +@register_stabilize +@node_rewriter([reciprocal]) +def local_reciprocal_1_plus_exp(fgraph, node): + """``reciprocal(1+exp(x)) -> sigm(-x)`` + + TODO: This is redundant; we can just decided on *one* canonical form + for division (e.g. either `true_div` or `reciprocal`) and have this + taken care of with existing rewrites. + """ + # This Rewrite should be done for numerical stability + # so we don't care to check client counts + if node.op == reciprocal: + reciprocal_arg = node.inputs[0] + if reciprocal_arg.owner and reciprocal_arg.owner.op == add: + scalars_, scalar_inputs, nonconsts = scalarconsts_rest( + reciprocal_arg.owner.inputs, only_process_constants=True + ) + # scalar_inputs are potentially dimshuffled and fill'd scalars + if len(nonconsts) == 1: + if nonconsts[0].owner and nonconsts[0].owner.op == exp: + if scalars_ and np.allclose(np.sum(scalars_), 1): + out = fill_chain( + sigmoid(neg(nonconsts[0].owner.inputs[0])), + scalar_inputs, + ) + # keep combined stack traces of + # exp(x): nonconsts[0], + # 1 + exp(x): reciprocal_arg, + # 1 / (1 + exp(x)): node.outputs[0] + copy_stack_trace( + [nonconsts[0], reciprocal_arg, node.outputs[0]], out + ) + return out + + +# 1 - sigmoid(x) -> sigmoid(-x) +local_1msigmoid = PatternNodeRewriter( + (sub, dict(pattern="y", constraint=_is_1), (sigmoid, "x")), + (sigmoid, (neg, "x")), + tracks=[sigmoid], + get_nodes=get_clients_at_depth1, + name="local_1msigmoid", +) +register_stabilize(local_1msigmoid) +register_specialize(local_1msigmoid) + + +log1pmexp_to_log1mexp = PatternNodeRewriter( + (log1p, (neg, (exp, "x"))), + (log1mexp, "x"), + allow_multiple_clients=True, +) +register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp") + + +# log(sigmoid(x) / (1 - sigmoid(x))) -> x +# i.e logit(sigmoid(x)) -> x +local_logit_sigmoid = PatternNodeRewriter( + (log, (true_div, (sigmoid, "x"), (sub, 1, (sigmoid, "x")))), + "x", + tracks=[sigmoid], + get_nodes=get_clients_at_depth2, + allow_multiple_clients=True, + name="local_logit_sigmoid", +) +register_canonicalize(local_logit_sigmoid) +register_specialize(local_logit_sigmoid) + + +# sigmoid(log(x / (1-x)) -> x +# i.e., sigmoid(logit(x)) -> x +local_sigmoid_logit = PatternNodeRewriter( + (sigmoid, (log, (true_div, "x", (sub, 1, "x")))), + "x", + allow_multiple_clients=True, + name="local_sigmoid_logit", +) +register_canonicalize(local_sigmoid_logit) +register_specialize(local_sigmoid_logit) + + +@register_canonicalize +@register_useless +@node_rewriter([_conj]) +def local_useless_conj(fgraph, node): + r"""Remove `conj` `Op`\s applied to non-imaginary variable types.""" + x = node.inputs[0] + if x.type.dtype not in complex_dtypes: + return [x] diff --git a/aesara/tensor/rewriting/shape.py b/aesara/tensor/rewriting/shape.py new file mode 100644 index 0000000000..a3b30177f0 --- /dev/null +++ b/aesara/tensor/rewriting/shape.py @@ -0,0 +1,1188 @@ +import traceback +from io import StringIO +from typing import Optional +from typing import cast as type_cast +from warnings import warn + +import numpy as np + +import aesara +from aesara.configdefaults import config +from aesara.graph.basic import Constant, Variable, ancestors, equal_computations +from aesara.graph.features import AlreadyThere, Feature +from aesara.graph.fg import FunctionGraph +from aesara.graph.rewriting.basic import ( + GraphRewriter, + check_chain, + copy_stack_trace, + node_rewriter, +) +from aesara.graph.utils import InconsistencyError, get_variable_trace_string +from aesara.tensor.basic import ( + MakeVector, + as_tensor_variable, + cast, + constant, + extract_constant, + get_scalar_constant_value, + stack, +) +from aesara.tensor.elemwise import DimShuffle, Elemwise +from aesara.tensor.exceptions import NotScalarConstantError, ShapeError +from aesara.tensor.rewriting.basic import ( + register_canonicalize, + register_specialize, + register_stabilize, + register_useless, + topo_constant_folding, +) +from aesara.tensor.shape import ( + Reshape, + Shape, + Shape_i, + SpecifyShape, + Unbroadcast, + shape_i, + specify_shape, + unbroadcast, +) +from aesara.tensor.subtensor import Subtensor, get_idx_list +from aesara.tensor.type import TensorType, discrete_dtypes, integer_dtypes +from aesara.tensor.type_other import NoneConst + + +class ShapeFeature(Feature): + r"""A `Feature` that tracks shape information in a graph. + + This `Feature` aids in the replacement of all `Shape`\s and `Subtensor`\s of `Shape`\s with + `Shape_i` and `MakeVector` `Op`\s. + + This `Feature` and its associated rewrites have several goals: + + 1. to "lift" `Shape`\s to as close to the inputs as possible, + 2. to infer the shape of every node in the graph in terms of the + input shapes, and + 3. remove fill `Op`\s (e.g. `Second`) from the graph. + + Lifting shapes as close to the inputs as possible is important for + canonicalization because it is very bad form to have to compute + something just to know how big it will be. Firstly, it is a waste + of time to compute such outputs. But it is important to get rid + of these outputs as early as possible in the compilation process + because the extra computations make it appear as if many internal + graph nodes have multiple clients. Many rewrites refuse to + work on nodes with multiple clients. + + Lifting is done by using an `.infer_shape` function if one is + present, or else using a conservative default. An Op that + supports shape-lifting should define a infer_shape(self, fgraph, node, + input_shapes) function. The argument input_shapes is a tuple of + tuples... there is an interior tuple for each input to the node. + The tuple has as many elements as dimensions. The element in + position i of tuple j represents the i'th shape component of the + j'th input. The function should return a tuple of tuples. One + output tuple for each node.output. Again, the i'th element of the + j'th output tuple represents the output[j].shape[i] of the + function. If an output is not a TensorType, then None should be + returned instead of a tuple for that output. + + For example the infer_shape for a matrix-matrix product would accept + input_shapes=((x0,x1), (y0,y1)) and return ((x0, y1),). + + Inferring the shape of internal nodes in the graph is important + for doing size-driven rewrites. If we know how big various + intermediate results will be, we can estimate the cost of many Ops + accurately, and generate c-code that is specific [e.g. unrolled] + to particular sizes. + + In cases where you cannot figure out the shape, raise a ShapeError. + + Notes + ----- + Right now there is only the ConvOp that could really take + advantage of this shape inference, but it is worth it even + just for the ConvOp. All that's necessary to do shape + inference is 1) to mark shared inputs as having a particular + shape, either via a .tag or some similar hacking; and 2) to + add an optional In() argument to promise that inputs will + have a certain shape (or even to have certain shapes in + certain dimensions). + + We can't automatically infer the shape of shared variables as they can + change of shape during the execution by default. + + To use this shape information in rewrites, use the + ``shape_of`` dictionary. + + For example: + + .. code-block:: python + + try: + shape_of = fgraph.shape_feature.shape_of + except AttributeError: + # This can happen when the mode doesn't include the ShapeFeature. + return + + shape_of_output_zero = shape_of[node.output[0]] + + The ``shape_of_output_zero`` symbol will contain a tuple, whose + elements are either integers or symbolic integers. + + TODO: check to see if the symbols are necessarily + non-constant... or are integer literals sometimes Aesara + constants?? That would be confusing. + + """ + + def get_node_infer_shape(self, node): + try: + shape_infer = node.op.infer_shape + except AttributeError: + shape_infer = self.default_infer_shape + + try: + o_shapes = shape_infer( + self.fgraph, node, [self.shape_of[r] for r in node.inputs] + ) + except ShapeError: + o_shapes = self.default_infer_shape( + self.fgraph, node, [self.shape_of[r] for r in node.inputs] + ) + except NotImplementedError as e: + raise NotImplementedError( + "Code called by infer_shape failed raising a " + "NotImplementedError. Raising NotImplementedError to " + "indicate that a shape cannot be computed is no longer " + "supported, and one should now use ShapeError " + f"instead. The original exception message is: {e}" + ).with_traceback(e.__traceback__) + except Exception as e: + msg = ( + f"Failed to infer_shape from Op {node.op}.\nInput shapes: " + f"{[self.shape_of[r] for r in node.inputs]}\nException encountered during infer_shape: " + f"{type(e)}\nException message: {str(e)}\nTraceback: {traceback.format_exc()}" + ) + if config.on_shape_error == "raise": + raise Exception(msg).with_traceback(e.__traceback__) + else: + warn(msg) + o_shapes = self.default_infer_shape( + self.fgraph, node, [self.shape_of[r] for r in node.inputs] + ) + + return o_shapes + + def get_shape(self, var, idx): + """Rewrites can call this to get a `Shape_i`. + + It is better to call this then use directly ``shape_of[var][idx]`` + as this method should update `shape_of` if needed. + + TODO: Up to now, we don't update it in all cases. Update in all cases. + """ + r = self.shape_of[var][idx] + if ( + r.owner + and isinstance(r.owner.op, Shape_i) + and r.owner.inputs[0] not in self.fgraph.variables + ): + assert var.owner + node = var.owner + # recur on inputs + for i in node.inputs: + if getattr(i.type, "ndim", None) > 0: + self.get_shape(i, 0) + o_shapes = self.get_node_infer_shape(node) + assert len(o_shapes) == len(node.outputs) + + # Only change the variables and dimensions that would introduce + # extra computation + for new_shps, out in zip(o_shapes, node.outputs): + if not hasattr(out.type, "ndim"): + continue + + merged_shps = list(self.shape_of[out]) + changed = False + for i in range(out.type.ndim): + n_r = merged_shps[i] + if ( + n_r.owner + and isinstance(n_r.owner.op, Shape_i) + and n_r.owner.inputs[0] not in self.fgraph.variables + ): + changed = True + merged_shps[i] = new_shps[i] + if changed: + self.set_shape(out, merged_shps, override=True) + r = self.shape_of[var][idx] + return r + + def shape_ir(self, i, r): + """Return symbolic r.shape[i] for tensor variable r, int i.""" + if hasattr(r.type, "shape") and r.type.shape[i] is not None: + return constant(r.type.shape[i], dtype="int64") + else: + # Do not call make_node for test_value + s = Shape_i(i)(r) + try: + s = get_scalar_constant_value(s) + except NotScalarConstantError: + pass + return s + + def shape_tuple(self, r): + """Return a tuple of symbolic shape vars for tensor variable r.""" + if not hasattr(r.type, "ndim"): + # This happen for NoneConst. + return None + return tuple(self.shape_ir(i, r) for i in range(r.type.ndim)) + + def default_infer_shape(self, fgraph, node, i_shapes): + """Return a list of shape tuple or None for the outputs of node. + + This function is used for Ops that don't implement infer_shape. + Ops that do implement infer_shape should use the i_shapes parameter, + but this default implementation ignores it. + + """ + rval = [] + for r in node.outputs: + try: + rval.append(self.shape_tuple(r)) + except AttributeError: + rval.append(None) + return rval + + def unpack(self, s_i, var): + """Return a symbolic integer scalar for the shape element s_i. + + The s_i argument was produced by the infer_shape() of an Op subclass. + + var: the variable that correspond to s_i. This is just for + error reporting. + + """ + assert s_i is not None + + if s_i == 1: + return self.lscalar_one + if isinstance(s_i, float) and int(s_i) == s_i: + s_i = int(s_i) + if isinstance(s_i, (np.integer, int)) or ( + isinstance(s_i, np.ndarray) and s_i.ndim == 0 + ): + # this shape is a constant + if s_i < 0: + msg = "There is a negative shape in the graph!" + msg += get_variable_trace_string(var) + # The rest of the pipeline don't handle correctly this + # case. So we have 2 choices, stop compilation or + # consider the shape as unknown. As we have more + # chance to give the stack trace here then later, I + # choose that options as it would give better error + # message. + raise AssertionError(msg) + return constant(s_i, dtype="int64") + if isinstance(s_i, (tuple, list)): + # this dimension is the same as many of the inputs + # which tells us that if one of the inputs is known, + # the others all become known. + # TODO: should be implemented in Elemwise, and Dot + # + # worst case, we loop over shape_of and replace things + raise NotImplementedError(s_i) + + # s_i is x.shape[i] for some x, we change it to shape_of[x][i] + if ( + s_i.owner + and isinstance(s_i.owner.op, Subtensor) + and s_i.owner.inputs[0].owner + and isinstance(s_i.owner.inputs[0].owner.op, Shape) + ): + assert s_i.type.ndim == 0 + assert len(s_i.owner.op.idx_list) == 1 + + # The current Subtensor always put constant index in the graph. + # This was not True in the past. So call the Subtensor function + # that will return the right index. + idx = get_idx_list(s_i.owner.inputs, s_i.owner.op.idx_list) + assert len(idx) == 1 + idx = idx[0] + try: + i = get_scalar_constant_value(idx) + except NotScalarConstantError: + pass + else: + # Executed only if no exception was raised + x = s_i.owner.inputs[0].owner.inputs[0] + # x should already have been imported, and should be in shape_of. + s_i = self.shape_of[x][i] + + if s_i.type.dtype in integer_dtypes: + if getattr(s_i.type, "ndim", 0): + raise TypeError("Shape element must be scalar", s_i) + return s_i + else: + raise TypeError( + "Unsupported shape element", s_i, type(s_i), getattr(s_i, "type", None) + ) + + def set_shape(self, r, s, override=False): + """Assign the shape `s` to previously un-shaped variable `r`. + + Parameters + ---------- + r : a variable + s : None or a tuple of symbolic integers + override : If False, it mean r is a new object in the fgraph. + If True, it mean r is already in the fgraph and we want to + override its shape. + + """ + if not override: + assert r not in self.shape_of, "r already in shape_of" + if s is None: + self.shape_of[r] = s + else: + if not isinstance(s, (tuple, list)): + raise TypeError("shapes must be tuple/list", (r, s)) + + if r.type.ndim != len(s): + sio = StringIO() + aesara.printing.debugprint(r, file=sio, print_type=True) + raise AssertionError( + f"Something inferred a shape with {len(s)} dimensions " + f"for a variable with {int(r.type.ndim)} dimensions" + f" for the variable:\n{sio.getvalue()}" + ) + + shape_vars = [] + for i in range(r.type.ndim): + if hasattr(r.type, "shape") and r.type.shape[i] is not None: + shape_vars.append(constant(r.type.shape[i], dtype="int64")) + else: + shape_vars.append(self.unpack(s[i], r)) + assert all( + not hasattr(r.type, "broadcastable") + or not r.type.broadcastable[i] + or self.lscalar_one.equals(shape_vars[i]) + or self.lscalar_one.equals(extract_constant(shape_vars[i])) + for i in range(r.type.ndim) + ) + self.shape_of[r] = tuple(shape_vars) + for sv in shape_vars: + self.shape_of_reverse_index.setdefault(sv, set()).add(r) + + def update_shape(self, r, other_r): + """Replace shape of r by shape of other_r. + + If, on some dimensions, the shape of other_r is not informative, + keep the shape of r on those dimensions. + + """ + # other_r should already have a shape + assert other_r in self.shape_of, ("other_r not in shape_of", other_r) + other_shape = self.shape_of[other_r] + + # If other_shape has no information, call is pointless. + if other_shape is None: + return + + if r in self.shape_of: + r_shape = self.shape_of[r] + else: + # If no info is known on r's shape, use other_shape + self.set_shape(r, other_shape) + return + if ( + other_r.owner + and r.owner + and other_r.owner.inputs == r.owner.inputs + and other_r.owner.op == r.owner.op + ): + # We are doing a merge, so the two shape graphs will be the + # same. This is only done so that we call `ancestors` less + # frequently. + return + + # Merge other_shape with r_shape, giving the priority to other_shape + merged_shape = [] + for i, ps in enumerate(other_shape): + if r_shape is None and other_shape: + merged_shape.append(other_shape[i]) + elif ( + ps.owner + and isinstance(getattr(ps.owner, "op", None), Shape_i) + and ps.owner.op.i == i + and ps.owner.inputs[0] in (r, other_r) + ): + # If other_shape[i] is uninformative, use r_shape[i]. + # For now, we consider 2 cases of uninformative other_shape[i]: + # - Shape_i(i)(other_r); + # - Shape_i(i)(r). + merged_shape.append(r_shape[i]) + elif isinstance(r_shape[i], (Constant, int)): + # We do this to call less often ancestors and make + # sure we have the simplest shape possible. + merged_shape.append(r_shape[i]) + elif isinstance(other_shape[i], (Constant, int)): + # We do this to call less often ancestors and make + # sure we have the simplest shape possible. + merged_shape.append(other_shape[i]) + elif other_shape[i] == r_shape[i]: + # This mean the shape is equivalent + # We do not want to do the ancestor check in those cases + merged_shape.append(r_shape[i]) + elif r_shape[i] in ancestors([other_shape[i]]): + # Another case where we want to use r_shape[i] is when + # other_shape[i] actually depends on r_shape[i]. In that case, + # we do not want to substitute an expression with another that + # is strictly more complex. Such a substitution could also lead + # to cycles: if (in the future) r_shape[i] gets replaced by an + # expression of other_shape[i], other_shape[i] may end up + # depending on itself. + merged_shape.append(r_shape[i]) + else: + merged_shape.append(other_shape[i]) + assert all( + ( + not hasattr(r.type, "broadcastable") + or not r.type.broadcastable[i] + and not other_r.type.broadcastable[i] + ) + or self.lscalar_one.equals(merged_shape[i]) + or self.lscalar_one.equals( + extract_constant(merged_shape[i], only_process_constants=True) + ) + for i in range(r.type.ndim) + ) + self.shape_of[r] = tuple(merged_shape) + for sv in self.shape_of[r]: + self.shape_of_reverse_index.setdefault(sv, set()).add(r) + + def set_shape_i(self, r, i, s_i): + """Replace element i of shape_of[r] by s_i""" + assert r in self.shape_of + prev_shape = self.shape_of[r] + # prev_shape is a tuple, so we cannot change it inplace, + # so we build another one. + new_shape = [] + for j, s_j in enumerate(prev_shape): + if j == i: + new_shape.append(self.unpack(s_i, r)) + else: + new_shape.append(s_j) + assert all( + not hasattr(r.type, "broadcastable") + or not r.type.broadcastable[idx] + or self.lscalar_one.equals(new_shape[idx]) + or self.lscalar_one.equals(extract_constant(new_shape[idx])) + for idx in range(r.type.ndim) + ) + self.shape_of[r] = tuple(new_shape) + for sv in self.shape_of[r]: + self.shape_of_reverse_index.setdefault(sv, set()).add(r) + + def init_r(self, r): + """Register r's shape in the shape_of dictionary.""" + if r not in self.shape_of: + self.set_shape(r, self.shape_tuple(r)) + + def make_vector_shape(self, r): + return as_tensor_variable(self.shape_of[r], ndim=1, dtype="int64") + + def on_attach(self, fgraph): + + if hasattr(fgraph, "shape_feature"): + raise AlreadyThere("This FunctionGraph already has a ShapeFeature") + + if hasattr(self, "fgraph") and self.fgraph != fgraph: + raise Exception("This ShapeFeature is already attached to a graph") + + self.fgraph = fgraph + + fgraph.shape_feature = self + # Must be local to the object as otherwise we reuse the same + # variable for multiple fgraph! + self.lscalar_one = constant(1, dtype="int64") + assert self.lscalar_one.type.dtype == "int64" + + self.fgraph = fgraph + # Variable -> tuple(scalars) or None (All tensor vars map to tuple) + self.shape_of = {} + # Variable -> + self.scheduled = {} + # shape var -> graph v + self.shape_of_reverse_index = {} + + for node in fgraph.toposort(): + self.on_import(fgraph, node, reason="on_attach") + + def on_detach(self, fgraph): + self.shape_of = {} + self.scheduled = {} + self.shape_of_reverse_index = {} + self.fgraph = None + del fgraph.shape_feature + + def on_import(self, fgraph, node, reason): + if node.outputs[0] in self.shape_of: + # this is a revert, not really an import + for r in node.outputs + node.inputs: + assert r in self.shape_of + return + + for i, r in enumerate(node.inputs): + # make sure we have shapes for the inputs + self.init_r(r) + + o_shapes = self.get_node_infer_shape(node) + + # this is packed information + # an element of o_shapes is either None or a tuple + # elements of the tuple can be either strings, or ints + if len(o_shapes) != len(node.outputs): + raise Exception( + ( + f'The infer_shape method for the Op "{node.op}" returned a list ' + f"with the wrong number of element: len(o_shapes) = {len(o_shapes)} " + f" != len(node.outputs) = {len(node.outputs)}" + ) + ) + + # Ensure shapes are in 'int64'. This is to make sure the assert + # found in the `local_useless_subtensor` rewrite does not fail. + for sh_idx, sh in enumerate(o_shapes): + if sh is None: + continue + if not isinstance(sh, (list, tuple)): + raise ValueError( + f"infer_shape of {node} didn't return a list of" + f" list. It returned '{o_shapes}'" + ) + new_shape = [] + for i, d in enumerate(sh): + # Note: we ignore any shape element that is not typed (i.e., + # does not have a 'dtype' attribute). This means there may + # still remain int elements that are int32 on 32-bit platforms, + # but this works with `local_useless_subtensor`, so for now we + # keep it this way. See #266 for a better long-term fix. + if getattr(d, "dtype", "int64") != "int64": + assert d.dtype in discrete_dtypes, (node, d.dtype) + assert str(d.dtype) != "uint64", node + new_shape += sh[len(new_shape) : i + 1] + if isinstance(d, Constant): + casted_d = constant(d.data, dtype="int64") + else: + casted_d = cast(d, "int64") + new_shape[i] = casted_d + if new_shape: + # We replace the shape with wrong dtype by the one with + # 'int64'. + new_shape += sh[len(new_shape) :] + o_shapes[sh_idx] = tuple(new_shape) + + for r, s in zip(node.outputs, o_shapes): + self.set_shape(r, s) + + def on_change_input(self, fgraph, node, i, r, new_r, reason): + if new_r not in self.shape_of: + # It happen that the fgraph didn't called on_import for some + # new_r. This happen when new_r don't have an + # owner(i.e. it is a constant or an input of the graph) + # update_shape suppose that r and new_r are in shape_of. + self.init_r(new_r) + + # This tells us that r and new_r must have the same shape if + # we didn't know that the shapes are related, now we do. + self.update_shape(new_r, r) + + # change_input happens in two cases: + # 1) we are trying to get rid of r, or + # 2) we are putting things back after a failed transaction. + + # In case 1, if r has a shape_i client, we will want to + # replace the shape_i of r with the shape of new_r. Say that + # r is *scheduled*. + # At that point, node is no longer a client of r, but of new_r + for (shpnode, idx) in fgraph.clients[r] + [(node, i)]: + if isinstance(getattr(shpnode, "op", None), Shape_i): + idx = shpnode.op.i + repl = self.shape_of[new_r][idx] + if repl.owner is shpnode: + # This mean the replacement shape object is + # exactly the same as the current shape object. So + # no need for replacement. + continue + if ( + repl.owner + and repl.owner.inputs[0] is shpnode.inputs[0] + and isinstance(repl.owner.op, Shape_i) + and repl.owner.op.i == shpnode.op.i + ): + # The replacement is a shape_i of the same + # input. So no need to do this equivalent + # replacement. + continue + + if shpnode.outputs[0] in ancestors([repl]): + raise InconsistencyError( + "This substitution would insert a cycle in the graph:" + f"node: {node}, i: {i}, r: {r}, new_r: {new_r}" + ) + + self.scheduled[shpnode] = new_r + # In case 2, if r is a variable that we've scheduled for shape update, + # then we should cancel it. + unscheduled = [k for k, v in self.scheduled.items() if v == r] + for k in unscheduled: + del self.scheduled[k] + + # In either case, r could be in shape_of.values(), that is, r itself + # is the shape of something. In that case, we want to update + # the value in shape_of, to keep it up-to-date. + for v in self.shape_of_reverse_index.get(r, []): + # The reverse index is only approximate. It is not updated on + # deletion of variables, or on change_input so it might be the + # case that there are a few extra `v`'s in it that no longer have + # a shape of r or possibly have been deleted from shape_of + # entirely. The important thing is that it permits to recall + # all variables with r in their shape. + for ii, svi in enumerate(self.shape_of.get(v, [])): + if svi == r: + self.set_shape_i(v, ii, new_r) + self.shape_of_reverse_index[r] = set() + + def same_shape( + self, + x: Variable, + y: Variable, + dim_x: Optional[int] = None, + dim_y: Optional[int] = None, + ) -> bool: + """Return ``True`` if `x` and `y` have the same shape. + + Parameters + ========== + x + The `Variable` for which its shape is to be compared with `y`'s shape. + y + The `Variable` for which its shape is to be compared with `x`'s shape. + dim_x + If non ``None``, compare only the dimension of `x` equal to + `dim_x`. + dim_y + If non ``None``, compare only the dimension of `y` equal to + `dim_y`. + + """ + sx = self.shape_of[x] + sy = self.shape_of[y] + + if sx is None or sy is None: + return False + + if dim_x is not None: + sx = [sx[dim_x]] + + if dim_y is not None: + sy = [sy[dim_y]] + + if len(sx) != len(sy): + return False + + # Canonicalize the graphs so that comparisons are reasonable + # TODO FIXME: This should *not* need to be performed manually here. + # Instead, the shape information in `self.shape_of` should be operated + # upon alongside all the other elements in a `FunctionGraph` (e.g. as + # if `self.shape_of.values()` were additional outputs). + shapes_fg = FunctionGraph( + outputs=sx + sy, + # features=[self], + clone=True, + # copy_inputs=False, + ) + from aesara.graph.rewriting.utils import rewrite_graph + + canon_shapes_fg = type_cast( + FunctionGraph, + rewrite_graph(shapes_fg, custom_rewrite=topo_constant_folding), + ) + canon_shapes = canon_shapes_fg.outputs + + sx = canon_shapes[: len(sx)] + sy = canon_shapes[len(sx) :] + + for dx, dy in zip(sx, sy): + if not equal_computations([dx], [dy]): + return False + + return True + + def clone(self): + return type(self)() + + +class ShapeOptimizer(GraphRewriter): + """Rewriter that adds `ShapeFeature` as a feature.""" + + def add_requirements(self, fgraph): + fgraph.attach_feature(ShapeFeature()) + + def apply(self, fgraph): + pass + + +class UnShapeOptimizer(GraphRewriter): + """Rewriter that removes `ShapeFeature` as a feature.""" + + def apply(self, fgraph): + for feature in fgraph._features: + if isinstance(feature, ShapeFeature): + fgraph.remove_feature(feature) + + +# Register it after merge1 optimization at 0. We don't want to track +# the shape of merged node. +aesara.compile.mode.optdb.register( # type: ignore + "ShapeOpt", ShapeOptimizer(), "fast_run", "fast_compile", position=0.1 +) +# Not enabled by default for now. Some crossentropy opt use the +# shape_feature. They are at step 2.01. uncanonicalize is at step +# 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable. +aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) # type: ignore + + +def local_reshape_chain(op): + @node_rewriter([op]) + def f(fgraph, node): + """ + Reshape(Reshape(shape1),shape2) -> Reshape(shape2) + + """ + if not check_chain(node, op, op): + return False + + # TODO: this can permit a failing program to run by eliminating + # the lower reshape + rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) + + # Copy over stacktrace from previous output node, as any error + # in new computational graph would have been caused by last op + # in the old computational graph. + copy_stack_trace(node.outputs, rval) + + # It might happen that the desired output of this node has a + # broadcastable pattern that does not match that of 'rval'. This is + # when originally, we were able to figure out that one of the + # dimensions of the reshape is one, but some other transformation + # replaced the shape by one for which this cannot be guessed. + # We should try to figure out why we lost the information about this + # constant value... but in the meantime, better not apply this + # rewrite. + if rval.broadcastable == node.outputs[0].broadcastable: + return [rval] + else: + return False + + return f + + +register_canonicalize(local_reshape_chain(Reshape), name="local_reshape_chain") + + +@register_useless +@register_canonicalize +@register_stabilize +@node_rewriter([Reshape]) +def local_useless_reshape(fgraph, node): + """ + Remove two kinds of useless reshape. + + Remove Reshape when both the input and output have a single dimension. + Remove Reshape when reshaping to the shape of the input. + + """ + op = node.op + if not isinstance(op, Reshape): + return False + + inp = node.inputs[0] + output = node.outputs[0] + output_shape = node.inputs[1] + + if inp.ndim != output.ndim: + return False + + # Simple case: both input and output have a single dimension. + # This could hide errors if the user provides inconsistent shapes. + if inp.ndim == 1 and output.ndim == 1 and inp.broadcastable == output.broadcastable: + return [inp] + + # Second case: all the shapes match the input shape + # Match Reshape(x, x.shape) + if output_shape.owner and isinstance(output_shape.owner.op, Shape): + shape_input = output_shape.owner.inputs[0] + if shape_input == inp: + return [inp] + + # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for + # broadcastable and constant dimensions + if output_shape.owner and isinstance(output_shape.owner.op, MakeVector): + output_shape_is = output_shape.owner.inputs + + shape_feature = getattr(fgraph, "shape_feature", None) + + nb_m1 = 0 + shape_match = [False] * inp.ndim + for dim in range(inp.ndim): + outshp_i = output_shape_is[dim] + # Match Shape_i{dim}(input) + if ( + outshp_i.owner + and isinstance(outshp_i.owner.op, Shape_i) + and outshp_i.owner.op.i == dim + and outshp_i.owner.inputs[0] == inp + ): + shape_match[dim] = True + continue + + # Match Shape(input)[dim] + if ( + outshp_i.owner + and isinstance(outshp_i.owner.op, Subtensor) + and len(outshp_i.owner.inputs) == 2 + and extract_constant(outshp_i.owner.inputs[1]) == dim + ): + subtensor_inp = outshp_i.owner.inputs[0] + if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape): + shape_input_i = subtensor_inp.owner.inputs[0] + if shape_input_i == inp: + shape_match[dim] = True + continue + + # Match 1 if input.broadcastable[dim] is True + cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) + if inp.broadcastable[dim] and cst_outshp_i == 1: + shape_match[dim] = True + continue + + # Match -1 + if cst_outshp_i == -1: + shape_match[dim] = True + nb_m1 += 1 + continue + + # Match shape_of[input][dim] or its constant equivalent + if shape_feature: + inpshp_i = shape_feature.get_shape(inp, dim) + if inpshp_i == outshp_i or ( + extract_constant(inpshp_i, only_process_constants=1) + == extract_constant(outshp_i, only_process_constants=1) + ): + shape_match[dim] = True + continue + + if all(shape_match) and nb_m1 <= 1: + return [inp] + + # TODO later: if all the shapes except one match, we may want to + # consider it useless as well, like we do in the 1-dim case. + return False + + +@register_canonicalize +@node_rewriter([Reshape]) +def local_reshape_to_dimshuffle(fgraph, node): + """ + Broadcastable dimensions in Reshape are replaced with dimshuffle. + + The goal is to avoid using reshape to add or remove broadcastable + dimensions, but use dimshuffle instead, so dimshuffles can cancel out + or be removed later on. + + For example: + - reshape(x, (1, n)) --> dimshuffle{x,0}(reshape(x, (n,)) + - reshape(x, (1, m, 1, n, 1, 1)) + --> dimshuffle{x,0,x,1,x,x}(reshape(x, (m, n))) + """ + op = node.op + if not isinstance(op, Reshape): + return False + + inp = node.inputs[0] + output = node.outputs[0] + output_shape = node.inputs[1] + + dimshuffle_new_order = [] + new_output_shape = [] + index = 0 # index over the output of the new reshape + for i in range(output.ndim): + # Since output_shape is a symbolic vector, we trust extract_constant + # to go through however it is formed to see if its i-th element is 1. + # We need only_process_constants=False for that. + dim = extract_constant( + output_shape[i], only_process_constants=False, elemwise=False + ) + if dim == 1: + dimshuffle_new_order.append("x") + else: + dimshuffle_new_order.append(index) + new_output_shape.append(dim) + index = index + 1 + if index != output.ndim: + inner = op.__class__(len(new_output_shape))(inp, new_output_shape) + copy_stack_trace(output, inner) + new_node = [DimShuffle(inner.type.broadcastable, dimshuffle_new_order)(inner)] + copy_stack_trace(output, new_node) + return new_node + + +@register_canonicalize +@register_stabilize +@node_rewriter([Reshape]) +def local_reshape_lift(fgraph, node): + """ + Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x)) + + Notes + ----- + This rewrite is needed by `log1msigm_to_softplus` in order to get applied + when there is a reshape. + + """ + if ( + isinstance(node.op, Reshape) + and node.inputs[0].owner + and isinstance(node.inputs[0].owner.op, Elemwise) + and len(node.inputs[0].owner.inputs) == 1 + ): + r = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) + # Copy stacktrace from previous Reshape op, as an error in new + # Reshape op could only have been caused by old one. + copy_stack_trace(node.outputs, r) + + e = node.inputs[0].owner.op(r) + # Copy stacktrace from both previous Reshape and UnaryElemwise op + # because an error in new cg could have been caused by either ops. + copy_stack_trace(node.outputs + node.inputs, e) + return [e] + + +@register_useless +@register_canonicalize +@node_rewriter([SpecifyShape]) +def local_merge_consecutive_specify_shape(fgraph, node): + """Replace ``specify_shape(specify_shape(x, s1), s2)`` with ``specify_shape(x, s3)``, + where s3 is the union of specified dimensions in s1 and s2, with preference given to s2. + """ + + if not isinstance(node.op, SpecifyShape): + return False + + obj = node.inputs[0] + if not (obj.owner and isinstance(obj.owner.op, SpecifyShape)): + return False + + inner_obj, *shape = obj.owner.inputs + for dim, sh in enumerate(node.inputs[1:]): + if not NoneConst.equals(sh): + shape[dim] = sh + + # TODO: We could make sure that the overlapping shapes of the two `SpecifyShape`s are + # the same. + + return [specify_shape(inner_obj, shape)] + + +@register_useless +@register_canonicalize +@node_rewriter([Shape]) +def local_Shape_of_SpecifyShape(fgraph, node): + """Replace ``specify_shape(x, s).shape`` with ``s``.""" + + if not isinstance(node.op, Shape): + return False + + specified_shape = node.inputs[0] + + if not isinstance(getattr(specified_shape.owner, "op", None), SpecifyShape): + return False + + x, *shape = specified_shape.owner.inputs + + # Replace `NoneConst` by `shape_i` + for i, sh in enumerate(shape): + if NoneConst.equals(sh): + shape[i] = shape_i(x, i, fgraph) + + return [stack(shape).astype(np.int64)] + + +@register_useless +@register_canonicalize +@node_rewriter([Shape_i]) +def local_Shape_i_of_broadcastable(fgraph, node): + """Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``.""" + + if not isinstance(node.op, Shape_i): + return False + + shape_arg = node.inputs[0] + + if not isinstance(shape_arg.type, TensorType): + return False + + if shape_arg.broadcastable[node.op.i]: + return [as_tensor_variable(1, dtype=np.int64)] + + +@register_specialize +@register_canonicalize +@node_rewriter([Shape]) +def local_shape_to_shape_i(fgraph, node): + if isinstance(node.op, Shape): + if not hasattr(fgraph, "shape_feature"): + return + shape_feature = fgraph.shape_feature + ret = shape_feature.make_vector_shape(node.inputs[0]) + + # We need to copy over stack trace from input to output + copy_stack_trace(node.outputs[0], ret) + return [ret] + + +@register_specialize +@register_canonicalize +@node_rewriter([Shape_i]) +def local_track_shape_i(fgraph, node): + if not isinstance(node.op, Shape_i): + return False + + try: + shape_feature = fgraph.shape_feature + except AttributeError: + return False + + if node not in shape_feature.scheduled: + return False + + # Don't unschedule node as it could be reinserted in the + # fgraph as we don't change it in the shapefeature internal + # structure. + replacement = shape_feature.scheduled[node] + return [shape_feature.shape_of[replacement][node.op.i]] + + +@register_canonicalize +@node_rewriter([Reshape]) +def local_useless_dimshuffle_in_reshape(fgraph, node): + """ + Removes useless DimShuffle operation inside Reshape: + + reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp) + reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp) + reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp) + reshape(col.dimshuffle(0), shp) => reshape(col, shp) + + """ + op = node.op + if not isinstance(op, Reshape): + return False + if not ( + node.inputs[0].owner is not None + and isinstance(node.inputs[0].owner.op, DimShuffle) + ): + return False + + new_order = node.inputs[0].owner.op.new_order + inp = node.inputs[0].owner.inputs[0] + broadcastables = node.inputs[0].broadcastable + new_order_of_nonbroadcast = [] + for i, bd in zip(new_order, broadcastables): + if not bd: + new_order_of_nonbroadcast.append(i) + no_change_in_order = all( + new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1] + for i in range(len(new_order_of_nonbroadcast) - 1) + ) + if no_change_in_order: + shape = node.inputs[1] + ret = op.__class__(node.outputs[0].ndim)(inp, shape) + copy_stack_trace(node.outputs[0], ret) + return [ret] + + +@register_useless +@register_canonicalize +@register_specialize +@node_rewriter([Unbroadcast]) +def local_useless_unbroadcast(fgraph, node): + """Remove `Unbroadcast` if it does not actually change the broadcasting pattern. + + TODO: Implement equivalent rewrite for SpecifyShape + """ + if isinstance(node.op, Unbroadcast): + x = node.inputs[0] + if x.broadcastable == node.outputs[0].broadcastable: + # No broadcastable flag was modified + # No need to copy over stack trace, + # because x should already have a stack trace. + return [x] + else: + # Keep the flags that modify something + new_axes = tuple(ax for ax in node.op.axes if x.type.shape[ax] == 1) + if new_axes == node.op.axes: + # All flags are useful + return None + else: + r = unbroadcast(x, *new_axes) + # Copy over stacktrace from previous output + copy_stack_trace(node.outputs, r) + return [r] + + +@register_canonicalize +@register_specialize +@node_rewriter([Unbroadcast]) +def local_unbroadcast_lift(fgraph, node): + """ + Lifts `Unbroadcast` through unary Elemwise operations, + and merges consecutive `Unbroadcast`s. + + Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x)) + Unbroadcast(Unbroadcast(x)) => Unbroadcast(x) + + TODO: Implement equivalent Elemwise lift for SpecifyShape + """ + op = node.op + if not isinstance(op, Unbroadcast): + return False + + inp = node.inputs[0] + inode = inp.owner + if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1: + if len(fgraph.clients.get(inp, ())) == 1: + unbroadcasted = unbroadcast(inode.inputs[0], *op.axes) + copy_stack_trace(node.outputs, unbroadcasted) + + rval = inode.op.make_node(unbroadcasted).outputs + + # Copy over stacktrace from previous output (after unbroadcasting) + # and input (after elemwise operation) to new output, because an + # error in the new graph could have been caused by either of the + # two ops. + copy_stack_trace(node.outputs + node.inputs, rval) + return rval + + if inode and isinstance(inode.op, Unbroadcast): + # Merge axis of each unbroadcast + axis = tuple(set(inode.op.axes).union(set(op.axes))) + iinput = inode.inputs[0] + rval = [unbroadcast(iinput, *axis)] + # Copy over stacktrace from previous output (after second unbroadcasting) + # and from previous input (after first unbroadcasting) because an error in + # the new graph could have been caused by either of the two Unbroadcast ops. + copy_stack_trace(node.outputs + node.inputs, rval) + return rval diff --git a/aesara/tensor/rewriting/subtensor.py b/aesara/tensor/rewriting/subtensor.py new file mode 100644 index 0000000000..c25b77f8ee --- /dev/null +++ b/aesara/tensor/rewriting/subtensor.py @@ -0,0 +1,1842 @@ +import sys +from collections.abc import Iterable + +import numpy as np + +import aesara +import aesara.scalar.basic as aes +from aesara import compile +from aesara.graph.basic import Constant, Variable +from aesara.graph.rewriting.basic import ( + WalkingGraphRewriter, + copy_stack_trace, + in2out, + node_rewriter, +) +from aesara.raise_op import Assert +from aesara.tensor.basic import ( + Alloc, + Join, + MakeVector, + ScalarFromTensor, + TensorFromScalar, + alloc, + as_tensor, + cast, + concatenate, + extract_constant, + get_scalar_constant_value, + switch, +) +from aesara.tensor.elemwise import Elemwise +from aesara.tensor.exceptions import NotScalarConstantError +from aesara.tensor.math import Dot, add +from aesara.tensor.math import all as at_all +from aesara.tensor.math import ( + and_, + ceil_intdiv, + dot, + eq, + ge, + gt, + le, + lt, + maximum, + minimum, + or_, +) +from aesara.tensor.rewriting.basic import ( + register_canonicalize, + register_specialize, + register_stabilize, +) +from aesara.tensor.shape import ( + Shape, + SpecifyShape, + Unbroadcast, + shape_padleft, + shape_tuple, + specify_shape, + unbroadcast, +) +from aesara.tensor.sharedvar import TensorSharedVariable +from aesara.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + AdvancedSubtensor, + AdvancedSubtensor1, + IncSubtensor, + Subtensor, + advanced_inc_subtensor1, + advanced_subtensor, + advanced_subtensor1, + as_index_constant, + as_index_literal, + get_canonical_form_slice, + get_constant_idx, + get_idx_list, + get_slice_elements, + inc_subtensor, + indices_from_subtensor, +) +from aesara.tensor.type import TensorType +from aesara.tensor.type_other import NoneTypeT, SliceConstant, SliceType +from aesara.tensor.var import TensorConstant, TensorVariable + + +def register_useless(lopt, *tags, **kwargs): + if isinstance(lopt, str): + + def register(inner_lopt): + return register_useless(inner_lopt, lopt, *tags, **kwargs) + + return register + else: + name = kwargs.pop("name", None) or lopt.__name__ + + compile.mode.local_useless.register( + name, lopt, "fast_run", *tags, position="last", **kwargs + ) + return lopt + + +def transform_take(a, indices, axis): + r"""Transform ``arr[:,:,:,indices,...]``-like operations into single-dimensional, vector index operations. + + This effectively converts certain `AdvancedSubtensor` `Op`\s into a + combination of `AdvancedSubtensor1`, `Dimshuffle`, and `Reshape` `Op`\s, + which can be more efficient. + + Parameters + ---------- + a : TensorVariable + The source array. + indices : TensorVariable, ndarray, list, tuple + The indices of the values to extract. + axis : int + The axis over which to select values. By default, the flattened + input array is used. + + """ + a = aesara.tensor.as_tensor_variable(a) + indices = aesara.tensor.as_tensor_variable(indices) + # We can use the more efficient `AdvancedSubtensor1` if `indices` is a vector + if indices.ndim == 1: + if axis == 0: + return advanced_subtensor1(a, indices) + else: + shuffle = list(range(a.ndim)) + shuffle[0] = axis + shuffle[axis] = 0 + res = advanced_subtensor1(a.dimshuffle(shuffle), indices).dimshuffle( + shuffle + ) + return res + + # We can reshape and flatten the indices in order to use an + # `AdvancedSubtensor1` `Op` per the above + indices_shape = shape_tuple(indices) + a_shape = shape_tuple(a) + + shape_parts = [ + a_shape[:axis], + indices_shape, + a_shape[axis + 1 :], + ] + + shape_parts = [sp for sp in shape_parts if len(sp) > 0] + + assert len(shape_parts) > 0 + + if len(shape_parts) > 1: + shape = aesara.tensor.concatenate(shape_parts) + else: + shape = shape_parts[0] + + ndim = a.ndim + indices.ndim - 1 + + return transform_take(a, indices.flatten(), axis).reshape(shape, ndim) + + +def is_full_slice(x): + """Determine if `x` is a ``slice(None)`` or a symbolic equivalent.""" + if ( + (isinstance(x, slice) and x == slice(None)) + or (isinstance(x, SliceConstant) and x.value == slice(None)) + or ( + not isinstance(x, SliceConstant) + and isinstance(getattr(x, "type", None), SliceType) + and x.owner is not None + and all( + isinstance(getattr(i, "type", None), NoneTypeT) for i in x.owner.inputs + ) + ) + ): + return True + return False + + +def get_advsubtensor_axis(indices): + """Determine the axis at which an array index is applied. + + This only works for ``take``-like indices: e.g. ``x[:, :, idx, ...]``. For + the above example, `get_advsubtensor_axis` would return ``2``. If it + encounters anything other than a set of `indices` containing full slices + and an array/tensor index, it will return ``None``. + + """ + found_idx = False + axis = 0 + for idx in indices: + if not found_idx and is_full_slice(idx): + # Preceding full slices + axis += 1 + elif found_idx and not is_full_slice(idx): + # We don't handle multiple indices + return + elif found_idx and is_full_slice(idx): + # Trailing full slices + continue + else: + found_idx = True + + if isinstance( + indices[axis], (TensorConstant, TensorVariable, TensorSharedVariable) + ): + return axis + + +@register_specialize +@node_rewriter([AdvancedSubtensor]) +def local_replace_AdvancedSubtensor(fgraph, node): + r""" + This rewrite converts expressions like ``X[..., y]`` into ``X.T[y].T``, for + a vector ``y``, and ``X[z, ...]`` into ``X[z.flatten()].reshape(...)``, for a + matrix ``z``. + + These rewrites replace `AdvancedSubtensor`\s with the more efficient + `AdvancedSubtensor1` and `Subtensor` `Op`\s. + """ + + if not isinstance(node.op, AdvancedSubtensor): + return + + indexed_var = node.inputs[0] + indices = node.inputs[1:] + + axis = get_advsubtensor_axis(indices) + + if axis is None or indices[axis].dtype == "bool": + # Booleans aren't handled + return + + new_res = transform_take(indexed_var, indices[axis], axis) + copy_stack_trace(node.outputs[0], new_res) + return [new_res] + + +@register_specialize +@node_rewriter([AdvancedIncSubtensor]) +def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): + r"""Replace `AdvancedIncSubtensor`\s with `AdvancedIncSubtensor1`\s. + + This is only done when there's a single vector index. + """ + + if not isinstance(node.op, AdvancedIncSubtensor) or node.op.ignore_duplicates: + # `AdvancedIncSubtensor1` does not ignore duplicate index values + return + + res = node.inputs[0] + val = node.inputs[1] + indices = node.inputs[2:] + + axis = get_advsubtensor_axis(indices) + + if axis is None or indices[axis].dtype == "bool": + # Booleans aren't currently handled by `AdvancedIncSubtensor1` + return + + new_subtensor = transform_take(res, indices[axis], axis) + + new_res = inc_subtensor( + new_subtensor, + val, + inplace=node.op.inplace, + set_instead_of_inc=node.op.set_instead_of_inc, + ignore_duplicates=False, + ) + copy_stack_trace(node.outputs[0], new_res) + return [new_res] + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([Subtensor]) +def local_subtensor_of_dot(fgraph, node): + """Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``. + ``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is + the remaining entries of ``idxs`` (if any), modified to skip the + second-to-last dimension of ``B`` (because dot sums over this dimension). + """ + if not isinstance(node.op, Subtensor): + return + if not node.inputs[0].owner or not isinstance(node.inputs[0].owner.op, Dot): + return + # If there is other node that use the outputs of the dot + # We don't want to compute twice the sub part. + if len(fgraph.clients[node.inputs[0]]) > 1: + return + + a = node.inputs[0].owner.inputs[0] + b = node.inputs[0].owner.inputs[1] + + idx_list = get_idx_list(node.inputs, node.op.idx_list) + + num_a_indices = min(a.ndim - 1, len(idx_list)) + a_indices = idx_list[:num_a_indices] + b_indices = idx_list[num_a_indices:] + + # This is necessary because np.dot sums the last index of a with the second to last of b + # so we want to skip the second-to-last index into b. + # This wasn't necessary for a, because we just omitted the last index. + # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:] + # (dot also handles b.ndim < 2 as a special case) + if b.ndim > 1 and len(b_indices) >= b.ndim - 1: + b_indices = ( + b_indices[: b.ndim - 2] + + (slice(None, None, None),) + + b_indices[b.ndim - 2 :] + ) + + a_sub = a.__getitem__(tuple(a_indices)) + b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b + + # Copy over previous output stacktrace to a_sub and b_sub, + # because an error in the subtensor operation (e.g. an index error) + # on either a or b must correspond to an error in the + # subtensor operation on their dot product. + copy_stack_trace(node.outputs[0], [a_sub, b_sub]) + + # Copy over previous output stacktrace and previous dot product stacktrace, + # because an error here may correspond to an either in either the original + # dot product, or in the dot product after the subtensor operation. + r = dot(a_sub, b_sub) + copy_stack_trace([node.outputs[0], node.inputs[0]], r) + + return [r] + + +@register_useless +@register_canonicalize +@register_specialize +@node_rewriter([Subtensor]) +def local_useless_slice(fgraph, node): + """ + Remove Subtensor of the form X[0, :] -> X[0] + """ + if isinstance(node.op, Subtensor): + slices = get_idx_list(node.inputs, node.op.idx_list) + last_slice = len(slices) + for s in slices[::-1]: + # check if slice and then check slice indices + if ( + isinstance(s, slice) + and s.start is None + and s.stop is None + and ( + s.step is None + or extract_constant(s.step, only_process_constants=True) == 1 + ) + ): + last_slice -= 1 + else: + break + # check if we removed something + if last_slice < len(slices): + subtens = Subtensor(slices[:last_slice]) + sl_ins = get_slice_elements( + slices[:last_slice], lambda x: isinstance(x, Variable) + ) + out = subtens(node.inputs[0], *sl_ins) + # Copy over previous output stacktrace + copy_stack_trace(node.outputs, out) + return [out] + + +# fast_compile to allow opt subtensor(cast{float32}(make_vector)) +@register_canonicalize("fast_compile") +@node_rewriter([Subtensor]) +def local_subtensor_lift(fgraph, node): + """ + unary(x)[idx] -> unary(x[idx])#any broadcast pattern. + + Handles the following unary ops: + elemwise(x,...)[idx] -> elemwise(x[idx],...) + when x,... are broadcasted scalar or not broadcasted at all + Unbroadcast(x)[idx] => Unbroadcast(x[idx]) + + """ + if isinstance(node.op, Subtensor): + u = node.inputs[0] + if not u.owner or len(fgraph.clients[u]) > 1: + return False + + if isinstance(u.owner.op, Elemwise) and len(u.owner.inputs) == 1: + idx = node.inputs[1:] + x_idx = node.op(u.owner.inputs[0], *idx) + # Copy over previous output stacktrace + copy_stack_trace(node.outputs, x_idx) + ret = u.owner.op(x_idx) + # Copy over previous output stacktrace + # and stacktrace from previous unary operation + copy_stack_trace([node.outputs[0], node.inputs[0]], ret) + return [ret] + + if isinstance(u.owner.op, Elemwise): + new_inputs = [] + if all(sum(i.type.broadcastable) == 0 for i in u.owner.inputs): + # There is no broadcastable in the inputs + idx = node.inputs[1:] + new_inputs = [node.op(i, *idx) for i in u.owner.inputs] + # Copy over previous output stacktrace + copy_stack_trace(node.outputs[0], new_inputs) + + ret = u.owner.op(*new_inputs) + # Copy over previous output stacktrace + # and stacktrace from previous unary operation + copy_stack_trace([node.outputs[0], node.inputs[0]], ret) + return [ret] + elif all(sum(i.type.broadcastable) in [i.ndim, 0] for i in u.owner.inputs): + # There is no broadcastable in the inputs or it is scalar + idx = node.inputs[1:] + new_inputs = [] + for i in u.owner.inputs: + if sum(i.type.broadcastable) == 0: + new_inputs.append(node.op(i, *idx)) + else: + # If the subtensor remove some dims, we must + # lower the number of dimensions of this scalar. + if node.outputs[0].ndim == i.ndim: + new_inputs.append(i) + else: + new_inputs.append( + i.dimshuffle(["x"] * node.outputs[0].ndim) + ) + + # Copy over previous output stacktrace + copy_stack_trace(node.outputs[0], new_inputs) + + ret = u.owner.op(*new_inputs) + # Copy over previous output stacktrace + # and stacktrace from previous unary operation + copy_stack_trace([node.outputs[0], node.inputs[0]], ret) + return [ret] + + if isinstance(u.owner.op, Unbroadcast): + # Subtensor might reduce dim., adapt broadcast pattern accordingly + old_axes = u.owner.op.axes + new_axes = [] + + # loop through indices being subtensor-ed + # i indexes broadcastable pattern before subtensor + # j indexes broadcastable pattern after subtensor + j = 0 + for (i, x) in enumerate(node.op.idx_list): + # if it is not a slice, it will reduce the dimension, should + # not appear in the broascastable dimensions + if isinstance(x, slice): + if i in old_axes: + new_axes.append(j) + j += 1 + # now keep the broadcastable pattern of all + # items not appearing in subtensor list + for i in range(len(node.op.idx_list), len(u.broadcastable)): + if i in old_axes: + new_axes.append(j) + j += 1 + + subt_x = node.op(u.owner.inputs[0], *node.inputs[1:]) + # Copy over previous output stacktrace + copy_stack_trace(node.outputs[0], subt_x) + + rbcast_subt_x = unbroadcast(subt_x, *new_axes) + # Copy over previous output stacktrace + # and stacktrace from previous unary operation + copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x) + + return [rbcast_subt_x] + + +@register_canonicalize +@register_specialize +@node_rewriter([Subtensor]) +def local_subtensor_merge(fgraph, node): + """ + Refactored optimization to deal with all cases of tensor merging. + Given a subgraph of the form Subtensor(Subtensor(u)), the optimization + expresses all slices in a canonical form, and then merges them together. + + """ + + if isinstance(node.op, Subtensor): + u = node.inputs[0] + if u.owner and isinstance(u.owner.op, Subtensor): + # We can merge :) + # x actual tensor on which we are picking slices + x = u.owner.inputs[0] + # slices of the first applied subtensor + slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list) + slices2 = get_idx_list(node.inputs, node.op.idx_list) + # Get the shapes of the vectors ! + try: + # try not to introduce new shape into the graph + xshape = fgraph.shape_feature.shape_of[x] + ushape = fgraph.shape_feature.shape_of[u] + except AttributeError: + # Following the suggested use of shape_feature which should + # consider the case when the compilation mode doesn't + # include the ShapeFeature + xshape = x.shape + ushape = u.shape + + merged_slices = [] + pos_2 = 0 + pos_1 = 0 + while (pos_1 < len(slices1)) and (pos_2 < len(slices2)): + slice1 = slices1[pos_1] + if isinstance(slice1, slice): + merged_slices.append( + merge_two_slices( + fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2] + ) + ) + pos_2 += 1 + else: + merged_slices.append(slice1) + pos_1 += 1 + + if pos_2 < len(slices2): + merged_slices += slices2[pos_2:] + else: + merged_slices += slices1[pos_1:] + + merged_slices = tuple(as_index_constant(s) for s in merged_slices) + subtens = Subtensor(merged_slices) + + sl_ins = get_slice_elements( + merged_slices, lambda x: isinstance(x, Variable) + ) + # Do not call make_node for test_value + out = subtens(x, *sl_ins) + + # Copy over previous output stacktrace + # and stacktrace from previous slicing operation. + # Why? Because, the merged slicing operation could have failed + # because of either of the two original slicing operations + orig_out = node.outputs[0] + copy_stack_trace([orig_out, node.inputs[0]], out) + return [out] + + +@register_specialize +@register_canonicalize +@node_rewriter([Subtensor]) +def local_subtensor_remove_broadcastable_index(fgraph, node): + """ + Remove broadcastable dimension with index 0 or -1 + a[:,:,:,0] -> a.dimshuffle(0,1,2), when + a.broadcastable = (False, False, False, True) + a[0,:,-1,:] -> a.dimshuffle(1,3), when + a.broadcastable = (True, False, True, False) + + """ + if isinstance(node.op, Subtensor): + idx = node.op.idx_list + else: + return + + remove_dim = [] + node_inputs_idx = 1 + for dim, elem in enumerate(idx): + if isinstance(elem, (aes.ScalarType)): + # The idx is a ScalarType, ie a Type. This means the actual index + # is contained in node.inputs[1] + dim_index = node.inputs[node_inputs_idx] + if isinstance(dim_index, aes.ScalarConstant): + dim_index = dim_index.value + if dim_index in (0, -1) and node.inputs[0].broadcastable[dim]: + remove_dim.append(dim) + node_inputs_idx += 1 + else: + return + elif isinstance(elem, slice): + if elem != slice(None): + return + elif isinstance(elem, (int, np.integer)): + if elem in (0, -1) and node.inputs[0].broadcastable[dim]: + remove_dim.append(dim) + else: + raise TypeError("case not expected") + + if len(remove_dim) == 0: + return + else: + all_dim = range(node.inputs[0].ndim) + remain_dim = [x for x in all_dim if x not in remove_dim] + return [node.inputs[0].dimshuffle(tuple(remain_dim))] + + +@register_useless +@register_canonicalize +@register_specialize +@node_rewriter([Subtensor]) +def local_subtensor_of_alloc(fgraph, node): + """ + + alloc(val)[x:y] -> alloc(val[...]) + alloc(val)[x:y] -> alloc(val) + This can be seen as a lift, but it also reduce the number of computation/memory. + + """ + if not isinstance(node.op, Subtensor): + return False + u = node.inputs[0] + if u.owner is None: + return False + if not isinstance(u.owner.op, Alloc): + return False + slices = get_idx_list(node.inputs, node.op.idx_list) + val = u.owner.inputs[0] + dims = u.owner.inputs[1:] + assert len(slices) <= len(dims) + + # Number of dimensions added to val + n_added_dims = u.ndim - val.ndim + # Dimensions of the returned alloc + nw_dims = [] + # Slices to take from val + val_slices = [] + + for i, (sl, dim) in enumerate(zip(slices, dims)): + # If val was not copied over that dim, + # we need to take the appropriate subtensor on it. + if i >= n_added_dims: + # We check that the corresponding val dimensions was + # not a broadcasted dimensions. + if ( + val.type.ndim > (i - n_added_dims) + and val.type.broadcastable[i - n_added_dims] + ): + val_slices.append(slice(None)) + else: + val_slices.append(sl) + + csl, _ = get_canonical_form_slice(sl, dim) + if type(csl) is not slice: + # That dimension is removed. + pass + else: + nw_dim = csl.stop - csl.start + + if csl.step != 1: + # Do not add the ceil_intdiv() graphs in the graphs + # when this is not needed as it prevent detecting the + # correct broadcast pattern. + nw_dim = ceil_intdiv(nw_dim, csl.step) + nw_dims += [nw_dim] + + nw_val = val[tuple(val_slices)] + nw_dims += dims[len(slices) :] + if nw_val.ndim > len(nw_dims): + return False + rval = alloc(nw_val, *nw_dims) + if not isinstance(rval, (list, tuple)): + rval = [rval] + return rval + + +@register_specialize +@register_canonicalize +@node_rewriter([Subtensor]) +def local_subtensor_inc_subtensor(fgraph, node): + """ + Subtensor(SetSubtensor(x, y, idx), idx) -> y + + """ + if isinstance(node.op, Subtensor): + x = node.inputs[0] + if not x.owner or not isinstance(x.owner.op, IncSubtensor): + return + if not x.owner.op.set_instead_of_inc: + return + + if x.owner.inputs[2:] == node.inputs[1:] and tuple( + x.owner.op.idx_list + ) == tuple(node.op.idx_list): + out = node.outputs[0] + y = x.owner.inputs[1] + # If the dtypes differ, cast y into x.dtype + if x.dtype != y.dtype: + y = y.astype(x.dtype) + if ( + out.type.dtype == y.type.dtype + and out.type.broadcastable == y.type.broadcastable + ): + # if x[idx] and y have the same type, directly return y + return [y] + else: + # The difference is related to broadcasting pattern + assert out.broadcastable != y.broadcastable + # We have to alloc y to the shape of x[idx] + x_subtensor = node.op(x.owner.inputs[0], *x.owner.inputs[2:]) + return [alloc(y, *x_subtensor.shape)] + else: + return + + +@register_specialize +@register_canonicalize("fast_compile") +@register_useless +@node_rewriter([Subtensor, AdvancedSubtensor1]) +def local_subtensor_make_vector(fgraph, node): + """Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant. + + Replace all ``Subtensor`` and ``MakeVector`` cases like: + [a,b,c][0] -> a + [a,b,c][0:2] -> [a,b] + + Replace all ``AdvancedSubtensor1`` and ``MakeVector`` cases like: + [a,b,c][[0,2]] -> [a,c] + + We can do this for constant indexes. + + .. note: + + This optimization implicitly relies on shape optimizations. + + TODO: This only applies to a single indexed dimension; we should have + something more general for constant ``*Subtensor*`` graphs (or perhaps + include this kind of work in the constant folding). + """ + + if not isinstance(node.op, (Subtensor, AdvancedSubtensor1)): + return False + + x = node.inputs[0] + + if not x.owner or not isinstance(x.owner.op, MakeVector): + return False + + make_vector_op = x.owner.op + + if isinstance(node.op, Subtensor): + (idx,) = node.op.idx_list + + if isinstance(idx, (aes.ScalarType, TensorType)): + old_idx, idx = idx, node.inputs[1] + assert idx.type.is_super(old_idx) + elif isinstance(node.op, AdvancedSubtensor1): + idx = node.inputs[1] + + if isinstance(idx, (int, np.integer)): + return [x.owner.inputs[idx]] + elif isinstance(idx, Variable): + if idx.ndim == 0: + try: + v = get_scalar_constant_value(idx, only_process_constants=True) + try: + ret = [x.owner.inputs[v]] + except IndexError: + raise NotScalarConstantError("Bad user graph!") + return ret + except NotScalarConstantError: + pass + elif idx.ndim == 1 and isinstance(idx, Constant): + values = list(map(int, list(idx.value))) + ret = make_vector_op(*[x.owner.inputs[v] for v in values]) + copy_stack_trace(node.outputs[0], ret) + return [ret] + elif isinstance(idx, slice): + # The index is a slice. If it's a constant slice, we can perform the + # index operation here. + try: + const_slice = get_constant_idx( + node.op.idx_list, node.inputs, allow_partial=False + )[0] + ret = make_vector_op(*x.owner.inputs[const_slice]) + copy_stack_trace(node.outputs, ret) + return [ret] + except NotScalarConstantError: + pass + + +@register_useless +@register_canonicalize +@register_specialize +@node_rewriter([IncSubtensor]) +def local_useless_inc_subtensor(fgraph, node): + r"""Remove redundant `IncSubtensor`\s. + + More specifically, ``set_subtensor(x[indices], y)`` is replaced by + ``y[indices]`` when ``indices`` are full `slice`\s and ``y``'s shape is + equal to ``x[indices]``, and ``inc_subtensor(x[indices], y)`` is replaced + by ``y[indices]`` when ``x[indices]`` is some array of ``0``\s, ``indices`` + are full slices, and the shapes are equal. + """ + if not isinstance(node.op, IncSubtensor): + return + + if not hasattr(fgraph, "shape_feature"): + return + + x, y, *index_inputs = node.inputs + + if node.op.set_instead_of_inc is False: + # This is an increment operation, so the array being incremented must + # consist of all zeros in order for the entire operation to be useless + try: + c = get_scalar_constant_value(x) + if c != 0: + return + except NotScalarConstantError: + return + + idx_cst = indices_from_subtensor(list(index_inputs), node.op.idx_list) + + # Check that all indices are full slices with only reversals and no step + # sizes + # TODO: It seems like there should be a basic `IncSubtensor` + # canonicalization that removes these redundant slices. + if all( + isinstance(e, slice) + and e.start is None + and e.stop is None + and ( + e.step is None + or extract_constant(e.step, only_process_constants=True) == -1 + ) + for e in idx_cst + ): + + # `IncSubtensor` broadcasts `x` on `y` based on run-time shapes, so we + # must check that they are the same + if not fgraph.shape_feature.same_shape(x, y): + return + + # There are no reversals, so we don't need a replacement. + if all(e.step is None for e in node.op.idx_list): + # They are exactly the same shapes, so we can remove this `IncSubtensor` + return [y] + + new_node = Subtensor(node.op.idx_list).make_node(y, *index_inputs) + new_out = new_node.outputs[0] + copy_stack_trace(node.outputs, new_out) + + return [new_out] + + +@register_canonicalize +@register_specialize +@node_rewriter([AdvancedIncSubtensor1]) +def local_set_to_inc_subtensor(fgraph, node): + r""" + AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) -> + AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False) + + TODO FIXME: Why doesn't this apply to all `*IncSubtensor*` `Op`\s? If it + did this wouldn't need to also be included in the "specialize" pass. + + """ + if ( + isinstance(node.op, AdvancedIncSubtensor1) + and node.op.set_instead_of_inc + and node.inputs[1].owner + and isinstance(node.inputs[1].owner.op, Elemwise) + and isinstance(node.inputs[1].owner.op.scalar_op, aes.Add) + ): + addn = node.inputs[1].owner + subn = None + other = None + + if addn.inputs[0].owner and isinstance( + addn.inputs[0].owner.op, AdvancedSubtensor1 + ): + subn = addn.inputs[0].owner + other = addn.inputs[1] + elif addn.inputs[1].owner and isinstance( + addn.inputs[1].owner.op, AdvancedSubtensor1 + ): + subn = addn.inputs[1].owner + other = addn.inputs[0] + else: + return + if subn.inputs[1] != node.inputs[2] or subn.inputs[0] != node.inputs[0]: + return + ret = advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2]) + + copy_stack_trace(node.outputs, ret) + + return [ret] + + +@register_canonicalize +@register_specialize +@node_rewriter([Subtensor]) +def local_useless_subtensor(fgraph, node): + """Remove `Subtensor` if it takes the full input.""" + # This optimization needs ShapeOpt and fgraph.shape_feature + if not hasattr(fgraph, "shape_feature"): + return + + shape_of = fgraph.shape_feature.shape_of + + cdata = get_constant_idx( + node.op.idx_list, + node.inputs, + allow_partial=True, + only_process_constants=True, + ) + for pos, idx in enumerate(cdata): + if not isinstance(idx, slice): + # If idx is not a slice, this means we remove this dimension + # from the output, so the subtensor is not useless + return False + if idx.start is not None and idx.start != 0: + # If the start of the slice is different from 0, or is a + # variable, then we assume the subtensor is not useless + return False + if idx.step is not None and idx.step != 1: + # If we are going backwards, or skipping elements, then this + # is not a useless subtensor + return False + + length_pos = shape_of[node.inputs[0]][pos] + + if isinstance(idx.stop, (int, np.integer)): + length_pos_data = sys.maxsize + try: + length_pos_data = get_scalar_constant_value( + length_pos, only_process_constants=True + ) + except NotScalarConstantError: + pass + + if idx.stop < length_pos_data: + return False + elif isinstance(idx.stop, Variable): + length_pos_shape_i = idx.stop + # length_pos is a tensor variable, but length_pos_shape_i + # is a scalar variable. We try to see if they represent + # the same underlying variable. + if length_pos_shape_i.owner and isinstance( + length_pos_shape_i.owner.op, ScalarFromTensor + ): + length_pos_shape_i = length_pos_shape_i.owner.inputs[0] + elif length_pos.owner and isinstance(length_pos.owner.op, TensorFromScalar): + length_pos = length_pos.owner.inputs[0] + else: + # We did not find underlying variables of the same type + return False + + # The type can be different: int32 vs int64. length_pos + # should always be int64 as that is what the shape + # tracker keep. Subtensor accept any scalar int{8,16,32,64} + # as index type. + assert str(length_pos.type.dtype) == "int64" + assert str(length_pos_shape_i.type.dtype) in [ + "int8", + "int16", + "int32", + "int64", + ] + + # length_pos_shape_i cannot be None + if length_pos_shape_i != length_pos: + return False + elif idx.stop is None: + continue + else: + return False + + return [node.inputs[0]] + + +@register_canonicalize +@register_specialize +@node_rewriter([AdvancedSubtensor1]) +def local_useless_AdvancedSubtensor1(fgraph, node): + """Remove `AdvancedSubtensor1` if it takes the full input. + + In the `AdvancedSubtensor1` case, the full input is taken when the indices + are equivalent to ``arange(0, input.shape[0], 1)`` using either an explicit + list/vector or the `ARange` `Op`. + + """ + # This optimization needs ShapeOpt and fgraph.shape_feature + if not hasattr(fgraph, "shape_feature"): + return + + shape_of = fgraph.shape_feature.shape_of + + # get length of the indexed tensor along the first axis + try: + length = get_scalar_constant_value( + shape_of[node.inputs[0]][0], only_process_constants=True + ) + except NotScalarConstantError: + return False + + # get index (which must be a vector by definition) + idx = node.inputs[1] + + # `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for + # this optimization + if isinstance(idx, Constant): + idx = idx.value + if len(idx) != length: + return False + if np.any(idx != np.arange(length)): + return False + else: + return False + + return [node.inputs[0]] + + +def merge_two_slices(fgraph, slice1, len1, slice2, len2): + """ + This function merges two slices into a single slice. The code works on + the assumption that: + + a) slice1 is actually a slice and not an index, while slice2 + can be just an index. + + b) the two slices **have been applied consecutively** on the same + tensor + + The output slice is **not** in canonical form, but actually just a slice + that can be applied to a tensor to produce the same output as applying + the two consecutive slices. + ``len1`` is the length of the tensor **before** applying the first slice, + while ``len2`` is the length **after** applying the first slice. + """ + + if not isinstance(slice1, slice): + raise ValueError("slice1 should be of type `slice`") + + sl1, reverse1 = get_canonical_form_slice(slice1, len1) + sl2, reverse2 = get_canonical_form_slice(slice2, len2) + + if not isinstance(sl2, slice): + if reverse1 is None: + # The first slice is not in reverse, which makes things a lot + # more clear. + # In this case we need to take care only of the special cases: + # len2 <=0 -> throw index error regardless of sl2 + # sl2 > len2 -> throw index error + # sl2 < -len2 -> throw index error + # To get a index error we simply use len1+1 to indicate we are + # out of bounds, because passing this index through the formula + # of getting the mixed slice is not guaranteed to result in an + # index error. The **issue though** if that the error will + # complain about accessing element len1+1 which is probably not + # too intuitive for the user + val = sl1.start + sl2 * sl1.step + val = switch(le(len2, 0), len1 + 1, val) + val = switch(ge(sl2, len2), len1 + 1, val) + val = switch(lt(sl2, 0), -len1 - 1, val) + if sl1.step: + val = switch(eq(sl1.step, 0), len1 + 1, val) + return val + else: + # We are in the more complex case when we do not actually know + # if the first slice was in reverse or not. + # in case it was not in reverse: + p_val = sl1.start + sl2 * sl1.step + # case it was in reverse we need to realize that we do not want + # the k-th element from sl.start but the k-th element from + # sl.stop backwards + n_val = sl1.stop - 1 - sl2 * sl1.step + # we need to pick either n_val or p_val and then follow same + # steps as above for covering the index error cases + val = switch(lt(reverse1, 0), n_val, p_val) + val = switch(le(len2, 0), len1 + 1, val) + val = switch(ge(sl2, len2), len1 + 1, val) + val = switch(lt(sl2, 0), -len1 - 1, val) + if sl1.step: + val = switch(eq(sl1.step, 0), len1 + 1, val) + return val + else: + # We are deleaing with two slices that need to be put together + # according to the two steps we have 4 different combinations of + # positive/negative. I will denote the case I'm looking at by + # suffixes to the variables (nn,np,pn,pp): + flen = sl2.stop - sl2.start + p_step = sl1.step * sl2.step + n_step = sl1.step * sl2.step * -1 + + pp_start = minimum(sl1.start + sl2.start * sl1.step, sl1.stop) + pp_stop = minimum(sl1.start + sl2.stop * sl1.step, sl1.stop) + + pn_stop = sl1.start + (sl2.start - 1) * sl1.step + pn_stop = switch( + and_(lt(pn_stop, 0), gt(flen, 0)), + -len1 - 1, + minimum(pn_stop, sl1.stop), + ) + pn_start = sl1.start + (sl2.stop - 1) * sl1.step + pn_start = minimum(pn_start, sl1.stop) + pn_start = maximum(pn_start, 0) + + np_stop = sl1.stop - sl2.stop * sl1.step - 1 + np_stop = switch( + and_(lt(np_stop, 0), gt(flen, 0)), + -len1 - 1, + maximum(sl1.start - 1, np_stop), + ) + np_start = maximum(sl1.start, sl1.stop - sl2.start * sl1.step - 1) + + nn_start = maximum(sl1.start, (sl1.stop - 1) - (sl2.stop - 1) * sl1.step) + nn_stop = maximum(sl1.start, sl1.stop - sl2.start * sl1.step) + + start = switch( + lt(reverse2 * reverse1, 0), + switch(lt(reverse1, 0), np_start, pn_start), + switch(lt(reverse1, 0), nn_start, pp_start), + ) + + stop = switch( + lt(reverse2 * reverse1, 0), + switch(lt(reverse1, 0), np_stop, pn_stop), + switch(lt(reverse1, 0), nn_stop, pp_stop), + ) + + step = switch(lt(reverse2 * reverse1, 0), n_step, p_step) + start = switch(le(flen, 0), 0, start) + stop = switch(le(flen, 0), 0, stop) + + return slice(start, stop, step) + + +@register_canonicalize +@node_rewriter([add]) +def local_IncSubtensor_serialize(fgraph, node): + """ + When using Subtensor, gradient graphs can be ugly. + + If we ask for grad(f(a[0]), a), we are going to get something like + + IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0]) + + This might be ugly, but at least it's as fast as you could want. + If we ask for grad(f(a[0], a[1], a[2]), a), it's much worse... + + Elemwise{Add} + IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0]) + IncSubtensor(Elemwise{second}(a, 0), g(f(a[1])), [1]) + IncSubtensor(Elemwise{second}(a, 0), g(f(a[2])), [2]) + + This is much worse because this time we have to produce 3 matrices + the size of 'a', just so we can add them together. + + This Op rearranges IncSubtensor's that all work on the same + initial argument (here, Elemwise{second}(a,0)) into a chain. The + advantage of the chain structure is that each one can be optimized + later in the pipeline to operate inplace. + + Ideally, the op will do something like this: + + # + # add(x, incsubtensor(b, c), incsubtensor(b, d)) + # -> incsubtensor(incsubtensor(add(x,b,b), c), d) + + """ + + def movable(i): + # Return True iff this is a incsubtensor that we can move + return ( + i.owner + and isinstance( + i.owner.op, + ( + IncSubtensor, + AdvancedIncSubtensor1, + AdvancedIncSubtensor, + ), + ) + and i.type.is_super(o_type) + and len(fgraph.clients[i]) == 1 + and not i.owner.op.set_instead_of_inc + ) + + if node.op == add: + o_type = node.outputs[0].type + + movable_inputs = [i for i in node.inputs if movable(i)] + + if movable_inputs: + new_inputs = [i for i in node.inputs if not movable(i)] + [ + mi.owner.inputs[0] for mi in movable_inputs + ] + if len(new_inputs) == 0: + new_add = new_inputs[0] + else: + new_add = add(*new_inputs) + + # Copy over stacktrace from original output, as an error + # (e.g. an index error) in this add operation should + # correspond to an error in the original add operation. + copy_stack_trace(node.outputs[0], new_add) + + # stack up the new incsubtensors + tip = new_add + for mi in movable_inputs: + assert o_type.is_super(tip.type) + assert mi.owner.inputs[0].type.is_super(tip.type) + tip = mi.owner.op(tip, *mi.owner.inputs[1:]) + # Copy over stacktrace from outputs of the original + # "movable" operation to the new operation. + copy_stack_trace(node.outputs + mi.owner.outputs, tip) + + return [tip] + + # print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs] + + +# We register it in a WalkingGraphRewriter inside the canonizer EQ optimizer. +# Otherwise in some cases it was making the EQ optimizer use 45. In +# the WalkingGraphRewriter, the EQ only use 5 passes. +compile.optdb.register( + "pre_local_IncSubtensor_serialize", + in2out(local_IncSubtensor_serialize), + "fast_run", + # Just before canonizer + position=0.99, +) + + +# after priority 50 Destructive inplace operations +# gemm is the first one now, at priority 70 + + +@node_rewriter([IncSubtensor], inplace=True) +def local_inplace_setsubtensor(fgraph, node): + if isinstance(node.op, IncSubtensor) and not node.op.inplace: + dta = node.op.destroyhandler_tolerate_aliased + new_op = node.op.__class__( + node.op.idx_list, + inplace=True, + set_instead_of_inc=node.op.set_instead_of_inc, + destroyhandler_tolerate_aliased=dta, + ) + new_node = new_op(*node.inputs) + val = getattr(node.outputs[0].tag, "nan_guard_mode_check", True) + new_node.tag.nan_guard_mode_check = val + + # Copy stacktrace from original outputs to new outputs. + # This is sensible, because the new operation is the + # same as the old one, but now with different attributes. + copy_stack_trace(node.outputs, new_node) + return [new_node] + return False + + +compile.optdb.register( + "local_inplace_setsubtensor", + WalkingGraphRewriter( + local_inplace_setsubtensor, failure_callback=WalkingGraphRewriter.warn_inplace + ), + "fast_run", + "inplace", + position=60, +) + + +@node_rewriter([AdvancedIncSubtensor1], inplace=True) +def local_inplace_AdvancedIncSubtensor1(fgraph, node): + if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace: + new_op = node.op.clone_inplace() + new_node = new_op(*node.inputs) + copy_stack_trace(node.outputs, new_node) + return [new_node] + return False + + +compile.optdb.register( + "local_inplace_AdvancedIncSubtensor1", + WalkingGraphRewriter( + local_inplace_AdvancedIncSubtensor1, + failure_callback=WalkingGraphRewriter.warn_inplace, + ), + "fast_run", + "inplace", + position=60, +) + + +@node_rewriter([AdvancedIncSubtensor], inplace=True) +def local_inplace_AdvancedIncSubtensor(fgraph, node): + if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace: + new_op = type(node.op)( + inplace=True, + set_instead_of_inc=node.op.set_instead_of_inc, + ignore_duplicates=node.op.ignore_duplicates, + ) + new_node = new_op(*node.inputs) + copy_stack_trace(node.outputs, new_node) + return [new_node] + return False + + +compile.optdb.register( + "local_inplace_AdvancedIncSubtensor", + WalkingGraphRewriter( + local_inplace_AdvancedIncSubtensor, + failure_callback=WalkingGraphRewriter.warn_inplace, + ), + "fast_run", + "inplace", + position=60, +) + + +# Register old name +@register_canonicalize("local_incsubtensor_of_allocs") +@register_stabilize("local_incsubtensor_of_allocs") +@node_rewriter([IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1]) +def local_incsubtensor_of_zeros(fgraph, node): + """ + IncSubtensor(x, zeros, idx) -> x + + """ + if ( + isinstance(node.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)) + and not node.op.set_instead_of_inc + ): + x = node.inputs[0] + y = node.inputs[1] + try: + # Don't use only_process_constants=True. We need to + # investigate Alloc of 0s but with non constant shape. + if get_scalar_constant_value(y, elemwise=False) == 0: + # No need to copy over the stacktrace, + # because x should already have a stacktrace + return [x] + except NotScalarConstantError: + return + + +@register_canonicalize +@register_specialize +@node_rewriter([IncSubtensor]) +def local_incsubtensor_of_zeros_to_setsubtensor(fgraph, node): + """ + IncSubtensor(zeros, x, ...) -> SetSubtensor(zeros, x, ...) + """ + if isinstance(node.op, (IncSubtensor)) and not node.op.set_instead_of_inc: + x = node.inputs[0] + + if isinstance(x, Constant) and not np.any(x.data): + return [ + IncSubtensor( + node.op.idx_list, + node.op.inplace, + set_instead_of_inc=True, + destroyhandler_tolerate_aliased=node.op.destroyhandler_tolerate_aliased, + )(*node.inputs) + ] + + +@register_canonicalize("local_setsubtensor_of_allocs") +@register_stabilize("local_setsubtensor_of_allocs") +@node_rewriter([IncSubtensor]) +def local_setsubtensor_of_constants(fgraph, node): + """ + SetSubtensor(x, x[idx], idx) -> x + + when x is constant or alloc. + + """ + if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc: + x = node.inputs[0] + y = node.inputs[1] + + # Don't use only_process_constants=True. We need to + # investigate Alloc of 0s but with non constant shape. + try: + replace_x = get_scalar_constant_value(x, elemwise=False) + except NotScalarConstantError: + return + + try: + replace_y = get_scalar_constant_value(y, elemwise=False) + except NotScalarConstantError: + return + + if replace_x == replace_y: + + # No need to copy over the stacktrace, + # because x should already have a stacktrace + return [x] + else: + return False + + +@register_canonicalize +@register_specialize +@node_rewriter([AdvancedSubtensor1]) +def local_adv_sub1_adv_inc_sub1(fgraph, node): + """Rewrite graphs like ``AdvancedSubtensor1(AdvancedSetSubtensor1(...), ...)``. + + .. code:: + + AdvancedSubtensor1(AdvancedSetSubtensor1(x, y, idx), idx) -> y + + + Notes + ----- + This rewrite adds an `AssertOp`; otherwise, it would remove shape and index + error. If you want to get rid of them, see the :ref:`unsafe_rewrites` + section. + + A previous version of this rewrite also matched + ``AdvancedSubtensor1(AdvancedIncSubtensor1(x, y, idx), idx)``. + This is incorrect when there are duplicate indices. + The current version warns the user about potential issues. + + """ + if not isinstance(node.op, AdvancedSubtensor1): + return + inp = node.inputs[0] + if not inp.owner or not isinstance(inp.owner.op, AdvancedIncSubtensor1): + return + idx = node.inputs[1] + idx2 = inp.owner.inputs[2] + x = inp.owner.inputs[0] + y = inp.owner.inputs[1] + if idx is not idx2: + return + if ( + not inp.owner.op.set_instead_of_inc + and + # Don't use only_process_constants=True. We need to + # investigate Alloc of 0s but with non constant shape. + extract_constant(x, elemwise=False) != 0 + ): + return + + if not inp.owner.op.set_instead_of_inc: + return + + cond = [at_all(and_(lt(idx, x.shape[0]), ge(idx, -x.shape[0])))] + if not fgraph.shape_feature.same_shape(idx, y, 0, 0): + cond.append(eq(idx.shape[0], y.shape[0])) + r = Assert( + "Bad indexing or shapes in a AdvancedIncSubtensor1 " "that was optimized away" + )(y, *cond) + copy_stack_trace(y, r) + + if r.dtype == node.outputs[0].dtype: + return [r] + # It is possible that y is upcast or downcast to x.dtype. + # In all case, as we set or add with 0, we can just cast y. + r2 = cast(r, node.outputs[0].dtype) + + # Copy over stacktrace from before casting, since + # we don't expect problems in the casting operation, + # and any problems in the indexing would have been spotted above. + copy_stack_trace(r, r2) + return [r2] + + +@register_specialize +@register_stabilize +@register_canonicalize +@register_useless +@node_rewriter([IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1]) +def local_useless_inc_subtensor_alloc(fgraph, node): + """ + Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of + a fully or partially broadcastable variable, by one that skips the + intermediate `alloc` where possible. + + """ + if isinstance(node.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)): + x = node.inputs[0] + y = node.inputs[1] + i = node.inputs[2:] + + if y.owner is not None and isinstance(y.owner.op, Alloc): + # `z` is the input of the Alloc op, i.e. at.alloc(z, ) + z = y.owner.inputs[0] + + try: + shape_feature = fgraph.shape_feature + except AttributeError: + # The shape feature may not be available in some mode, but we + # need it for this optimization, so don't continue. + return False + + shape_of = shape_feature.shape_of + same_shape = shape_feature.same_shape + + # Get the subtensor of `x` indexed by `i` in order to compare + # shapes later. + if isinstance(node.op, IncSubtensor): + xi = Subtensor(node.op.idx_list)(x, *i) + elif isinstance(node.op, AdvancedIncSubtensor): + xi = advanced_subtensor(x, *i) + elif isinstance(node.op, AdvancedIncSubtensor1): + xi = advanced_subtensor1(x, *i) + else: + raise Exception("Should never happen!") + + reason = "local_useless_incsubtensor_alloc" + + # Add `xi` to the shape feature `fgraph`. This is important for + # shape inference later because the variable must be part of the + # function graph in order to call `same_shape` on it. + if xi not in shape_of: + shape_feature.on_import(fgraph, xi.owner, f"{reason}: add `xi`") + + # `xi` may have more dimensions than `y` since the subtensor ops + # do automatic broadcasting of the increment internally. Thus, we + # need to make the leading implicitly broadcasted dimensions + # explicit for shape comparison later. + if xi.ndim > y.ndim: + y = shape_padleft(y, xi.ndim - y.ndim) + if y not in shape_of: + shape_feature.on_import(fgraph, y.owner, f"{reason}: add `y`") + + # Build `z_broad` explicitly to include extra implicit dimensions. + z_broad = (True,) * (xi.ndim - z.ndim) + z.broadcastable + + cond = [ + # The shapes of `y` and `xi` must either agree or `y` may + # also have shape equal to 1 which may be treated as a + # broadcastable dimension by the subtensor op. + or_(eq(y.shape[k], 1), eq(y.shape[k], xi.shape[k])) + # Loop over all dimensions. + for k in range(xi.ndim) + # We need to check the above shapes, if + # * the pre-alloc increment `z` is broadcastable in + # dimension `k` (if it isn't, then the shapes of `z` and + # `y` are the same by the definition of the `Alloc` op in + # this dimension and replacing `y` by `z` will not hide a + # shape error), and + # * `xi` and `y` do not have the same shape in dimension + # `k` or we cannot infer the shape statically (if the + # shapes of `xi` and `y` are not the same, then replacing + # `y` by `z` will hide the shape error of `y`), and + # * the shape of `y` is not equal to 1 or we cannot infer + # the shape statically (if the shape of `y` is equal to + # 1, then `y` is broadcasted by the inc_subtensor op + # internally, so the shapes of `xi` and `y` do not need + # to match in dimension `k`; else we need to check at + # runtime that the shape of `y` is either 1 or the same + # as `xi` or otherwise replacing `y` by `z` will hide a + # shape error). + if ( + z_broad[k] + and not same_shape(xi, y, dim_x=k, dim_y=k) + and shape_of[y][k] != 1 + ) + ] + + if len(cond) > 0: + msg = "`x[i]` and `y` do not have the same shape." + z = Assert(msg)(z, *cond) + + r = node.op(x, z, *i) + # Copy over stacktrace from previous output, since + # we don't expect problems when removing the intermediate + # alloc operation and so we still want to point at the line + # of the inc_subtensor operation. + copy_stack_trace(node.outputs, r) + + return [r] + + +@register_specialize +@register_canonicalize +@node_rewriter([Subtensor]) +def local_subtensor_shape_constant(fgraph, node): + r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known. + + We want to convert graphs like + + Subtensor{int64} [id A] '' + |Shape [id B] '' + | | [id C] + |ScalarConstant{0} [id D] + + into + + TensorConstant{1} + + TODO: Something like `local_shape_to_shape_i` should be a general + canonicalization, and not a `ShapeFeature`-dependent rewrite. If that were + the case, we could change this to only operate on `Shape_i`\s. + Currently, we're not handling them because they should only appear when + `ShapeFeature` is present, and it will also simplify/remove them. + + """ + if not isinstance(node.op, Subtensor): + return False + + shape = node.inputs[0] + + if not (shape.owner and isinstance(shape.owner.op, Shape)): + return False + + shape_arg = shape.owner.inputs[0] + + (idx,) = get_idx_list(node.inputs, node.op.idx_list) + + try: + idx_val = as_index_literal(idx) + except NotScalarConstantError: + return False + + assert idx_val != np.newaxis + + if not isinstance(shape_arg.type, TensorType): + return False + + shape_parts = shape_arg.type.broadcastable[idx_val] + + if isinstance(shape_parts, Iterable): + if all(shape_parts): + return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)] + elif shape_parts: + return [as_tensor(1, dtype=np.int64)] + + +@register_canonicalize +@node_rewriter([Subtensor]) +def local_subtensor_SpecifyShape_lift(fgraph, node): + """Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``.""" + + if not isinstance(node.op, Subtensor): + return False + + specify_shape_node = node.inputs[0] + + if not ( + specify_shape_node.owner + and isinstance(specify_shape_node.owner.op, SpecifyShape) + ): + return False + + obj_arg = specify_shape_node.owner.inputs[0] + shape_arg = specify_shape_node.owner.inputs[1:] + + indices = get_idx_list(node.inputs, node.op.idx_list) + + if any( + isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType) + for index in indices + ): + return False + + new_obj_arg = obj_arg[indices] + # No need to specify shape for scalar outputs + if new_obj_arg.ndim == 0: + return [new_obj_arg] + return [specify_shape(new_obj_arg, shape_arg[len(indices) :])] + + +@register_specialize +@node_rewriter([Join]) +def local_join_subtensors(fgraph, node): + r"""Simplify contiguous :class:`Subtensor`\s inside a :class:`Join`. + + `join((x[:3], x[3:5]), axis=0) -> x[:5]` + """ + # TODO: Generalize to AdvancedSubtensors + + axis, tensors = node.inputs[0], node.inputs[1:] + + try: + axis = get_scalar_constant_value(axis) + except NotScalarConstantError: + return + + for subtensor1_idx, (subtensor1, subtensor2) in enumerate( + zip(tensors[:-1], tensors[1:]) + ): + # Check that two consecutive Subtensors are operating on the same base tensor + if not ( + ( + subtensor1.owner is not None + and isinstance(subtensor1.owner.op, Subtensor) + ) + and ( + subtensor2.owner is not None + and isinstance(subtensor2.owner.op, Subtensor) + ) + and (subtensor1.owner.inputs[0] is subtensor2.owner.inputs[0]) + ): + continue + + # Check that subtensors have consecutive indexes across the join axis + idxs_subtensor1 = indices_from_subtensor( + subtensor1.owner.inputs[1:], subtensor1.owner.op.idx_list + ) + idxs_subtensor2 = indices_from_subtensor( + subtensor2.owner.inputs[1:], subtensor2.owner.op.idx_list + ) + try: + idxs_axis_subtensor1 = idxs_subtensor1[axis] + idxs_axis_subtensor2 = idxs_subtensor2[axis] + except IndexError: + continue + if not ( + isinstance(idxs_axis_subtensor1, slice) + and isinstance(idxs_axis_subtensor2, slice) + ): + continue + start_subtensor1, stop_subtensor1, step_subtensor1 = ( + idxs_axis_subtensor1.start, + idxs_axis_subtensor1.stop, + idxs_axis_subtensor1.step, + ) + start_subtensor2, stop_subtensor2, step_subtensor2 = ( + idxs_axis_subtensor2.start, + idxs_axis_subtensor2.stop, + idxs_axis_subtensor2.step, + ) + if not ( + (stop_subtensor1 is not None and start_subtensor2 is not None) + and (stop_subtensor1 == start_subtensor2) + ): + continue + + # Check that step is None or 1 + # For non-unit steps (perhaps except for -1) we would need to know the + # exact values of start and stop to know if they can be merged + for step in (step_subtensor1, step_subtensor2): + if step is None: + continue + try: + if get_scalar_constant_value(step, only_process_constants=True) != 1: + return None + except NotScalarConstantError: + return None + + # Check that all other idxs of subtensor are the same + if all( + idxs_nonaxis_subtensor1 == idxs_nonaxis_subtensor2 + for i, (idxs_nonaxis_subtensor1, idxs_nonaxis_subtensor2) in enumerate( + zip(idxs_subtensor1, idxs_subtensor2) + ) + if i != axis + ): + + base_tensor = subtensor1.owner.inputs[0] + new_idxs = list(idxs_subtensor1) + new_idxs[axis] = slice(start_subtensor1, stop_subtensor2, step_subtensor1) + merged_subtensors = base_tensor[new_idxs] + + new_joined_tensors = [ + *tensors[:subtensor1_idx], + merged_subtensors, + *tensors[subtensor1_idx + 2 :], + ] + if len(new_joined_tensors) > 1: + return [concatenate(new_joined_tensors, axis=axis)] + else: + return [merged_subtensors] + + +@register_specialize +@node_rewriter( + [ + Subtensor, + AdvancedSubtensor1, + AdvancedSubtensor, + IncSubtensor, + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + ] +) +def local_uint_constant_indices(fgraph, node): + """Convert constant indices to unsigned dtypes.""" + + if isinstance(node.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)): + x, y, *indices = node.inputs + else: + x, *indices = node.inputs + y = None + + idx_list = getattr(node.op, "idx_list", None) + new_indices = list(indices_from_subtensor(indices, idx_list)) + has_new_index = False + + for i, index in enumerate(new_indices): + + if not isinstance(index, Constant): + continue + + index_val = index.data + + if index_val is None or isinstance(index_val, slice): + # TODO: If slice index dtypes matter, we can consider converting + # those, as well. + continue + + assert isinstance(index_val, (np.generic, np.ndarray)) + + if index_val.size == 0: + continue + + if index_val.dtype == bool: + continue + + if np.ndim(index_val) > 0: + minval = index_val.min() + else: + minval = index_val + + if minval >= 0: + maxval = index_val.max() + dtype = np.min_scalar_type(maxval) + else: + # If we can't convert to unsigned, then don't attempt to minimize + # the type size either--at least not for now. + # dtype = np.min_scalar_type(-max(-minval, maxval)) + continue + + if dtype == index_val.dtype: + continue + + if index_val.ndim > 0: + new_index = aesara.tensor.as_tensor_variable( + index_val.astype(dtype), dtype=dtype + ) + else: + new_index = aes.constant(index_val.astype(dtype), dtype=dtype) + + new_indices[i] = new_index + has_new_index = True + + if not has_new_index: + return False + + new_out = x[tuple(new_indices)] + + if y is not None: + new_out = inc_subtensor( + new_out, + y, + inplace=node.op.inplace, + set_instead_of_inc=node.op.set_instead_of_inc, + ignore_duplicates=getattr(node.op, "ignore_duplicates", False), + ) + + new_outs = new_out.owner.outputs + copy_stack_trace(node.outputs, new_outs) + + return new_outs diff --git a/aesara/tensor/rewriting/uncanonicalize.py b/aesara/tensor/rewriting/uncanonicalize.py new file mode 100644 index 0000000000..6ef4f29bb6 --- /dev/null +++ b/aesara/tensor/rewriting/uncanonicalize.py @@ -0,0 +1,251 @@ +""" +This file implement specialization optimization that break the +canonization form of the graph. + +Currently there is problem with the order of optimization and the +definition of definition of canonized graph. + +Right now there is a canonization optimization phase that try to make +all equivalent graph identical. This is not always the case, but it do +many of the basic stuff canonical. We need to extend the definition of +canonization to make this true more often. + +The problem this file indent to fix in the future is that in the +"Equilibrium" specialization optimization phase, there is optimization +that request that the graph is canonical, some other request that this +is not true, and some other that break the canonicalization for some +optimization. As we can't control the order of those optimization, there +is case that some optimization requesting a canonical graph won't be +applied as optimization that break the canonicalization form of the +graph executed before. + +To fix this, we need to split the specialization phase into a phase +where optimization can't break the canonicalization form and one where +this is allowed. This is also needed for the stabilized optimization +phase, but as it happen before the specialization phase, this cause less +problem. + +Also, we should make the fgraph refuse optimization that break the +canonization of the graph in the optimizations phases where the graph is +supposed to be canonical. + +""" + +from aesara import scalar as aes +from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter +from aesara.tensor.basic import Alloc, alloc, constant +from aesara.tensor.elemwise import CAReduce, DimShuffle +from aesara.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg +from aesara.tensor.rewriting.basic import register_uncanonicalize +from aesara.tensor.shape import Reshape, reshape +from aesara.tensor.subtensor import Subtensor + + +@register_uncanonicalize +@node_rewriter([MaxAndArgmax]) +def local_max_and_argmax(fgraph, node): + """ + If we don't use the argmax, change it to a max only. + """ + if isinstance(node.op, MaxAndArgmax): + axis = node.op.get_params(node) + if len(fgraph.clients[node.outputs[1]]) == 0: + new = Max(axis)(node.inputs[0]) + copy_stack_trace(node.outputs[0], new) + return [new, None] + + if len(fgraph.clients[node.outputs[0]]) == 0: + new = Argmax(axis)(node.inputs[0]) + copy_stack_trace(node.outputs[0], new) + return [None, new] + + +@register_uncanonicalize +@node_rewriter([neg]) +def local_max_to_min(fgraph, node): + """ + Change -(max(-x)) to min. + + This is tested in tensor/tests/test_basic.py:test_min_max. + + Notes + ----- + We don't need an opt that will do the reverse as by default + the interface put only MaxAndArgmax into the graph. + + """ + if node.op == neg and node.inputs[0].owner: + max = node.inputs[0] + if ( + max.owner + and isinstance(max.owner.op, CAReduce) + and max.owner.op.scalar_op == aes.scalar_maximum + ): + neg_node = max.owner.inputs[0] + if neg_node.owner and neg_node.owner.op == neg: + new = Min(max.owner.op.axis)(neg_node.owner.inputs[0]) + return [copy_stack_trace(node.outputs[0], new)] + + return False + + +@register_uncanonicalize +@node_rewriter([Alloc]) +def local_alloc_dimshuffle(fgraph, node): + """ + If a dimshuffle is inside an alloc and only adds dimension to the + left, remove it. + + Alloc(DimShuffle(x), ...) - > Alloc(x, ...) + """ + if isinstance(node.op, Alloc): + input_ = node.inputs[0] + if input_.owner and isinstance(input_.owner.op, DimShuffle): + # check if it only adds dimension to the left + new_order = input_.owner.op.new_order + expected_new_order = ("x",) * ( + input_.ndim - input_.owner.inputs[0].ndim + ) + tuple(range(input_.owner.inputs[0].ndim)) + if new_order != expected_new_order: + return False + return [alloc(input_.owner.inputs[0], *node.inputs[1:])] + return False + + +@register_uncanonicalize +@node_rewriter([Reshape]) +def local_reshape_dimshuffle(fgraph, node): + """ + If a dimshuffle is inside a reshape and does not change the order + of dimensions, remove it. + + Reshape(Dimshuffle(x), shp) -> Reshape(x, shp) + """ + if isinstance(node.op, Reshape): + input_ = node.inputs[0] + if input_.owner and isinstance(input_.owner.op, DimShuffle): + new_order = input_.owner.op.new_order + offset = 0 + for dim in new_order: + if dim == "x": + continue + elif dim != offset: + return False + else: + offset += 1 + return [ + reshape( + input_.owner.inputs[0], node.inputs[1], ndim=node.outputs[0].ndim + ) + ] + return False + + +@register_uncanonicalize +@node_rewriter([DimShuffle]) +def local_dimshuffle_alloc(fgraph, node): + """ + If an alloc is inside a dimshuffle which only adds dimension to the left, + scrap the dimshuffle and adds 1 into the alloc + + dimshuffle{x, 0, 1}(alloc([3 4], 3, 2) => alloc([3 4], 1, 3, 2) + """ + if isinstance(node.op, DimShuffle) and node.inputs[0].owner: + input_ = node.inputs[0] + if isinstance(input_.owner.op, Alloc): + # check if it only adds dimension to the left + new_order = node.op.new_order + expected_new_order = ("x",) * (len(new_order) - input_.ndim) + tuple( + range(input_.ndim) + ) + if new_order != expected_new_order: + return False + + # count numbers of 'x' + nb_new_dims = len(new_order) - input_.ndim + new_shape_input = (1,) * nb_new_dims + tuple(input_.owner.inputs[1:]) + + return [alloc(input_.owner.inputs[0], *new_shape_input)] + return False + + +@register_uncanonicalize +@node_rewriter([DimShuffle]) +def local_dimshuffle_subtensor(fgraph, node): + """If a subtensor is inside a dimshuffle which only drop + broadcastable dimensions, scrap the dimshuffle and index the + subtensor with 0 + + x[i:j, :, k:l].dimshuffle(0, 2) => + x[i:j, 0, k:l] if x.broadcastable == (False, True, False) + + """ + if isinstance(node.op, DimShuffle) and node.inputs[0].owner: + # the dimshuffle can only drop dimensions (cannot reshape nor add 'x') + if "x" in node.op.new_order: + return False + new_order = node.op.new_order + # new order could be empty + # Verif that we don't change dimensions order. + if len(new_order) > 1: + past_dim = new_order[0] + for dim in new_order[1:]: + if not dim > past_dim: + return False + else: + past_dim = dim + + input_ = node.inputs[0] + if isinstance(input_.owner.op, Subtensor): + # the arguments missing from the dimshuffles must be dims + # that are broadcastable + broadcastable = input_.broadcastable + + missing_dims = list(range(input_.ndim)) + for dim in new_order: + missing_dims.remove(dim) + + if not all(broadcastable[i] for i in missing_dims): + return False + + # create a new idx_list for a new Subtensor object + # have to loop on idx_list and inputs + # inputs has the length of sum of non None elements of idx_list + # (check in slice!). + # len(missing_dims) can be < len(idx_list), this happens if + # tensor was indexed such as x[scalar, :, :], check that as well + new_idx_list = list(input_.owner.op.idx_list) + new_inputs = [input_.owner.inputs[0]] + zero = constant(0) + slice_attr_list = ["start", "stop", "step"] + j = 0 + slice_i = -1 + subtensor_removed_dims = 0 + for i, idx in enumerate(input_.owner.op.idx_list): + if isinstance(idx, slice): + past_j = j + slice_i += 1 + for slice_attr in slice_attr_list: + if getattr(idx, slice_attr) is not None: + new_inputs += [input_.owner.inputs[1 + j]] + j += 1 + # if past_j == j indicates a slice(None, None, None), + # that's where we want to index with 0 if it is also at + # the same spot of a missing dim + if past_j == j and slice_i in missing_dims: + new_idx_list[i] = zero + new_inputs += [zero] + else: + new_inputs += [input_.owner.inputs[1 + j]] + j += 1 + subtensor_removed_dims += 1 + # Verify the trailing dimensions the subtensor didn't look at. + for idx in range(len(input_.owner.op.idx_list), new_inputs[0].ndim): + if (idx - subtensor_removed_dims) in missing_dims: + while len(new_idx_list) < idx: + new_idx_list.append(slice(None)) + + new_idx_list.append(zero) + new_inputs.append(zero) + return [Subtensor(new_idx_list)(*new_inputs)] + return False diff --git a/aesara/tensor/slinalg.py b/aesara/tensor/slinalg.py index 0adba5cbef..8bcd353a72 100644 --- a/aesara/tensor/slinalg.py +++ b/aesara/tensor/slinalg.py @@ -101,9 +101,8 @@ def tril_and_halve_diagonal(mtx): def conjugate_solve_triangular(outer, inner): """Computes L^{-T} P L^{-1} for lower-triangular L.""" - return solve_upper_triangular( - outer.T, solve_upper_triangular(outer.T, inner.T).T - ) + solve_upper = SolveTriangular(lower=False) + return solve_upper(outer.T, solve_upper(outer.T, inner.T).T) s = conjugate_solve_triangular( chol_x, tril_and_halve_diagonal(chol_x.T.dot(dz)) @@ -507,15 +506,6 @@ def solve(a, b, assume_a="gen", lower=False, check_finite=True): )(a, b) -# TODO: These are deprecated; emit a warning -solve_lower_triangular = SolveTriangular(lower=True) -solve_upper_triangular = SolveTriangular(lower=False) -solve_symmetric = Solve(assume_a="sym") - -# TODO: Optimizations to replace multiplication by matrix inverse -# with solve() Op (still unwritten) - - class Eigvalsh(Op): """ Generalized eigenvalues of a Hermitian positive definite eigensystem. @@ -748,10 +738,45 @@ def perform(self, node, inputs, outputs): __all__ = [ "cholesky", "solve", - "solve_lower_triangular", - "solve_upper_triangular", - "solve_symmetric", "eigvalsh", "kron", "expm", ] + +DEPRECATED_NAMES = [ + ( + "solve_lower_triangular", + "`solve_lower_triangular` is deprecated; use `solve` instead.", + SolveTriangular(lower=True), + ), + ( + "solve_upper_triangular", + "`solve_upper_triangular` is deprecated; use `solve` instead.", + SolveTriangular(lower=False), + ), + ( + "solve_symmetric", + "`solve_symmetric` is deprecated; use `solve` instead.", + Solve(assume_a="sym"), + ), +] + + +def __getattr__(name): + """Intercept module-level attribute access of deprecated symbols. + + Adapted from https://stackoverflow.com/a/55139609/3006474. + + """ + from warnings import warn + + for old_name, msg, old_object in DEPRECATED_NAMES: + if name == old_name: + warn(msg, DeprecationWarning, stacklevel=2) + return old_object + + raise AttributeError(f"module {__name__} has no attribute {name}") + + +def __dir__(): + return sorted(__all__ + [names[0] for names in DEPRECATED_NAMES]) diff --git a/aesara/tensor/subtensor.py b/aesara/tensor/subtensor.py index 340e85686f..f9abfaf784 100644 --- a/aesara/tensor/subtensor.py +++ b/aesara/tensor/subtensor.py @@ -41,6 +41,10 @@ iscalar, lscalar, tensor, + ubscalar, + uiscalar, + ulscalar, + uwscalar, wscalar, zscalar, ) @@ -50,12 +54,25 @@ _logger = logging.getLogger("aesara.tensor.subtensor") invalid_scal_types = (aes.float64, aes.float32, aes.float16) -scal_types = (aes.int64, aes.int32, aes.int16, aes.int8) +scal_types = ( + aes.int64, + aes.int32, + aes.int16, + aes.int8, + aes.uint64, + aes.uint32, + aes.uint16, + aes.uint8, +) tensor_types = ( lscalar, iscalar, wscalar, bscalar, + ulscalar, + uiscalar, + uwscalar, + ubscalar, ) invalid_tensor_types = ( fscalar, @@ -376,7 +393,7 @@ def slice_len(slc, n): def is_basic_idx(idx): """Determine if an index is of the NumPy basic type. - XXX: This only checks a single index, so an integers is *not* considered a + XXX: This only checks a single index, so an integer is *not* considered a basic index, because--depending on the other indices its used with--an integer can indicate advanced indexing. diff --git a/aesara/tensor/subtensor_opt.py b/aesara/tensor/subtensor_opt.py index bdf777ecfd..3047683330 100644 --- a/aesara/tensor/subtensor_opt.py +++ b/aesara/tensor/subtensor_opt.py @@ -1,1753 +1,10 @@ -import sys -from collections.abc import Iterable +import warnings -import numpy as np -import aesara -import aesara.scalar.basic as aes -from aesara import compile -from aesara.graph.basic import Constant, Variable -from aesara.graph.opt import TopoOptimizer, copy_stack_trace, in2out, local_optimizer -from aesara.raise_op import Assert -from aesara.tensor.basic import ( - Alloc, - ARange, - Join, - MakeVector, - ScalarFromTensor, - TensorFromScalar, - alloc, - as_tensor, - cast, - concatenate, - extract_constant, - get_scalar_constant_value, - switch, +warnings.warn( + "The module `aesara.tensor.subtensor_opt` is deprecated; use `aesara.tensor.rewriting.subtensor` instead.", + DeprecationWarning, + stacklevel=2, ) -from aesara.tensor.basic_opt import ( - register_canonicalize, - register_specialize, - register_stabilize, -) -from aesara.tensor.elemwise import Elemwise -from aesara.tensor.exceptions import NotScalarConstantError -from aesara.tensor.math import Dot, add -from aesara.tensor.math import all as at_all -from aesara.tensor.math import ( - and_, - ceil_intdiv, - dot, - eq, - ge, - gt, - le, - lt, - maximum, - minimum, - or_, -) -from aesara.tensor.shape import ( - Shape, - SpecifyShape, - Unbroadcast, - shape_padleft, - shape_tuple, - specify_shape, - unbroadcast, -) -from aesara.tensor.sharedvar import TensorSharedVariable -from aesara.tensor.subtensor import ( - AdvancedIncSubtensor, - AdvancedIncSubtensor1, - AdvancedSubtensor, - AdvancedSubtensor1, - IncSubtensor, - Subtensor, - advanced_inc_subtensor1, - advanced_subtensor, - advanced_subtensor1, - as_index_constant, - as_index_literal, - get_canonical_form_slice, - get_constant_idx, - get_idx_list, - get_slice_elements, - inc_subtensor, - indices_from_subtensor, -) -from aesara.tensor.type import TensorType -from aesara.tensor.type_other import NoneTypeT, SliceConstant, SliceType -from aesara.tensor.var import TensorConstant, TensorVariable - - -def register_useless(lopt, *tags, **kwargs): - if isinstance(lopt, str): - - def register(inner_lopt): - return register_useless(inner_lopt, lopt, *tags, **kwargs) - - return register - else: - name = kwargs.pop("name", None) or lopt.__name__ - - compile.mode.local_useless.register( - name, lopt, "fast_run", *tags, position="last", **kwargs - ) - return lopt - - -def transform_take(a, indices, axis): - r"""Transform ``arr[:,:,:,indices,...]``-like operations into single-dimensional, vector index operations. - - This effectively converts certain `AdvancedSubtensor` `Op`\s into a - combination of `AdvancedSubtensor1`, `Dimshuffle`, and `Reshape` `Op`\s, - which can be more efficient. - - Parameters - ---------- - a : TensorVariable - The source array. - indices : TensorVariable, ndarray, list, tuple - The indices of the values to extract. - axis : int - The axis over which to select values. By default, the flattened - input array is used. - - """ - a = aesara.tensor.as_tensor_variable(a) - indices = aesara.tensor.as_tensor_variable(indices) - # We can use the more efficient `AdvancedSubtensor1` if `indices` is a vector - if indices.ndim == 1: - if axis == 0: - return advanced_subtensor1(a, indices) - else: - shuffle = list(range(a.ndim)) - shuffle[0] = axis - shuffle[axis] = 0 - res = advanced_subtensor1(a.dimshuffle(shuffle), indices).dimshuffle( - shuffle - ) - return res - - # We can reshape and flatten the indices in order to use an - # `AdvancedSubtensor1` `Op` per the above - indices_shape = shape_tuple(indices) - a_shape = shape_tuple(a) - - shape_parts = [ - a_shape[:axis], - indices_shape, - a_shape[axis + 1 :], - ] - - shape_parts = [sp for sp in shape_parts if len(sp) > 0] - - assert len(shape_parts) > 0 - - if len(shape_parts) > 1: - shape = aesara.tensor.concatenate(shape_parts) - else: - shape = shape_parts[0] - - ndim = a.ndim + indices.ndim - 1 - - return transform_take(a, indices.flatten(), axis).reshape(shape, ndim) - - -def is_full_slice(x): - """Determine if `x` is a ``slice(None)`` or a symbolic equivalent.""" - if ( - (isinstance(x, slice) and x == slice(None)) - or (isinstance(x, SliceConstant) and x.value == slice(None)) - or ( - not isinstance(x, SliceConstant) - and isinstance(getattr(x, "type", None), SliceType) - and x.owner is not None - and all( - isinstance(getattr(i, "type", None), NoneTypeT) for i in x.owner.inputs - ) - ) - ): - return True - return False - - -def get_advsubtensor_axis(indices): - """Determine the axis at which an array index is applied. - - This only works for ``take``-like indices: e.g. ``x[:, :, idx, ...]``. For - the above example, `get_advsubtensor_axis` would return ``2``. If it - encounters anything other than a set of `indices` containing full slices - and an array/tensor index, it will return ``None``. - - """ - found_idx = False - axis = 0 - for idx in indices: - if not found_idx and is_full_slice(idx): - # Preceding full slices - axis += 1 - elif found_idx and not is_full_slice(idx): - # We don't handle multiple indices - return - elif found_idx and is_full_slice(idx): - # Trailing full slices - continue - else: - found_idx = True - - if isinstance( - indices[axis], (TensorConstant, TensorVariable, TensorSharedVariable) - ): - return axis - - -@register_specialize -@local_optimizer([AdvancedSubtensor]) -def local_replace_AdvancedSubtensor(fgraph, node): - r""" - This rewrite converts expressions like ``X[..., y]`` into ``X.T[y].T``, for - a vector ``y``, and ``X[z, ...]`` into ``X[z.flatten()].reshape(...)``, for a - matrix ``z``. - - These rewrites replace `AdvancedSubtensor`\s with the more efficient - `AdvancedSubtensor1` and `Subtensor` `Op`\s. - """ - - if not isinstance(node.op, AdvancedSubtensor): - return - - indexed_var = node.inputs[0] - indices = node.inputs[1:] - - axis = get_advsubtensor_axis(indices) - - if axis is None or indices[axis].dtype == "bool": - # Booleans aren't handled - return - - new_res = transform_take(indexed_var, indices[axis], axis) - copy_stack_trace(node.outputs[0], new_res) - return [new_res] - - -@register_specialize -@local_optimizer([AdvancedIncSubtensor]) -def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): - r"""Replace `AdvancedIncSubtensor`\s with `AdvancedIncSubtensor1`\s. - - This is only done when there's a single vector index. - """ - - if not isinstance(node.op, AdvancedIncSubtensor) or node.op.ignore_duplicates: - # `AdvancedIncSubtensor1` does not ignore duplicate index values - return - - res = node.inputs[0] - val = node.inputs[1] - indices = node.inputs[2:] - - axis = get_advsubtensor_axis(indices) - - if axis is None or indices[axis].dtype == "bool": - # Booleans aren't currently handled by `AdvancedIncSubtensor1` - return - - new_subtensor = transform_take(res, indices[axis], axis) - - new_res = inc_subtensor( - new_subtensor, - val, - inplace=node.op.inplace, - set_instead_of_inc=node.op.set_instead_of_inc, - ignore_duplicates=False, - ) - copy_stack_trace(node.outputs[0], new_res) - return [new_res] - - -@register_canonicalize -@register_stabilize -@register_specialize -@local_optimizer([Subtensor]) -def local_subtensor_of_dot(fgraph, node): - """Rewrite ``at.dot(A, B)[idxs]`` into ``at.dot(A[idxs_a], B[idxs_b])``. - ``idxs_a`` is the first ``A.ndim-1`` entries of ``idxs``, and ``idxs_b`` is - the remaining entries of ``idxs`` (if any), modified to skip the - second-to-last dimension of ``B`` (because dot sums over this dimension). - """ - if not isinstance(node.op, Subtensor): - return - if not node.inputs[0].owner or not isinstance(node.inputs[0].owner.op, Dot): - return - # If there is other node that use the outputs of the dot - # We don't want to compute twice the sub part. - if len(fgraph.clients[node.inputs[0]]) > 1: - return - - a = node.inputs[0].owner.inputs[0] - b = node.inputs[0].owner.inputs[1] - - idx_list = get_idx_list(node.inputs, node.op.idx_list) - - num_a_indices = min(a.ndim - 1, len(idx_list)) - a_indices = idx_list[:num_a_indices] - b_indices = idx_list[num_a_indices:] - - # This is necessary because np.dot sums the last index of a with the second to last of b - # so we want to skip the second-to-last index into b. - # This wasn't necessary for a, because we just omitted the last index. - # We skip this if b.ndim = 1, since then we just want b_sub = b, not b_sub = b[:] - # (dot also handles b.ndim < 2 as a special case) - if b.ndim > 1 and len(b_indices) >= b.ndim - 1: - b_indices = ( - b_indices[: b.ndim - 2] - + (slice(None, None, None),) - + b_indices[b.ndim - 2 :] - ) - - a_sub = a.__getitem__(tuple(a_indices)) - b_sub = b.__getitem__(tuple(b_indices)) if b_indices else b - - # Copy over previous output stacktrace to a_sub and b_sub, - # because an error in the subtensor operation (e.g. an index error) - # on either a or b must correspond to an error in the - # subtensor operation on their dot product. - copy_stack_trace(node.outputs[0], [a_sub, b_sub]) - - # Copy over previous output stacktrace and previous dot product stacktrace, - # because an error here may correspond to an either in either the original - # dot product, or in the dot product after the subtensor operation. - r = dot(a_sub, b_sub) - copy_stack_trace([node.outputs[0], node.inputs[0]], r) - - return [r] - - -@register_useless -@register_canonicalize -@register_specialize -@local_optimizer([Subtensor]) -def local_useless_slice(fgraph, node): - """ - Remove Subtensor of the form X[0, :] -> X[0] - """ - if isinstance(node.op, Subtensor): - slices = get_idx_list(node.inputs, node.op.idx_list) - last_slice = len(slices) - for s in slices[::-1]: - # check if slice and then check slice indices - if ( - isinstance(s, slice) - and s.start is None - and s.stop is None - and ( - s.step is None - or extract_constant(s.step, only_process_constants=True) == 1 - ) - ): - last_slice -= 1 - else: - break - # check if we removed something - if last_slice < len(slices): - subtens = Subtensor(slices[:last_slice]) - sl_ins = get_slice_elements( - slices[:last_slice], lambda x: isinstance(x, Variable) - ) - out = subtens(node.inputs[0], *sl_ins) - # Copy over previous output stacktrace - copy_stack_trace(node.outputs, out) - return [out] - - -# fast_compile to allow opt subtensor(cast{float32}(make_vector)) -@register_canonicalize("fast_compile") -@local_optimizer([Subtensor]) -def local_subtensor_lift(fgraph, node): - """ - unary(x)[idx] -> unary(x[idx])#any broadcast pattern. - - Handles the following unary ops: - elemwise(x,...)[idx] -> elemwise(x[idx],...) - when x,... are broadcasted scalar or not broadcasted at all - Unbroadcast(x)[idx] => Unbroadcast(x[idx]) - - """ - if isinstance(node.op, Subtensor): - u = node.inputs[0] - if not u.owner or len(fgraph.clients[u]) > 1: - return False - - if isinstance(u.owner.op, Elemwise) and len(u.owner.inputs) == 1: - idx = node.inputs[1:] - x_idx = node.op(u.owner.inputs[0], *idx) - # Copy over previous output stacktrace - copy_stack_trace(node.outputs, x_idx) - ret = u.owner.op(x_idx) - # Copy over previous output stacktrace - # and stacktrace from previous unary operation - copy_stack_trace([node.outputs[0], node.inputs[0]], ret) - return [ret] - - if isinstance(u.owner.op, Elemwise): - new_inputs = [] - if all(sum(i.type.broadcastable) == 0 for i in u.owner.inputs): - # There is no broadcastable in the inputs - idx = node.inputs[1:] - new_inputs = [node.op(i, *idx) for i in u.owner.inputs] - # Copy over previous output stacktrace - copy_stack_trace(node.outputs[0], new_inputs) - - ret = u.owner.op(*new_inputs) - # Copy over previous output stacktrace - # and stacktrace from previous unary operation - copy_stack_trace([node.outputs[0], node.inputs[0]], ret) - return [ret] - elif all(sum(i.type.broadcastable) in [i.ndim, 0] for i in u.owner.inputs): - # There is no broadcastable in the inputs or it is scalar - idx = node.inputs[1:] - new_inputs = [] - for i in u.owner.inputs: - if sum(i.type.broadcastable) == 0: - new_inputs.append(node.op(i, *idx)) - else: - # If the subtensor remove some dims, we must - # lower the number of dimensions of this scalar. - if node.outputs[0].ndim == i.ndim: - new_inputs.append(i) - else: - new_inputs.append( - i.dimshuffle(["x"] * node.outputs[0].ndim) - ) - - # Copy over previous output stacktrace - copy_stack_trace(node.outputs[0], new_inputs) - - ret = u.owner.op(*new_inputs) - # Copy over previous output stacktrace - # and stacktrace from previous unary operation - copy_stack_trace([node.outputs[0], node.inputs[0]], ret) - return [ret] - - if isinstance(u.owner.op, Unbroadcast): - # Subtensor might reduce dim., adapt broadcast pattern accordingly - old_axes = u.owner.op.axes - new_axes = [] - - # loop through indices being subtensor-ed - # i indexes broadcastable pattern before subtensor - # j indexes broadcastable pattern after subtensor - j = 0 - for (i, x) in enumerate(node.op.idx_list): - # if it is not a slice, it will reduce the dimension, should - # not appear in the broascastable dimensions - if isinstance(x, slice): - if i in old_axes: - new_axes.append(j) - j += 1 - # now keep the broadcastable pattern of all - # items not appearing in subtensor list - for i in range(len(node.op.idx_list), len(u.broadcastable)): - if i in old_axes: - new_axes.append(j) - j += 1 - - subt_x = node.op(u.owner.inputs[0], *node.inputs[1:]) - # Copy over previous output stacktrace - copy_stack_trace(node.outputs[0], subt_x) - - rbcast_subt_x = unbroadcast(subt_x, *new_axes) - # Copy over previous output stacktrace - # and stacktrace from previous unary operation - copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x) - - return [rbcast_subt_x] - - -@register_canonicalize -@register_specialize -@local_optimizer([Subtensor]) -def local_subtensor_merge(fgraph, node): - """ - Refactored optimization to deal with all cases of tensor merging. - Given a subgraph of the form Subtensor(Subtensor(u)), the optimization - expresses all slices in a canonical form, and then merges them together. - - """ - - if isinstance(node.op, Subtensor): - u = node.inputs[0] - if u.owner and isinstance(u.owner.op, Subtensor): - # We can merge :) - # x actual tensor on which we are picking slices - x = u.owner.inputs[0] - # slices of the first applied subtensor - slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list) - slices2 = get_idx_list(node.inputs, node.op.idx_list) - # Get the shapes of the vectors ! - try: - # try not to introduce new shape into the graph - xshape = fgraph.shape_feature.shape_of[x] - ushape = fgraph.shape_feature.shape_of[u] - except AttributeError: - # Following the suggested use of shape_feature which should - # consider the case when the compilation mode doesn't - # include the ShapeFeature - xshape = x.shape - ushape = u.shape - - merged_slices = [] - pos_2 = 0 - pos_1 = 0 - while (pos_1 < len(slices1)) and (pos_2 < len(slices2)): - slice1 = slices1[pos_1] - if isinstance(slice1, slice): - merged_slices.append( - merge_two_slices( - fgraph, slice1, xshape[pos_1], slices2[pos_2], ushape[pos_2] - ) - ) - pos_2 += 1 - else: - merged_slices.append(slice1) - pos_1 += 1 - - if pos_2 < len(slices2): - merged_slices += slices2[pos_2:] - else: - merged_slices += slices1[pos_1:] - - merged_slices = tuple(as_index_constant(s) for s in merged_slices) - subtens = Subtensor(merged_slices) - - sl_ins = get_slice_elements( - merged_slices, lambda x: isinstance(x, Variable) - ) - # Do not call make_node for test_value - out = subtens(x, *sl_ins) - - # Copy over previous output stacktrace - # and stacktrace from previous slicing operation. - # Why? Because, the merged slicing operation could have failed - # because of either of the two original slicing operations - orig_out = node.outputs[0] - copy_stack_trace([orig_out, node.inputs[0]], out) - return [out] - - -@register_specialize -@register_canonicalize -@local_optimizer([Subtensor]) -def local_subtensor_remove_broadcastable_index(fgraph, node): - """ - Remove broadcastable dimension with index 0 or -1 - a[:,:,:,0] -> a.dimshuffle(0,1,2), when - a.broadcastable = (False, False, False, True) - a[0,:,-1,:] -> a.dimshuffle(1,3), when - a.broadcastable = (True, False, True, False) - - """ - if isinstance(node.op, Subtensor): - idx = node.op.idx_list - else: - return - - remove_dim = [] - node_inputs_idx = 1 - for dim, elem in enumerate(idx): - if isinstance(elem, (aes.ScalarType)): - # The idx is a ScalarType, ie a Type. This means the actual index - # is contained in node.inputs[1] - dim_index = node.inputs[node_inputs_idx] - if isinstance(dim_index, aes.ScalarConstant): - dim_index = dim_index.value - if dim_index in (0, -1) and node.inputs[0].broadcastable[dim]: - remove_dim.append(dim) - node_inputs_idx += 1 - else: - return - elif isinstance(elem, slice): - if elem != slice(None): - return - elif isinstance(elem, (int, np.integer)): - if elem in (0, -1) and node.inputs[0].broadcastable[dim]: - remove_dim.append(dim) - else: - raise TypeError("case not expected") - - if len(remove_dim) == 0: - return - else: - all_dim = range(node.inputs[0].ndim) - remain_dim = [x for x in all_dim if x not in remove_dim] - return [node.inputs[0].dimshuffle(tuple(remain_dim))] - - -@register_useless -@register_canonicalize -@register_specialize -@local_optimizer([Subtensor]) -def local_subtensor_of_alloc(fgraph, node): - """ - - alloc(val)[x:y] -> alloc(val[...]) - alloc(val)[x:y] -> alloc(val) - This can be seen as a lift, but it also reduce the number of computation/memory. - - """ - if not isinstance(node.op, Subtensor): - return False - u = node.inputs[0] - if u.owner is None: - return False - if not isinstance(u.owner.op, Alloc): - return False - slices = get_idx_list(node.inputs, node.op.idx_list) - val = u.owner.inputs[0] - dims = u.owner.inputs[1:] - assert len(slices) <= len(dims) - - # Number of dimensions added to val - n_added_dims = u.ndim - val.ndim - # Dimensions of the returned alloc - nw_dims = [] - # Slices to take from val - val_slices = [] - - for i, (sl, dim) in enumerate(zip(slices, dims)): - # If val was not copied over that dim, - # we need to take the appropriate subtensor on it. - if i >= n_added_dims: - # We check that the corresponding val dimensions was - # not a broadcasted dimensions. - if ( - val.type.ndim > (i - n_added_dims) - and val.type.broadcastable[i - n_added_dims] - ): - val_slices.append(slice(None)) - else: - val_slices.append(sl) - - csl, _ = get_canonical_form_slice(sl, dim) - if type(csl) is not slice: - # That dimension is removed. - pass - else: - nw_dim = csl.stop - csl.start - - if csl.step != 1: - # Do not add the ceil_intdiv() graphs in the graphs - # when this is not needed as it prevent detecting the - # correct broadcast pattern. - nw_dim = ceil_intdiv(nw_dim, csl.step) - nw_dims += [nw_dim] - - nw_val = val[tuple(val_slices)] - nw_dims += dims[len(slices) :] - if nw_val.ndim > len(nw_dims): - return False - rval = alloc(nw_val, *nw_dims) - if not isinstance(rval, (list, tuple)): - rval = [rval] - return rval - - -@register_specialize -@register_canonicalize -@local_optimizer([Subtensor]) -def local_subtensor_inc_subtensor(fgraph, node): - """ - Subtensor(SetSubtensor(x, y, idx), idx) -> y - - """ - if isinstance(node.op, Subtensor): - x = node.inputs[0] - if not x.owner or not isinstance(x.owner.op, IncSubtensor): - return - if not x.owner.op.set_instead_of_inc: - return - - if x.owner.inputs[2:] == node.inputs[1:] and tuple( - x.owner.op.idx_list - ) == tuple(node.op.idx_list): - out = node.outputs[0] - y = x.owner.inputs[1] - # If the dtypes differ, cast y into x.dtype - if x.dtype != y.dtype: - y = y.astype(x.dtype) - if ( - out.type.dtype == y.type.dtype - and out.type.broadcastable == y.type.broadcastable - ): - # if x[idx] and y have the same type, directly return y - return [y] - else: - # The difference is related to broadcasting pattern - assert out.broadcastable != y.broadcastable - # We have to alloc y to the shape of x[idx] - x_subtensor = node.op(x.owner.inputs[0], *x.owner.inputs[2:]) - return [alloc(y, *x_subtensor.shape)] - else: - return - - -@register_specialize -@register_canonicalize("fast_compile") -@register_useless -@local_optimizer([Subtensor, AdvancedSubtensor1]) -def local_subtensor_make_vector(fgraph, node): - """Perform ``*Subtensor*`` operations on ``MakeVector`` outputs when the indices are constant. - - Replace all ``Subtensor`` and ``MakeVector`` cases like: - [a,b,c][0] -> a - [a,b,c][0:2] -> [a,b] - - Replace all ``AdvancedSubtensor1`` and ``MakeVector`` cases like: - [a,b,c][[0,2]] -> [a,c] - - We can do this for constant indexes. - - .. note: - - This optimization implicitly relies on shape optimizations. - - TODO: This only applies to a single indexed dimension; we should have - something more general for constant ``*Subtensor*`` graphs (or perhaps - include this kind of work in the constant folding). - """ - - if not isinstance(node.op, (Subtensor, AdvancedSubtensor1)): - return False - - x = node.inputs[0] - - if not x.owner or not isinstance(x.owner.op, MakeVector): - return False - - make_vector_op = x.owner.op - - if isinstance(node.op, Subtensor): - (idx,) = node.op.idx_list - - if isinstance(idx, (aes.ScalarType, TensorType)): - old_idx, idx = idx, node.inputs[1] - assert idx.type.is_super(old_idx) - elif isinstance(node.op, AdvancedSubtensor1): - idx = node.inputs[1] - - if isinstance(idx, (int, np.integer)): - return [x.owner.inputs[idx]] - elif isinstance(idx, Variable): - if idx.ndim == 0: - try: - v = get_scalar_constant_value(idx, only_process_constants=True) - try: - ret = [x.owner.inputs[v]] - except IndexError: - raise NotScalarConstantError("Bad user graph!") - return ret - except NotScalarConstantError: - pass - elif idx.ndim == 1 and isinstance(idx, Constant): - values = list(map(int, list(idx.value))) - ret = make_vector_op(*[x.owner.inputs[v] for v in values]) - copy_stack_trace(node.outputs[0], ret) - return [ret] - elif isinstance(idx, slice): - # The index is a slice. If it's a constant slice, we can perform the - # index operation here. - try: - const_slice = get_constant_idx( - node.op.idx_list, node.inputs, allow_partial=False - )[0] - ret = make_vector_op(*x.owner.inputs[const_slice]) - copy_stack_trace(node.outputs, ret) - return [ret] - except NotScalarConstantError: - pass - - -@register_useless -@register_canonicalize -@register_specialize -@local_optimizer([IncSubtensor]) -def local_useless_inc_subtensor(fgraph, node): - r"""Remove redundant `IncSubtensor`\s. - - More specifically, ``set_subtensor(x[indices], y)`` is replaced by - ``y[indices]`` when ``indices`` are full `slice`\s and ``y``'s shape is - equal to ``x[indices]``, and ``inc_subtensor(x[indices], y)`` is replaced - by ``y[indices]`` when ``x[indices]`` is some array of ``0``\s, ``indices`` - are full slices, and the shapes are equal. - """ - if not isinstance(node.op, IncSubtensor): - return - - if not hasattr(fgraph, "shape_feature"): - return - - x, y, *index_inputs = node.inputs - - if node.op.set_instead_of_inc is False: - # This is an increment operation, so the array being incremented must - # consist of all zeros in order for the entire operation to be useless - try: - c = get_scalar_constant_value(x) - if c != 0: - return - except NotScalarConstantError: - return - - idx_cst = indices_from_subtensor(list(index_inputs), node.op.idx_list) - - # Check that all indices are full slices with only reversals and no step - # sizes - # TODO: It seems like there should be a basic `IncSubtensor` - # canonicalization that removes these redundant slices. - if all( - isinstance(e, slice) - and e.start is None - and e.stop is None - and ( - e.step is None - or extract_constant(e.step, only_process_constants=True) == -1 - ) - for e in idx_cst - ): - - # `IncSubtensor` broadcasts `x` on `y` based on run-time shapes, so we - # must check that they are the same - if not fgraph.shape_feature.same_shape(x, y): - return - - # There are no reversals, so we don't need a replacement. - if all(e.step is None for e in node.op.idx_list): - # They are exactly the same shapes, so we can remove this `IncSubtensor` - return [y] - - new_node = Subtensor(node.op.idx_list).make_node(y, *index_inputs) - new_out = new_node.outputs[0] - copy_stack_trace(node.outputs, new_out) - - return [new_out] - - -@register_canonicalize -@register_specialize -@local_optimizer([AdvancedIncSubtensor1]) -def local_set_to_inc_subtensor(fgraph, node): - r""" - AdvancedIncSubtensor1(x, x[ilist]+other, ilist, set_instead_of_inc=True) -> - AdvancedIncSubtensor1(x, other, ilist, set_instead_of_inc=False) - - TODO FIXME: Why doesn't this apply to all `*IncSubtensor*` `Op`\s? If it - did this wouldn't need to also be included in the "specialize" pass. - - """ - if ( - isinstance(node.op, AdvancedIncSubtensor1) - and node.op.set_instead_of_inc - and node.inputs[1].owner - and isinstance(node.inputs[1].owner.op, Elemwise) - and isinstance(node.inputs[1].owner.op.scalar_op, aes.Add) - ): - addn = node.inputs[1].owner - subn = None - other = None - - if addn.inputs[0].owner and isinstance( - addn.inputs[0].owner.op, AdvancedSubtensor1 - ): - subn = addn.inputs[0].owner - other = addn.inputs[1] - elif addn.inputs[1].owner and isinstance( - addn.inputs[1].owner.op, AdvancedSubtensor1 - ): - subn = addn.inputs[1].owner - other = addn.inputs[0] - else: - return - if subn.inputs[1] != node.inputs[2] or subn.inputs[0] != node.inputs[0]: - return - ret = advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2]) - - copy_stack_trace(node.outputs, ret) - - return [ret] - - -@register_canonicalize -@register_specialize -@local_optimizer([Subtensor, AdvancedSubtensor1]) -def local_useless_subtensor(fgraph, node): - """ - Remove Subtensor/AdvancedSubtensor1 if it takes the full input. In the - AdvancedSubtensor1 case, the full input is taken when the indices are - equivalent to `arange(0, input.shape[0], 1)` using either an explicit - list/vector or the ARange op. - - """ - # This optimization needs ShapeOpt and fgraph.shape_feature - if not hasattr(fgraph, "shape_feature"): - return - - shape_of = fgraph.shape_feature.shape_of - - if isinstance(node.op, Subtensor): - cdata = get_constant_idx( - node.op.idx_list, - node.inputs, - allow_partial=True, - only_process_constants=True, - ) - for pos, idx in enumerate(cdata): - if not isinstance(idx, slice): - # If idx is not a slice, this means we remove this dimension - # from the output, so the subtensor is not useless - return False - if idx.start is not None and idx.start != 0: - # If the start of the slice is different from 0, or is a - # variable, then we assume the subtensor is not useless - return False - if idx.step is not None and idx.step != 1: - # If we are going backwards, or skipping elements, then this - # is not a useless subtensor - return False - - for pos, idx in enumerate(cdata): - - length_pos = shape_of[node.inputs[0]][pos] - - if isinstance(idx.stop, (int, np.integer)): - length_pos_data = sys.maxsize - try: - length_pos_data = get_scalar_constant_value( - length_pos, only_process_constants=True - ) - except NotScalarConstantError: - pass - - if idx.stop < length_pos_data: - return False - elif isinstance(idx.stop, Variable): - length_pos_shape_i = idx.stop - # length_pos is a tensor variable, but length_pos_shape_i - # is a scalar variable. We try to see if they represent - # the same underlying variable. - if length_pos_shape_i.owner and isinstance( - length_pos_shape_i.owner.op, ScalarFromTensor - ): - length_pos_shape_i = length_pos_shape_i.owner.inputs[0] - elif length_pos.owner and isinstance( - length_pos.owner.op, TensorFromScalar - ): - length_pos = length_pos.owner.inputs[0] - else: - # We did not find underlying variables of the same type - return False - - # The type can be different: int32 vs int64. length_pos - # should always be int64 as that is what the shape - # tracker keep. Subtensor accept any scalar int{8,16,32,64} - # as index type. - assert str(length_pos.type.dtype) == "int64" - assert str(length_pos_shape_i.type.dtype) in [ - "int8", - "int16", - "int32", - "int64", - ] - - # length_pos_shape_i cannot be None - if length_pos_shape_i != length_pos: - return False - elif idx.stop is None: - pass - else: - return False - elif isinstance(node.op, AdvancedSubtensor1): - # get length of the indexed tensor along the first axis - try: - length = get_scalar_constant_value( - shape_of[node.inputs[0]][0], only_process_constants=True - ) - except NotScalarConstantError: - return False - - # get index (which must be a vector by definition) - idx = node.inputs[1] - - # `idx` must be equivalent to [0,1,...,shape[0] - 1] to qualify for - # this optimization - if isinstance(idx, Constant): - idx = idx.value - if len(idx) != length: - return False - if np.any(idx != np.arange(length)): - return False - elif idx.owner is not None and isinstance(idx.owner.op, ARange): - try: - start, stop, step = map( - lambda x: get_scalar_constant_value(x, only_process_constants=True), - idx.owner.inputs, - ) - except NotScalarConstantError: - return False - - if start != 0: - return False - if stop != length: - return False - if step != 1: - return False - else: - return False - else: - return False - - # We don't need to copy over any stacktrace here, - # because previous stacktrace should suffice. - return [node.inputs[0]] - - -def merge_two_slices(fgraph, slice1, len1, slice2, len2): - """ - This function merges two slices into a single slice. The code works on - the assumption that: - - a) slice1 is actually a slice and not an index, while slice2 - can be just an index. - - b) the two slices **have been applied consecutively** on the same - tensor - - The output slice is **not** in canonical form, but actually just a slice - that can be applied to a tensor to produce the same output as applying - the two consecutive slices. - ``len1`` is the length of the tensor **before** applying the first slice, - while ``len2`` is the length **after** applying the first slice. - """ - - if not isinstance(slice1, slice): - raise ValueError("slice1 should be of type `slice`") - - sl1, reverse1 = get_canonical_form_slice(slice1, len1) - sl2, reverse2 = get_canonical_form_slice(slice2, len2) - - if not isinstance(sl2, slice): - if reverse1 is None: - # The first slice is not in reverse, which makes things a lot - # more clear. - # In this case we need to take care only of the special cases: - # len2 <=0 -> throw index error regardless of sl2 - # sl2 > len2 -> throw index error - # sl2 < -len2 -> throw index error - # To get a index error we simply use len1+1 to indicate we are - # out of bounds, because passing this index through the formula - # of getting the mixed slice is not guaranteed to result in an - # index error. The **issue though** if that the error will - # complain about accessing element len1+1 which is probably not - # too intuitive for the user - val = sl1.start + sl2 * sl1.step - val = switch(le(len2, 0), len1 + 1, val) - val = switch(ge(sl2, len2), len1 + 1, val) - val = switch(lt(sl2, 0), -len1 - 1, val) - if sl1.step: - val = switch(eq(sl1.step, 0), len1 + 1, val) - return val - else: - # We are in the more complex case when we do not actually know - # if the first slice was in reverse or not. - # in case it was not in reverse: - p_val = sl1.start + sl2 * sl1.step - # case it was in reverse we need to realize that we do not want - # the k-th element from sl.start but the k-th element from - # sl.stop backwards - n_val = sl1.stop - 1 - sl2 * sl1.step - # we need to pick either n_val or p_val and then follow same - # steps as above for covering the index error cases - val = switch(lt(reverse1, 0), n_val, p_val) - val = switch(le(len2, 0), len1 + 1, val) - val = switch(ge(sl2, len2), len1 + 1, val) - val = switch(lt(sl2, 0), -len1 - 1, val) - if sl1.step: - val = switch(eq(sl1.step, 0), len1 + 1, val) - return val - else: - # We are deleaing with two slices that need to be put together - # according to the two steps we have 4 different combinations of - # positive/negative. I will denote the case I'm looking at by - # suffixes to the variables (nn,np,pn,pp): - flen = sl2.stop - sl2.start - p_step = sl1.step * sl2.step - n_step = sl1.step * sl2.step * -1 - - pp_start = minimum(sl1.start + sl2.start * sl1.step, sl1.stop) - pp_stop = minimum(sl1.start + sl2.stop * sl1.step, sl1.stop) - - pn_stop = sl1.start + (sl2.start - 1) * sl1.step - pn_stop = switch( - and_(lt(pn_stop, 0), gt(flen, 0)), - -len1 - 1, - minimum(pn_stop, sl1.stop), - ) - pn_start = sl1.start + (sl2.stop - 1) * sl1.step - pn_start = minimum(pn_start, sl1.stop) - pn_start = maximum(pn_start, 0) - - np_stop = sl1.stop - sl2.stop * sl1.step - 1 - np_stop = switch( - and_(lt(np_stop, 0), gt(flen, 0)), - -len1 - 1, - maximum(sl1.start - 1, np_stop), - ) - np_start = maximum(sl1.start, sl1.stop - sl2.start * sl1.step - 1) - - nn_start = maximum(sl1.start, (sl1.stop - 1) - (sl2.stop - 1) * sl1.step) - nn_stop = maximum(sl1.start, sl1.stop - sl2.start * sl1.step) - - start = switch( - lt(reverse2 * reverse1, 0), - switch(lt(reverse1, 0), np_start, pn_start), - switch(lt(reverse1, 0), nn_start, pp_start), - ) - - stop = switch( - lt(reverse2 * reverse1, 0), - switch(lt(reverse1, 0), np_stop, pn_stop), - switch(lt(reverse1, 0), nn_stop, pp_stop), - ) - - step = switch(lt(reverse2 * reverse1, 0), n_step, p_step) - start = switch(le(flen, 0), 0, start) - stop = switch(le(flen, 0), 0, stop) - - return slice(start, stop, step) - - -@register_canonicalize -@local_optimizer([add]) -def local_IncSubtensor_serialize(fgraph, node): - """ - When using Subtensor, gradient graphs can be ugly. - - If we ask for grad(f(a[0]), a), we are going to get something like - - IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0]) - - This might be ugly, but at least it's as fast as you could want. - If we ask for grad(f(a[0], a[1], a[2]), a), it's much worse... - - Elemwise{Add} - IncSubtensor(Elemwise{second}(a, 0), g(f(a[0])), [0]) - IncSubtensor(Elemwise{second}(a, 0), g(f(a[1])), [1]) - IncSubtensor(Elemwise{second}(a, 0), g(f(a[2])), [2]) - - This is much worse because this time we have to produce 3 matrices - the size of 'a', just so we can add them together. - - This Op rearranges IncSubtensor's that all work on the same - initial argument (here, Elemwise{second}(a,0)) into a chain. The - advantage of the chain structure is that each one can be optimized - later in the pipeline to operate inplace. - - Ideally, the op will do something like this: - - # - # add(x, incsubtensor(b, c), incsubtensor(b, d)) - # -> incsubtensor(incsubtensor(add(x,b,b), c), d) - - """ - - def movable(i): - # Return True iff this is a incsubtensor that we can move - return ( - i.owner - and isinstance( - i.owner.op, - ( - IncSubtensor, - AdvancedIncSubtensor1, - AdvancedIncSubtensor, - ), - ) - and i.type.is_super(o_type) - and len(fgraph.clients[i]) == 1 - and not i.owner.op.set_instead_of_inc - ) - - if node.op == add: - o_type = node.outputs[0].type - - movable_inputs = [i for i in node.inputs if movable(i)] - - if movable_inputs: - new_inputs = [i for i in node.inputs if not movable(i)] + [ - mi.owner.inputs[0] for mi in movable_inputs - ] - if len(new_inputs) == 0: - new_add = new_inputs[0] - else: - new_add = add(*new_inputs) - - # Copy over stacktrace from original output, as an error - # (e.g. an index error) in this add operation should - # correspond to an error in the original add operation. - copy_stack_trace(node.outputs[0], new_add) - - # stack up the new incsubtensors - tip = new_add - for mi in movable_inputs: - assert o_type.is_super(tip.type) - assert mi.owner.inputs[0].type.is_super(tip.type) - tip = mi.owner.op(tip, *mi.owner.inputs[1:]) - # Copy over stacktrace from outputs of the original - # "movable" operation to the new operation. - copy_stack_trace(node.outputs + mi.owner.outputs, tip) - - return [tip] - - # print incsub_inputs, [id(i.owner.inputs[0]) for i in incsub_inputs] - - -# We register it in a TopoOptimizer inside the canonizer EQ optimizer. -# Otherwise in some cases it was making the EQ optimizer use 45. In -# the TopoOptimizer, the EQ only use 5 passes. -compile.optdb.register( - "pre_local_IncSubtensor_serialize", - in2out(local_IncSubtensor_serialize), - "fast_run", - # Just before canonizer - position=0.99, -) - - -# after priority 50 Destructive inplace operations -# gemm is the first one now, at priority 70 - - -@local_optimizer([IncSubtensor], inplace=True) -def local_inplace_setsubtensor(fgraph, node): - if isinstance(node.op, IncSubtensor) and not node.op.inplace: - dta = node.op.destroyhandler_tolerate_aliased - new_op = node.op.__class__( - node.op.idx_list, - inplace=True, - set_instead_of_inc=node.op.set_instead_of_inc, - destroyhandler_tolerate_aliased=dta, - ) - new_node = new_op(*node.inputs) - val = getattr(node.outputs[0].tag, "nan_guard_mode_check", True) - new_node.tag.nan_guard_mode_check = val - - # Copy stacktrace from original outputs to new outputs. - # This is sensible, because the new operation is the - # same as the old one, but now with different attributes. - copy_stack_trace(node.outputs, new_node) - return [new_node] - return False - - -compile.optdb.register( - "local_inplace_setsubtensor", - TopoOptimizer( - local_inplace_setsubtensor, failure_callback=TopoOptimizer.warn_inplace - ), - "fast_run", - "inplace", - position=60, -) - - -@local_optimizer([AdvancedIncSubtensor1], inplace=True) -def local_inplace_AdvancedIncSubtensor1(fgraph, node): - if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace: - new_op = node.op.clone_inplace() - new_node = new_op(*node.inputs) - copy_stack_trace(node.outputs, new_node) - return [new_node] - return False - - -compile.optdb.register( - "local_inplace_AdvancedIncSubtensor1", - TopoOptimizer( - local_inplace_AdvancedIncSubtensor1, failure_callback=TopoOptimizer.warn_inplace - ), - "fast_run", - "inplace", - position=60, -) - - -@local_optimizer([AdvancedIncSubtensor], inplace=True) -def local_inplace_AdvancedIncSubtensor(fgraph, node): - if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace: - new_op = type(node.op)( - inplace=True, - set_instead_of_inc=node.op.set_instead_of_inc, - ignore_duplicates=node.op.ignore_duplicates, - ) - new_node = new_op(*node.inputs) - copy_stack_trace(node.outputs, new_node) - return [new_node] - return False - - -compile.optdb.register( - "local_inplace_AdvancedIncSubtensor", - TopoOptimizer( - local_inplace_AdvancedIncSubtensor, failure_callback=TopoOptimizer.warn_inplace - ), - "fast_run", - "inplace", - position=60, -) - - -# Register old name -@register_canonicalize("local_incsubtensor_of_allocs") -@register_stabilize("local_incsubtensor_of_allocs") -@local_optimizer([IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1]) -def local_incsubtensor_of_zeros(fgraph, node): - """ - IncSubtensor(x, zeros, idx) -> x - - """ - if ( - isinstance(node.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)) - and not node.op.set_instead_of_inc - ): - x = node.inputs[0] - y = node.inputs[1] - try: - # Don't use only_process_constants=True. We need to - # investigate Alloc of 0s but with non constant shape. - if get_scalar_constant_value(y, elemwise=False) == 0: - # No need to copy over the stacktrace, - # because x should already have a stacktrace - return [x] - except NotScalarConstantError: - return - - -@register_canonicalize -@register_specialize -@local_optimizer([IncSubtensor]) -def local_incsubtensor_of_zeros_to_setsubtensor(fgraph, node): - """ - IncSubtensor(zeros, x, ...) -> SetSubtensor(zeros, x, ...) - """ - if isinstance(node.op, (IncSubtensor)) and not node.op.set_instead_of_inc: - x = node.inputs[0] - - if isinstance(x, Constant) and not np.any(x.data): - return [ - IncSubtensor( - node.op.idx_list, - node.op.inplace, - set_instead_of_inc=True, - destroyhandler_tolerate_aliased=node.op.destroyhandler_tolerate_aliased, - )(*node.inputs) - ] - - -@register_canonicalize("local_setsubtensor_of_allocs") -@register_stabilize("local_setsubtensor_of_allocs") -@local_optimizer([IncSubtensor]) -def local_setsubtensor_of_constants(fgraph, node): - """ - SetSubtensor(x, x[idx], idx) -> x - - when x is constant or alloc. - - """ - if isinstance(node.op, IncSubtensor) and node.op.set_instead_of_inc: - x = node.inputs[0] - y = node.inputs[1] - - # Don't use only_process_constants=True. We need to - # investigate Alloc of 0s but with non constant shape. - try: - replace_x = get_scalar_constant_value(x, elemwise=False) - except NotScalarConstantError: - return - - try: - replace_y = get_scalar_constant_value(y, elemwise=False) - except NotScalarConstantError: - return - - if replace_x == replace_y: - - # No need to copy over the stacktrace, - # because x should already have a stacktrace - return [x] - else: - return False - - -@register_canonicalize -@register_specialize -@local_optimizer([AdvancedSubtensor1]) -def local_adv_sub1_adv_inc_sub1(fgraph, node): - """Optimize the possible AdvSub1(AdvSetSub1(...), ...). - - AdvancedSubtensor1(AdvancedSetSubtensor1(x, y, idx), idx) -> y - - Notes - ----- - This opt add AssertOp. Otherwise, it would remove shape and - index error. If you want to get rid of them, see the - :ref:`unsafe_optimization` section. - - WARNING: - A previous version of this optimization also matched - AdvancedSubtensor1(AdvancedIncSubtensor1(0s, y, idx), idx) -> y - This is incorrect when there are duplicate indices. - The current version warns the user about potential past issues. - - """ - if not isinstance(node.op, AdvancedSubtensor1): - return - inp = node.inputs[0] - if not inp.owner or not isinstance(inp.owner.op, AdvancedIncSubtensor1): - return - idx = node.inputs[1] - idx2 = inp.owner.inputs[2] - x = inp.owner.inputs[0] - y = inp.owner.inputs[1] - if idx is not idx2: - return - if ( - not inp.owner.op.set_instead_of_inc - and - # Don't use only_process_constants=True. We need to - # investigate Alloc of 0s but with non constant shape. - extract_constant(x, elemwise=False) != 0 - ): - return - - if not inp.owner.op.set_instead_of_inc: - return - - cond = [at_all(and_(lt(idx, x.shape[0]), ge(idx, -x.shape[0])))] - if not fgraph.shape_feature.same_shape(idx, y, 0, 0): - cond.append(eq(idx.shape[0], y.shape[0])) - r = Assert( - "Bad indexing or shapes in a AdvancedIncSubtensor1 " "that was optimized away" - )(y, *cond) - copy_stack_trace(y, r) - - if r.dtype == node.outputs[0].dtype: - return [r] - # It is possible that y is upcast or downcast to x.dtype. - # In all case, as we set or add with 0, we can just cast y. - r2 = cast(r, node.outputs[0].dtype) - - # Copy over stacktrace from before casting, since - # we don't expect problems in the casting operation, - # and any problems in the indexing would have been spotted above. - copy_stack_trace(r, r2) - return [r2] - - -@register_specialize -@register_stabilize -@register_canonicalize -@register_useless -@local_optimizer([IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1]) -def local_useless_inc_subtensor_alloc(fgraph, node): - """ - Replaces an [Advanced]IncSubtensor[1], whose increment is an `alloc` of - a fully or partially broadcastable variable, by one that skips the - intermediate `alloc` where possible. - - """ - if isinstance(node.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)): - x = node.inputs[0] - y = node.inputs[1] - i = node.inputs[2:] - - if y.owner is not None and isinstance(y.owner.op, Alloc): - # `z` is the input of the Alloc op, i.e. at.alloc(z, ) - z = y.owner.inputs[0] - - try: - shape_feature = fgraph.shape_feature - except AttributeError: - # The shape feature may not be available in some mode, but we - # need it for this optimization, so don't continue. - return False - - shape_of = shape_feature.shape_of - same_shape = shape_feature.same_shape - - # Get the subtensor of `x` indexed by `i` in order to compare - # shapes later. - if isinstance(node.op, IncSubtensor): - xi = Subtensor(node.op.idx_list)(x, *i) - elif isinstance(node.op, AdvancedIncSubtensor): - xi = advanced_subtensor(x, *i) - elif isinstance(node.op, AdvancedIncSubtensor1): - xi = advanced_subtensor1(x, *i) - else: - raise Exception("Should never happen!") - - reason = "local_useless_incsubtensor_alloc" - - # Add `xi` to the shape feature `fgraph`. This is important for - # shape inference later because the variable must be part of the - # function graph in order to call `same_shape` on it. - if xi not in shape_of: - shape_feature.on_import(fgraph, xi.owner, f"{reason}: add `xi`") - - # `xi` may have more dimensions than `y` since the subtensor ops - # do automatic broadcasting of the increment internally. Thus, we - # need to make the leading implicitly broadcasted dimensions - # explicit for shape comparison later. - if xi.ndim > y.ndim: - y = shape_padleft(y, xi.ndim - y.ndim) - if y not in shape_of: - shape_feature.on_import(fgraph, y.owner, f"{reason}: add `y`") - - # Build `z_broad` explicitly to include extra implicit dimensions. - z_broad = (True,) * (xi.ndim - z.ndim) + z.broadcastable - - cond = [ - # The shapes of `y` and `xi` must either agree or `y` may - # also have shape equal to 1 which may be treated as a - # broadcastable dimension by the subtensor op. - or_(eq(y.shape[k], 1), eq(y.shape[k], xi.shape[k])) - # Loop over all dimensions. - for k in range(xi.ndim) - # We need to check the above shapes, if - # * the pre-alloc increment `z` is broadcastable in - # dimension `k` (if it isn't, then the shapes of `z` and - # `y` are the same by the definition of the `Alloc` op in - # this dimension and replacing `y` by `z` will not hide a - # shape error), and - # * `xi` and `y` do not have the same shape in dimension - # `k` or we cannot infer the shape statically (if the - # shapes of `xi` and `y` are not the same, then replacing - # `y` by `z` will hide the shape error of `y`), and - # * the shape of `y` is not equal to 1 or we cannot infer - # the shape statically (if the shape of `y` is equal to - # 1, then `y` is broadcasted by the inc_subtensor op - # internally, so the shapes of `xi` and `y` do not need - # to match in dimension `k`; else we need to check at - # runtime that the shape of `y` is either 1 or the same - # as `xi` or otherwise replacing `y` by `z` will hide a - # shape error). - if ( - z_broad[k] - and not same_shape(xi, y, dim_x=k, dim_y=k) - and shape_of[y][k] != 1 - ) - ] - - if len(cond) > 0: - msg = "`x[i]` and `y` do not have the same shape." - z = Assert(msg)(z, *cond) - - r = node.op(x, z, *i) - # Copy over stacktrace from previous output, since - # we don't expect problems when removing the intermediate - # alloc operation and so we still want to point at the line - # of the inc_subtensor operation. - copy_stack_trace(node.outputs, r) - - return [r] - - -@register_specialize -@register_canonicalize -@local_optimizer([Subtensor]) -def local_subtensor_shape_constant(fgraph, node): - r"""Simplify constant `Subtensor`\s on `Shape`\s dimensions that are known. - - We want to convert graphs like - - Subtensor{int64} [id A] '' - |Shape [id B] '' - | | [id C] - |ScalarConstant{0} [id D] - - into - - TensorConstant{1} - - TODO: Something like `local_shape_to_shape_i` should be a general - canonicalization, and not a `ShapeFeature`-dependent rewrite. If that were - the case, we could change this to only operate on `Shape_i`\s. - Currently, we're not handling them because they should only appear when - `ShapeFeature` is present, and it will also simplify/remove them. - - """ - if not isinstance(node.op, Subtensor): - return False - - shape = node.inputs[0] - - if not (shape.owner and isinstance(shape.owner.op, Shape)): - return False - - shape_arg = shape.owner.inputs[0] - - (idx,) = get_idx_list(node.inputs, node.op.idx_list) - - try: - idx_val = as_index_literal(idx) - except NotScalarConstantError: - return False - - assert idx_val != np.newaxis - - if not isinstance(shape_arg.type, TensorType): - return False - - shape_parts = shape_arg.type.broadcastable[idx_val] - - if isinstance(shape_parts, Iterable): - if all(shape_parts): - return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)] - elif shape_parts: - return [as_tensor(1, dtype=np.int64)] - - -@register_canonicalize -@local_optimizer([Subtensor]) -def local_subtensor_SpecifyShape_lift(fgraph, node): - """Lift ``specify_shape(x, s)[i_1, ..., i_n]`` to ``specify_shape(x[i1, ... , i_n], s[n:])``.""" - - if not isinstance(node.op, Subtensor): - return False - - specify_shape_node = node.inputs[0] - - if not ( - specify_shape_node.owner - and isinstance(specify_shape_node.owner.op, SpecifyShape) - ): - return False - - obj_arg = specify_shape_node.owner.inputs[0] - shape_arg = specify_shape_node.owner.inputs[1:] - - indices = get_idx_list(node.inputs, node.op.idx_list) - - if any( - isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType) - for index in indices - ): - return False - - new_obj_arg = obj_arg[indices] - # No need to specify shape for scalar outputs - if new_obj_arg.ndim == 0: - return [new_obj_arg] - return [specify_shape(new_obj_arg, shape_arg[len(indices) :])] - - -@register_specialize -@local_optimizer([Join]) -def local_join_subtensors(fgraph, node): - r"""Simplify contiguous :class:`Subtensor`\s inside a :class:`Join`. - - `join((x[:3], x[3:5]), axis=0) -> x[:5]` - """ - # TODO: Generalize to AdvancedSubtensors - - axis, tensors = node.inputs[0], node.inputs[1:] - - try: - axis = get_scalar_constant_value(axis) - except NotScalarConstantError: - return - - for subtensor1_idx, (subtensor1, subtensor2) in enumerate( - zip(tensors[:-1], tensors[1:]) - ): - # Check that two consecutive Subtensors are operating on the same base tensor - if not ( - ( - subtensor1.owner is not None - and isinstance(subtensor1.owner.op, Subtensor) - ) - and ( - subtensor2.owner is not None - and isinstance(subtensor2.owner.op, Subtensor) - ) - and (subtensor1.owner.inputs[0] is subtensor2.owner.inputs[0]) - ): - continue - - # Check that subtensors have consecutive indexes across the join axis - idxs_subtensor1 = indices_from_subtensor( - subtensor1.owner.inputs[1:], subtensor1.owner.op.idx_list - ) - idxs_subtensor2 = indices_from_subtensor( - subtensor2.owner.inputs[1:], subtensor2.owner.op.idx_list - ) - try: - idxs_axis_subtensor1 = idxs_subtensor1[axis] - idxs_axis_subtensor2 = idxs_subtensor2[axis] - except IndexError: - continue - if not ( - isinstance(idxs_axis_subtensor1, slice) - and isinstance(idxs_axis_subtensor2, slice) - ): - continue - start_subtensor1, stop_subtensor1, step_subtensor1 = ( - idxs_axis_subtensor1.start, - idxs_axis_subtensor1.stop, - idxs_axis_subtensor1.step, - ) - start_subtensor2, stop_subtensor2, step_subtensor2 = ( - idxs_axis_subtensor2.start, - idxs_axis_subtensor2.stop, - idxs_axis_subtensor2.step, - ) - if not ( - (stop_subtensor1 is not None and start_subtensor2 is not None) - and (stop_subtensor1 == start_subtensor2) - ): - continue - - # Check that step is None or 1 - # For non-unit steps (perhaps except for -1) we would need to know the - # exact values of start and stop to know if they can be merged - for step in (step_subtensor1, step_subtensor2): - if step is None: - continue - try: - if get_scalar_constant_value(step, only_process_constants=True) != 1: - return None - except NotScalarConstantError: - return None - - # Check that all other idxs of subtensor are the same - if all( - idxs_nonaxis_subtensor1 == idxs_nonaxis_subtensor2 - for i, (idxs_nonaxis_subtensor1, idxs_nonaxis_subtensor2) in enumerate( - zip(idxs_subtensor1, idxs_subtensor2) - ) - if i != axis - ): - - base_tensor = subtensor1.owner.inputs[0] - new_idxs = list(idxs_subtensor1) - new_idxs[axis] = slice(start_subtensor1, stop_subtensor2, step_subtensor1) - merged_subtensors = base_tensor[new_idxs] - new_joined_tensors = [ - *tensors[:subtensor1_idx], - merged_subtensors, - *tensors[subtensor1_idx + 2 :], - ] - if len(new_joined_tensors) > 1: - return [concatenate(new_joined_tensors, axis=axis)] - else: - return [merged_subtensors] +from aesara.tensor.rewriting.subtensor import * # noqa: F401 E402 F403 diff --git a/aesara/tensor/type.py b/aesara/tensor/type.py index 115c609b5b..06cf964142 100644 --- a/aesara/tensor/type.py +++ b/aesara/tensor/type.py @@ -92,9 +92,12 @@ def __init__( ) shape = broadcastable - if isinstance(dtype, str) and dtype == "floatX": + if str(dtype) == "floatX": self.dtype = config.floatX else: + if np.obj2sctype(dtype) is None: + raise TypeError(f"Invalid dtype: {dtype}") + self.dtype = np.dtype(dtype).name def parse_bcast_and_shape(s): @@ -319,16 +322,13 @@ def is_super(self, otype): def convert_variable(self, var): if self.is_super(var.type): - # `var.type` is at least as specific as `self`, so we return - # `var` as-is + # `var.type` is as specific as `self`, so we return `var` as-is return var - elif var.type.is_super(self): - # `var.type` is less specific than `self`, so we convert - # `var` to `self`'s `Type`. - # Note that, in this case, `var.type != self`, because that's - # covered by the branch above. - # Use the more specific static shape information of the two + if (self.ndim == var.type.ndim) and (self.dtype == var.type.dtype): + # `var.type` only differs from `self` in that its shape is (at least partially) + # less specific than `self`, so we convert `var` to `self`'s `Type`. + # `specify_shape` will combine the more precise shapes of the two types return aesara.tensor.specify_shape(var, self.shape) def value_zeros(self, shape): @@ -778,6 +778,10 @@ def tensor(*args, **kwargs): wscalar = TensorType("int16", ()) iscalar = TensorType("int32", ()) lscalar = TensorType("int64", ()) +ubscalar = TensorType("uint8", ()) +uwscalar = TensorType("uint16", ()) +uiscalar = TensorType("uint32", ()) +ulscalar = TensorType("uint64", ()) def scalar(name=None, dtype=None): diff --git a/aesara/tensor/utils.py b/aesara/tensor/utils.py index e8ed778703..88e73d1ef3 100644 --- a/aesara/tensor/utils.py +++ b/aesara/tensor/utils.py @@ -63,7 +63,9 @@ def shape_of_variables(fgraph, input_shapes): """ if not hasattr(fgraph, "shape_feature"): - fgraph.attach_feature(aesara.tensor.basic_opt.ShapeFeature()) + from aesara.tensor.rewriting.shape import ShapeFeature + + fgraph.attach_feature(ShapeFeature()) input_dims = [ dimension diff --git a/aesara/tensor/var.py b/aesara/tensor/var.py index a5a7ad457e..8b281e6bd0 100644 --- a/aesara/tensor/var.py +++ b/aesara/tensor/var.py @@ -15,6 +15,7 @@ from aesara.tensor import _get_vector_length, as_tensor_variable from aesara.tensor.exceptions import AdvancedIndexingError from aesara.tensor.type import TensorType +from aesara.tensor.type_other import NoneConst from aesara.tensor.utils import hash_from_ndarray @@ -466,7 +467,7 @@ def includes_bool(args_el): ellipses = [] index_dim_count = 0 for i, arg in enumerate(args): - if arg is np.newaxis: + if arg is np.newaxis or arg is NoneConst: # no increase in index_dim_count pass elif arg is Ellipsis: @@ -515,13 +516,13 @@ def is_empty_array(val): isinstance(val, np.ndarray) and val.size == 0 ) - # Force input to be int64 datatype if input is an empty list or tuple + # Force input to be an int datatype if input is an empty list or tuple # Else leave it as is if it is a real number # Convert python literals to aesara constants args = tuple( [ at.subtensor.as_index_constant( - np.array(inp, dtype=np.int64) if is_empty_array(inp) else inp + np.array(inp, dtype=np.uint8) if is_empty_array(inp) else inp ) for inp in args ] @@ -537,7 +538,7 @@ def is_empty_array(val): advanced = True break - if arg is not np.newaxis: + if arg is not np.newaxis and arg is not NoneConst: try: at.subtensor.index_vars_to_types(arg) except AdvancedIndexingError: @@ -549,7 +550,7 @@ def is_empty_array(val): if advanced: return at.subtensor.advanced_subtensor(self, *args) else: - if np.newaxis in args: + if np.newaxis in args or NoneConst in args: # `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new # broadcastable dimension at this location". Since Aesara adds # new broadcastable dimensions via the `DimShuffle` `Op`, the @@ -561,7 +562,7 @@ def is_empty_array(val): pattern = [] new_args = [] for arg in args: - if arg == np.newaxis: + if arg is np.newaxis or arg is NoneConst: pattern.append("x") new_args.append(slice(None, None, None)) else: @@ -579,9 +580,9 @@ def is_empty_array(val): # with some symbolic variable. if not ( isinstance(arg, slice) - and arg.start is None - and arg.stop is None - and arg.step is None + and (arg.start is None or arg.start is NoneConst) + and (arg.stop is None or arg.stop is NoneConst) + and (arg.step is None or arg.step is NoneConst) ): full_slices = False if full_slices: @@ -877,10 +878,18 @@ def _get_vector_length_TensorVariable(op_or_var, var): class TensorConstantSignature(tuple): - """ - A Signature object for comparing TensorConstant instances. + r"""A signature object for comparing `TensorConstant` instances. + + An instance is a pair with the type ``(Type, ndarray)``. + + TODO FIXME: Subclassing `tuple` is unnecessary, and it appears to be + preventing the use of a much more convenient `__init__` that removes the + need for all these lazy computations and their safety checks. + + Also, why do we even need this signature stuff? We could simply implement + good `Constant.__eq__` and `Constant.__hash__` implementations. - An instance is a pair: (Type instance, ndarray). + We could also produce plain `tuple`\s with hashable values. """ @@ -929,19 +938,27 @@ def aesara_hash(self): _, d = self return hash_from_ndarray(d) - def _get_sum(self): + @property + def sum(self): """Compute sum of non NaN / Inf values in the array.""" try: return self._sum except AttributeError: - self._sum = self.no_nan.sum() - # The following 2 lines are needede as in Python 3.3 with NumPy + + # Prevent warnings when there are `inf`s and `-inf`s present + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + self._sum = self.no_nan.sum() + + # The following 2 lines are needed as in Python 3.3 with NumPy # 1.7.1, numpy.ndarray and numpy.memmap aren't hashable. if isinstance(self._sum, np.memmap): self._sum = np.asarray(self._sum).item() + if self.has_nan and self.no_nan.mask.all(): # In this case the sum is not properly computed by numpy. self._sum = 0 + if np.isinf(self._sum) or np.isnan(self._sum): # NaN may happen when there are both -inf and +inf values. if self.has_nan: @@ -956,25 +973,22 @@ def _get_sum(self): self._sum = np.ma.masked_array(self[1], mask).sum() # At this point there should be no more NaN. assert not np.isnan(self._sum) - return self._sum - sum = property(_get_sum) + if isinstance(self._sum, np.ma.core.MaskedConstant): + self._sum = 0 + + return self._sum - def _get_no_nan(self): + @property + def no_nan(self): try: return self._no_nan except AttributeError: - nan_mask = np.isnan(self[1]) - if nan_mask.any(): - self._no_nan = np.ma.masked_array(self[1], nan_mask) - self.has_nan = True - else: - self._no_nan = self[1] - self.has_nan = False + nans = np.isnan(self[1]) + self._no_nan = np.ma.masked_array(self[1], nans) + self.has_nan = np.any(nans) return self._no_nan - no_nan = property(_get_no_nan) - def get_unique_value(x: TensorVariable) -> Optional[Number]: """Return the unique value of a tensor, if there is one""" diff --git a/aesara/typed_list/__init__.py b/aesara/typed_list/__init__.py index 7ebb89a826..75635b2f31 100644 --- a/aesara/typed_list/__init__.py +++ b/aesara/typed_list/__init__.py @@ -1,3 +1,3 @@ -from . import opt -from .basic import * -from .type import TypedListType +from aesara.typed_list import rewriting +from aesara.typed_list.basic import * +from aesara.typed_list.type import TypedListType diff --git a/aesara/typed_list/opt.py b/aesara/typed_list/rewriting.py similarity index 55% rename from aesara/typed_list/opt.py rename to aesara/typed_list/rewriting.py index 5254a88ecd..ee57425d32 100644 --- a/aesara/typed_list/opt.py +++ b/aesara/typed_list/rewriting.py @@ -1,10 +1,10 @@ from aesara.compile import optdb -from aesara.graph.opt import TopoOptimizer, local_optimizer +from aesara.graph.rewriting.basic import WalkingGraphRewriter, node_rewriter from aesara.typed_list.basic import Append, Extend, Insert, Remove, Reverse -@local_optimizer([Append, Extend, Insert, Reverse, Remove], inplace=True) -def typed_list_inplace_opt(fgraph, node): +@node_rewriter([Append, Extend, Insert, Reverse, Remove], inplace=True) +def typed_list_inplace_rewrite(fgraph, node): if ( isinstance(node.op, (Append, Extend, Insert, Reverse, Remove)) and not node.op.inplace @@ -17,8 +17,10 @@ def typed_list_inplace_opt(fgraph, node): optdb.register( - "typed_list_inplace_opt", - TopoOptimizer(typed_list_inplace_opt, failure_callback=TopoOptimizer.warn_inplace), + "typed_list_inplace_rewrite", + WalkingGraphRewriter( + typed_list_inplace_rewrite, failure_callback=WalkingGraphRewriter.warn_inplace + ), "fast_run", "inplace", position=60, diff --git a/aesara/utils.py b/aesara/utils.py index d6e8537e1e..613d39ad16 100644 --- a/aesara/utils.py +++ b/aesara/utils.py @@ -158,12 +158,18 @@ def deprecated(message: str = ""): def decorator_wrapper(func): @wraps(func) def function_wrapper(*args, **kwargs): + nonlocal message + current_call_source = "|".join( traceback.format_stack(inspect.currentframe()) ) if current_call_source not in function_wrapper.last_call_source: + + if not message: + message = f"Function {func.__name__} is deprecated." + warnings.warn( - "Function {} is now deprecated! {}".format(func.__name__, message), + message, category=DeprecationWarning, stacklevel=2, ) diff --git a/conftest.py b/conftest.py index 8971c59919..bd4f8bab67 100644 --- a/conftest.py +++ b/conftest.py @@ -7,7 +7,7 @@ def pytest_sessionstart(session): os.environ["AESARA_FLAGS"] = ",".join( [ os.environ.setdefault("AESARA_FLAGS", ""), - "warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise", + "warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,cmodule__warn_no_version=True", ] ) diff --git a/doc/environment.yml b/doc/environment.yml index 8bbc5c30d0..93fc557e2e 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -1,7 +1,7 @@ name: aesara-docs channels: - - defaults - conda-forge + - nodefaults dependencies: - python=3.7 - gcc_linux-64 @@ -13,5 +13,6 @@ dependencies: - sphinx_rtd_theme - mock - pillow + - pip - pip: - -e ..[doc] diff --git a/doc/extending/creating_a_c_op.rst b/doc/extending/creating_a_c_op.rst index abe653d219..3c0a56736b 100644 --- a/doc/extending/creating_a_c_op.rst +++ b/doc/extending/creating_a_c_op.rst @@ -408,6 +408,9 @@ commonly used. this function should return a tuple of integers as previously described. + Also, do not use the built-in ``hash``; it will produce different values + between Python sessions and confound the caching process. + Important restrictions when implementing a :class:`COp` ======================================================= @@ -719,8 +722,8 @@ simple but it still involves defining many methods as well as mixing, in the same file, both Python and C code which tends to make the result less readable. -To help with this, Aesara defines a class, ``ExternalCOp``, from which new C :class:`Op`\s -can inherit. The class ``ExternalCOp`` aims to simplify the process of implementing +To help with this, Aesara defines a class, `ExternalCOp`, from which new C :class:`Op`\s +can inherit. The class `ExternalCOp` aims to simplify the process of implementing C :class:`Op`\s by doing the following : * It allows you to define the C implementation of your :class:`Op` in a distinct @@ -728,15 +731,15 @@ C :class:`Op`\s by doing the following : readable and well indented. * It can automatically handle all the methods that return C code, - in addition to :meth:`Op.c_code_cache_version()` based on the + in addition to :meth:`Op.c_code_cache_version` based on the provided external C implementation. -To illustrate how much simpler the class ``ExternalCOp`` makes the process of defining +To illustrate how much simpler the class `ExternalCOp` makes the process of defining a new :class:`Op` with a C implementation, let's revisit the second example of this -tutorial, the ``VectorTimesVector`` :class:`Op`. In that example, we implemented an :class:`Op` +tutorial, the `VectorTimesVector`\ :class:`Op`. In that example, we implemented an :class:`Op` to perform the task of element-wise vector-vector multiplication. The two following blocks of code illustrate what the :class:`Op` would look like if it was -implemented using the ``ExternalCOp`` class. +implemented using the `ExternalCOp` class. The new :class:`Op` is defined inside a Python file with the following code : @@ -850,23 +853,23 @@ in the same file and the C code contained string formatting markers. Now that we have motivated the `ExternalCOp` class, we can have a more precise look at what it does for us. For this, we go through the various elements that make up -this new version of the ``VectorTimesVector`` `Op` : +this new version of the `VectorTimesVector`\ `Op` : * Parent class : instead of inheriting from the class :class:`Op`, - VectorTimesVector inherits from the class ``ExternalCOp``. + VectorTimesVector inherits from the class `ExternalCOp`. -* Constructor : in our new `COp`, the ``__init__()`` method has an - important use; to inform the constructor of the ``ExternalCOp`` class +* Constructor : in our new `COp`, the :meth:`COp.__init__` method has an + important use; to inform the constructor of the `ExternalCOp` class of the location, on the filesystem of the C implementation of this `COp`. To do this, it gives a list of file paths containing the C code for this `COp`. To auto-generate the c_code method with a function call you can specify the function name as the second parameter. The paths should be given as a relative - path from the folder where the descendant of the ``ExternalCOp`` class + path from the folder where the descendant of the `ExternalCOp` class is defined. -* ``make_node()`` : the ``make_node()`` method is absolutely - identical to the one in our old example. Using the ``ExternalCOp`` +* :meth:`ExternalCOp.make_node` : this method is absolutely + identical to the one in our old example. Using the `ExternalCOp` class doesn't change anything here. * External C code : the external C code implements the various @@ -877,8 +880,8 @@ this new version of the ``VectorTimesVector`` `Op` : Main function ------------- -If you pass a function name to the ``__init__()`` method of the -``ExternalCOp`` class, it must respect the following constraints: +If you pass a function name to :meth:`ExternalCOp.__init___`, it must respect +the following constraints: * It must return an int. The value of that int indicates whether the `Op` could perform its task or not. A value of 0 indicates diff --git a/doc/extending/creating_an_op.rst b/doc/extending/creating_an_op.rst index c908bef49a..a5b906b0e6 100644 --- a/doc/extending/creating_an_op.rst +++ b/doc/extending/creating_an_op.rst @@ -9,7 +9,7 @@ a function that does what you want. If you can implement something in terms of an existing :ref:`Op`, you should do that. Odds are your function that uses existing Aesara expressions is short, -has no bugs, and potentially profits from optimizations that have already been +has no bugs, and potentially profits from rewrites that have already been implemented. However, if you cannot implement an :class:`Op` in terms of an existing :class:`Op`, you have to @@ -55,8 +55,8 @@ details how to write such an :class:`Op` instance. Please refers to structure. -Op's basic methods ------------------- +:class:`Op`'s basic methods +--------------------------- An :class:`Op` is any Python object which inherits from :class:`Op`. This section provides an overview of the basic methods you typically have to @@ -114,11 +114,11 @@ possibilities you may encounter or need. For that refer to An :class:`Op` has to implement some methods defined in the the interface of :class:`Op`. More specifically, it is mandatory for an :class:`Op` to define either -the method :func:`make_node` or :attr:`itypes`, :attr:`otypes` and one of the -implementation methods, either :func:`perform`, :meth:`COp.c_code` -or :func:`make_thunk`. +the method :meth:`Op.make_node` or :attr:`Op.itypes`, :attr:`Op.otypes` and one of the +implementation methods, either :meth:`Op.perform`, :meth:`COp.c_code` +or :meth:`Op.make_thunk`. - :func:`make_node` method creates an Apply node representing the application + :meth:`Op.make_node` method creates an Apply node representing the application of the :class:`Op` on the inputs provided. This method is responsible for three things: - it first checks that the input :class:`Variable`\s types are compatible @@ -134,11 +134,11 @@ or :func:`make_thunk`. - :func:`perform` method defines the Python implementation of an :class:`Op`. + :meth:`Op.perform` method defines the Python implementation of an :class:`Op`. It takes several arguments: - ``node`` is a reference to an Apply node which was previously - obtained via the :func:`make_node` method. It is typically not + obtained via the :meth:`Op.make_node` method. It is typically not used in a simple :class:`Op`, but it contains symbolic information that could be required by a complex :class:`Op`. - ``inputs`` is a list of references to data which can be operated on using @@ -156,7 +156,7 @@ or :func:`make_thunk`. preallocated in the ``output_storage``, it will be of the good dtype, but can have the wrong shape and have any stride pattern. - :func:`perform` method must be determined by the inputs. That is to say, + :meth:`Op.perform` method must be determined by the inputs. That is to say, when applied to identical inputs the method must return the same outputs. An :class:`Op`\s implementation can be defined in other ways, as well. @@ -165,7 +165,7 @@ or :func:`make_thunk`. :meth:`COp.c_code` and other related ``c_**`` methods. Note that an :class:`Op` can provide both Python and C implementations. - :func:`make_thunk` method is another alternative to :func:`perform`. + :meth:`Op.make_thunk` method is another alternative to :meth:`Op.perform`. It returns a thunk. A thunk is defined as a zero-arguments function which encapsulates the computation to be performed by an :class:`Op` on the arguments of its corresponding node. It takes several parameters: @@ -187,78 +187,73 @@ or :func:`make_thunk`. - ``impl`` allow to select between multiple implementation. It should have a default value of ``None``. - :func:`make_thunk` is useful if you want to generate code and compile + :meth:`Op.make_thunk` is useful if you want to generate code and compile it yourself. - If :func:`make_thunk()` is defined by an :class:`Op`, it will be used by Aesara + If :meth:`Op.make_thunk` is defined by an :class:`Op`, it will be used by Aesara to obtain the :class:`Op`'s implementation. - :func:`perform` and :meth:`COp.c_code` will be ignored. + :meth:`Op.perform` and :meth:`COp.c_code` will be ignored. - If :func:`make_node` is not defined, the :attr:`itypes` and :attr:`otypes` - are used by the :class:`Op`'s :func:`make_node` method to implement the functionality - of :func:`make_node` method mentioned above. + If :meth:`Op.make_node` is not defined, the :attr:`Op.itypes` and :attr:`Op.otypes` + are used by the :class:`Op`'s :meth:`Op.make_node` method to implement the functionality + of :meth:`Op.make_node` method mentioned above. :class:`Op`'s auxiliary methods ------------------------------- There are other methods that can be optionally defined by the :class:`Op`: - The :func:`__str__` method provides a meaningful string representation of - your :class:`Op`. - - :func:`__eq__` and :func:`__hash__` define respectivelly equality + :meth:`Op.__eq__` and :meth:`Op.__hash__` define respectivelly equality between two :class:`Op`\s and the hash of an :class:`Op` instance. - They will be used by the optimization - phase to merge nodes that are doing equivalent computations (same - inputs, same operation). - Two :class:`Op`\s that are equal according :func:`__eq__` + They will be used during the rewriting phase to merge nodes that are doing + equivalent computations (same inputs, same operation). + Two :class:`Op`\s that are equal according :meth:`Op.__eq__` should return the same output when they are applied on the same inputs. - The :attr:`__props__` attribute lists the properties that influence how the computation - is performed (usually these are set in :func:`__init__`). It must be a tuple. + The :attr:`Op.__props__` attribute lists the properties that influence how the computation + is performed. Usually these are set in :meth:`Op.__init__`. It must be a tuple. If you don't have any properties, then you should set this attribute to the empty tuple ``()``. - :attr:`__props__` enables the automatic generation of appropriate - :func:`__eq__` and :func:`__hash__`. + :attr:`Op.__props__` enables the automatic generation of appropriate + :meth:`Op.__eq__` and :meth:`Op.__hash__`. Given the method :func:`__eq__`, automatically generated from - :attr:`__props__`, two :class:`Op`\s will be equal if they have the same values for all - the properties listed in :attr:`__props__`. - Given to the method :func:`__hash__` automatically generated from - :attr:`__props__`, two :class:`Op`\s will be have the same hash if they have the same - values for all the properties listed in :attr:`__props__`. - :attr:`__props__` will also generate a suitable :func:`__str__` for your :class:`Op`. - This requires development version after September 1st, 2014 or version 0.7. - - The :func:`infer_shape` method allows an :class:`Op` to infer the shape of its + :attr:`Op.__props__`, two :class:`Op`\s will be equal if they have the same values for all + the properties listed in :attr:`Op.__props__`. + Given to the method :meth:`Op.__hash__` automatically generated from + :attr:`Op.__props__`, two :class:`Op`\s will be have the same hash if they have the same + values for all the properties listed in :attr:`Op.__props__`. + :attr:`Op.__props__` will also generate a suitable :meth:`Op.__str__` for your :class:`Op`. + + The :meth:`Op.infer_shape` method allows an :class:`Op` to infer the shape of its output variables without actually computing them. It takes as input ``fgraph``, a :class:`FunctionGraph`; ``node``, a reference to the :class:`Op`'s :class:`Apply` node; and a list of :class:`Variables`\s (e.g. ``i0_shape``, ``i1_shape``, ...) which are the dimensions of the :class:`Op` input :class:`Variable`\s. - :func:`infer_shape` returns a list where each element is a tuple representing + :meth:`Op.infer_shape` returns a list where each element is a tuple representing the shape of one output. This could be helpful if one only needs the shape of the output instead of the - actual outputs, which can be useful, for instance, for optimization + actual outputs, which can be useful, for instance, for rewriting procedures. - The :func:`grad` method is required if you want to differentiate some cost + The :meth:`Op.grad` method is required if you want to differentiate some cost whose expression includes your :class:`Op`. The gradient may be specified symbolically in this method. It takes two arguments ``inputs`` and ``output_gradients``, which are both lists of :class:`Variable`\s, and - those must be operated on using Aesara's symbolic language. The :func:`grad` + those must be operated on using Aesara's symbolic language. The :meth:`Op.grad` method must return a list containing one :class:`Variable` for each input. Each returned :class:`Variable` represents the gradient with respect to that input computed based on the symbolic gradients with respect to each output. If the output is not differentiable with respect to an input then - this method should be defined to return a variable of type ``NullType`` - for that input. Likewise, if you have not implemented the grad + this method should be defined to return a variable of type :class:`NullType` + for that input. Likewise, if you have not implemented the gradient computation for some input, you may return a variable of type - ``NullType`` for that input. Please refer to :func:`grad` for a more detailed + :class:`NullType` for that input. Please refer to :meth:`Op.grad` for a more detailed view. - The :func:`R_op` method is needed if you want ``aesara.gradient.Rop`` to + The :meth:`Op.R_op` method is needed if you want :func:`aesara.gradient.Rop` to work with your :class:`Op`. This function implements the application of the R-operator on the function represented by your :class:`Op`. Let assume that function is :math:`f`, @@ -541,7 +536,7 @@ How To Test it -------------- Aesara has some functionalities to simplify testing. These help test the -:meth:`infer_shape`, :meth:`grad` and :meth:`R_op` methods. Put the following code +:meth:`Op.infer_shape`, :meth:`Op.grad` and :meth:`Op.R_op` methods. Put the following code in a file and execute it with the ``pytest`` program. Basic Tests @@ -549,8 +544,8 @@ Basic Tests Basic tests are done by you just by using the :class:`Op` and checking that it returns the right answer. If you detect an error, you must raise an -*exception*. You can use the ``assert`` keyword to automatically raise an -``AssertionError``. +exception. You can use the ``assert`` keyword to automatically raise an +`AssertionError`. .. testcode:: tests @@ -583,14 +578,14 @@ the default tolerance can be changed with the Aesara flags default value do the most strict comparison, 1 and 2 make less strict comparison. -Testing the infer_shape -^^^^^^^^^^^^^^^^^^^^^^^ +Testing the :meth:`Op.infer_shape` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ When a class inherits from the :class:`InferShapeTester` class, it gets the :meth:`InferShapeTester._compile_and_check` method that tests the :meth:`Op.infer_shape` -method. It tests that the :class:`Op` gets optimized out of the graph if only +method. It tests that the :class:`Op` gets rewritten out of the graph if only the shape of the output is needed and not the output -itself. Additionally, it checks that the optimized graph computes +itself. Additionally, it checks that the rewritten graph computes the correct shape, by comparing it to the actual shape of the computed output. @@ -607,7 +602,7 @@ When testing with input values with shapes that take the same value over different dimensions (for instance, a square matrix, or a ``tensor3`` with shape ``(n, n, n)``, or ``(m, n, m)``), it is not possible to detect if the output shape was computed correctly, or if some shapes with the -same value have been mixed up. For instance, if the infer_shape uses +same value have been mixed up. For instance, if the :meth:`Op.infer_shape` uses the width of a matrix instead of its height, then testing with only square matrices will not detect the problem. This is why the :meth:`InferShapeTester._compile_and_check` method prints a warning in such a case. If @@ -721,17 +716,17 @@ Modify and execute the example to return two outputs: ``x + y`` and `jx - yj`. You can omit the :meth:`Rop` functions. Try to implement the testing apparatus described above. -(Notice that Aesara's current *elemwise fusion* optimization is +(Notice that Aesara's current *elemwise fusion* rewrite is only applicable to computations involving a single output. Hence, to gain efficiency over the basic solution that is asked here, the two operations would -have to be jointly optimized explicitly in the code.) +have to be jointly rewritten explicitly in the code.) Random numbers in tests """"""""""""""""""""""" -Making tests errors more reproducible is a good practice. To make your -tests more reproducible, you need a way to get the same random -numbers. You can do this by seeding NumPy's random number +Making tests errors more reproducible is a good practice. To make +tests more reproducible, one needs a way to get the same random +numbers. This can be done by seeding NumPy's random number generator. For convenience, the classes :class:`InferShapeTester` and :class:`RopLop_checker` @@ -751,7 +746,7 @@ basic Aesara :class:`Op` that will call the supplied function during execution. This isn't the recommended way to build an :class:`Op`, but allows for a quick implementation. -It takes an optional :func:`infer_shape` parameter that must have this +It takes an optional :meth:`Op.infer_shape` parameter that must have this signature: .. code-block:: none @@ -766,8 +761,8 @@ signature: .. warning:: - Not providing a :obj:`infer_shape` prevents shape-related - optimizations from working with this :class:`Op`. For example + Not providing a :meth:`Op.infer_shape` prevents shape-related + rewrites from working with this :class:`Op`. For example ``your_op(inputs, ...).shape`` will need the :class:`Op` to be executed just to get the shape. @@ -838,7 +833,7 @@ checked it, it would generate a false positive. Another case is related to You can tell :class:`NanGuardMode` to do not check a variable with: :attr:`variable.tag.nan_guard_mode_check`. Also, this tag automatically -follow that variable during optimization. This mean if you tag a +follows that variable during rewriting. This mean if you tag a variable that get replaced by an inplace version, it will keep that tag. diff --git a/doc/extending/extending_faq.rst b/doc/extending/extending_faq.rst index 3a654eafd7..4630ae220a 100644 --- a/doc/extending/extending_faq.rst +++ b/doc/extending/extending_faq.rst @@ -5,26 +5,26 @@ Extending Aesara: FAQ and Troubleshooting ========================================= -I wrote a new Op/Type, and weird stuff is happening... ------------------------------------------------------- +I wrote a new `Op`\/`Type`, and weird stuff is happening... +----------------------------------------------------------- First, check the :ref:`op_contract` and the :ref:`type_contract` and make sure you're following the rules. -Then try running your program in :ref:`using_debugmode`. DebugMode might catch +Then try running your program in :ref:`using_debugmode`. `DebugMode` might catch something that you're not seeing. -I wrote a new optimization, but it's not getting used... ---------------------------------------------------------- +I wrote a new rewrite, but it's not getting used... +--------------------------------------------------- -Remember that you have to register optimizations with the :ref:`optdb` +Remember that you have to register rewrites with the :ref:`optdb` for them to get used by the normal modes like FAST_COMPILE, FAST_RUN, -and DebugMode. +and `DebugMode`. -I wrote a new optimization, and it changed my results even though I'm pretty sure it is correct. ------------------------------------------------------------------------------------------------- +I wrote a new rewrite, and it changed my results even though I'm pretty sure it is correct. +------------------------------------------------------------------------------------------- First, check the :ref:`op_contract` and make sure you're following the rules. -Then try running your program in :ref:`using_debugmode`. DebugMode might +Then try running your program in :ref:`using_debugmode`. `DebugMode` might catch something that you're not seeing. diff --git a/doc/extending/graph_rewriting.rst b/doc/extending/graph_rewriting.rst index bec02cfc6e..9eb8d282ec 100644 --- a/doc/extending/graph_rewriting.rst +++ b/doc/extending/graph_rewriting.rst @@ -6,80 +6,81 @@ Graph Rewriting =============== In this document we will explain how graph rewriting works and how graph -optimizations can be constructed using graph rewriting. +rewrites can be constructed in Aesara. .. todo:: - The old "optimization" nomenclature is still in use throughout these documents - and the codebase; however, this is being changed to more accurately distinguish - between general graph rewriting for any purpose and the kind that is explicitly - intended to "optimize" a graph in some way. + The old "optimization" nomenclature is still in use throughout some of these + documents and the codebase; however, this is being changed to more accurately + distinguish between general graph rewriting for any purpose and the kind that + is explicitly intended to "optimize" a graph in some way. -Global and Local Optimizations -============================== +Graph and Node Rewriters +======================== -First, let's lay out the way optimizations work in Aesara. There are -two types of optimizations: *global* optimizations and *local* -optimizations. A global optimization takes a :class:`FunctionGraph` object (see its +There are two types of basic rewriters: *graph* rewriters and *node* rewriters. + +A graph rewriter takes a :class:`FunctionGraph` object (see its :doc:`documentation ` for more details) and navigates through it -in a suitable way, replacing some :class:`Variable`\s by others in the process. A -local optimization, on the other hand, is defined as a function on a +in a suitable way, replacing some :class:`Variable`\s by others in the process. +A node rewriter, on the other hand, is defined as a function on a *single* :ref:`apply` node and must return either ``False`` (to mean that nothing is to be done) or a list of new :class:`Variable`\s that we would like to -replace the node's outputs with. A :ref:`navigator` is a special kind -of global optimization which navigates the computation graph in some -fashion (e.g. in topological order, reverse-topological order, random -order, etc.) and applies one or more local optimizations at each step. +substitute for the node's current outputs. + +Some graph rewriters navigate the computation graph in a particular fashion +(e.g. in topological order, reverse-topological order, random order, etc.) and +apply one or more node rewriters at each step. :class:`WalkingGraphRewriter` is +one such example. -Optimizations which are holistic, meaning that they must take into -account dependencies that might be all over the graph, should be -global. Optimizations that can be done with a narrow perspective are -better defined as local optimizations. The majority of optimizations -we want to define are local. +Rewriters that are holistic, meaning that they must take into +account dependencies that might be all over the graph, should usually be +graph rewriters. Rewrites that only need a narrow view of sub-graphs are +better defined as node rewrites. -.. optimizer: +.. rewriter: -Global optimization -------------------- +Graph Rewriting +--------------- -.. class:: GlobalOptimizer +.. class:: GraphRewriter .. method:: apply(fgraph) This method takes a :class:`FunctionGraph` object which contains the computation graph - and does modifications in line with what the optimization is meant - to do. This is one of the main methods of the optimizer. + and does modifications in line with what the rewriter is meant + to do. This is one of the main methods of the rewriter. .. method:: add_requirements(fgraph) This method takes a :class:`FunctionGraph` object and adds :ref:`features ` to it. These features are "plugins" that are needed - for the :meth:`GlobalOptimizer.apply` method to do its job properly. + for the :meth:`GraphRewriter.apply` method to do its job properly. - .. method:: optimize(fgraph) + .. method:: rewrite(fgraph) This is the interface function called by Aesara. It calls - :meth:`GlobalOptimizer.apply` by default. + :meth:`GraphRewriter.apply` by default. -Local optimization ------------------- +Node Rewriting +-------------- -A local optimization is an object which defines the following methods: +A node rewriter is an object which defines the following methods: -.. class:: LocalOptimizer +.. class:: NodeRewriter .. method:: transform(fgraph, node) This method takes a :class:`FunctionGraph` and an :class:`Apply` node and returns either ``False`` to signify that no changes are to be done or a list of :class:`Variable`\s which matches the length of the node's ``outputs`` - list. When the :class:`LocalOptimizer` is applied by a :class:`NavigatorOptimizer`, the outputs - of the node passed as argument to the :class:`LocalOptimizer` will be replaced by + list. When the :class:`NodeRewriter` is applied by a :class:`NodeProcessingGraphRewriter`, the outputs + of the node passed as argument to the :class:`NodeRewriter` will be replaced by the list returned. -A simplification rule +A Simplification Rule ===================== For starters, let's define the following simplification: @@ -88,23 +89,23 @@ For starters, let's define the following simplification: \frac{xy}{y} = x -We will implement it in three ways: using a global optimization, a -local optimization with a :class:`NavigatorOptimizer` and then using the :class:`PatternSub` -facility. +We will implement it in three ways: using a graph rewriter, a node rewriter with +a :class:`NodeProcessingGraphRewriter`, and then using the +:class:`PatternNodeRewriter`. -Global optimization -------------------- +Graph Rewriter Implementation +----------------------------- -Here is the code for a global optimization implementing the +Here is the code for a graph rewriter implementing the simplification described above: .. testcode:: import aesara - from aesara.graph.opt import GlobalOptimizer + from aesara.graph.rewriting.basic import GraphRewriter from aesara.graph.features import ReplaceValidate - class Simplify(GlobalOptimizer): + class Simplify(GraphRewriter): def add_requirements(self, fgraph): fgraph.attach_feature(ReplaceValidate()) @@ -136,7 +137,7 @@ another while respecting certain validation constraints. As an exercise, try to rewrite :class:`Simplify` using :class:`NodeFinder`. (Hint: you want to use the method it publishes instead of the call to toposort) -Then, in :meth:`GlobalOptimizer.apply` we do the actual job of simplification. We start by +Then, in :meth:`GraphRewriter.apply` we do the actual job of simplification. We start by iterating through the graph in topological order. For each node encountered, we check if it's a ``div`` node. If not, we have nothing to do here. If so, we put in ``x``, ``y`` and ``z`` the numerator, @@ -149,7 +150,7 @@ we can now say that ``z == (a*b)/y``. If ``y==a`` then ``z==b`` and if ``z`` by either ``a`` or ``b`` using :meth:`FunctionGraph.replace_validate`; otherwise, we do nothing. -Now, we test the optimization: +Now, we test the rewriter: >>> from aesara.scalar import float64, add, mul, true_div >>> x = float64('x') @@ -159,14 +160,14 @@ Now, we test the optimization: >>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a]) >>> e FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x)))) ->>> simplify.optimize(e) +>>> simplify.rewrite(e) >>> e FunctionGraph(add(z, mul(x, true_div(z, x)))) You can check what happens if you put many instances of :math:`\frac{xy}{y}` in the graph. Note that it sometimes won't work for reasons that have nothing to do with the quality of the -optimization you wrote. For example, consider the following: +rewrite you wrote. For example, consider the following: >>> x = float64('x') >>> y = float64('y') @@ -175,7 +176,7 @@ optimization you wrote. For example, consider the following: >>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a]) >>> e FunctionGraph(true_div(mul(add(y, z), x), add(y, z))) ->>> simplify.optimize(e) +>>> simplify.rewrite(e) >>> e FunctionGraph(true_div(mul(add(y, z), x), add(y, z))) @@ -183,14 +184,14 @@ Nothing happened here. The reason is: ``add(y, z) != add(y, z)``. That is the case for efficiency reasons. To fix this problem we first need to merge the parts of the graph that represent the same computation, using the :class:`MergeOptimizer` defined in -:mod:`aesara.graph.opt`. +:mod:`aesara.graph.rewriting.basic`. ->>> from aesara.graph.opt import MergeOptimizer ->>> MergeOptimizer().optimize(e) # doctest: +ELLIPSIS +>>> from aesara.graph.rewriting.basic import MergeOptimizer +>>> MergeOptimizer().rewrite(e) # doctest: +ELLIPSIS (0, ..., None, None, {}, 1, 0) >>> e FunctionGraph(true_div(mul(*1 -> add(y, z), x), *1)) ->>> simplify.optimize(e) +>>> simplify.rewrite(e) >>> e FunctionGraph(x) @@ -198,30 +199,30 @@ Once the merge is done, both occurrences of ``add(y, z)`` are collapsed into a single one and is used as an input in two places. Note that ``add(x, y)`` and ``add(y, x)`` are still considered to be different because Aesara has no clue that ``add`` is -commutative. You may write your own global optimizer to identify +commutative. You may write your own graph rewrite to identify computations that are identical with full knowledge of the rules of arithmetic that your Ops implement. Aesara might provide facilities for this somewhere in the future. .. note:: - :class:`FunctionGraph` is an Aesara structure intended for the optimization + :class:`FunctionGraph` is an Aesara structure intended for the rewrite phase. It is used internally by :func:`aesara.function` and is rarely exposed to the end user. -Local Optimization ------------------- +Node Rewriter Implementation +---------------------------- The local version of the above code would be the following: .. testcode:: - from aesara.graph.opt import LocalOptimizer + from aesara.graph.rewriting.basic import NodeRewriter - class LocalSimplify(LocalOptimizer): + class LocalSimplify(NodeRewriter): def transform(self, fgraph, node): if node.op == true_div: x, y = node.inputs @@ -234,7 +235,7 @@ The local version of the above code would be the following: return False def tracks(self): - # This tells certain navigators to only apply this `LocalOptimizer` + # This tells certain navigators to only apply this `NodeRewriter` # on these kinds of `Op`s return [true_div] @@ -242,7 +243,7 @@ The local version of the above code would be the following: In this case, the transformation is defined in the -:meth:`LocalOptimizer.transform` method, which is given an explicit +:meth:`NodeRewriter.transform` method, which is given an explicit :class:`Apply` node on which to work. The entire graph--as a ``fgraph``--is also provided, in case global information is needed. @@ -252,10 +253,10 @@ outputs are returned. This list must have the same length as (e.g. available via ``fgraph.clients``), then it is not used elsewhere in the graph and you can put ``None`` in the returned list to remove it. -In order to apply the local optimizer we can use it in conjunction -with a :class:`NavigatorOptimizer`. Basically, a :class:`NavigatorOptimizer` is -a global optimizer that loops through all nodes in the graph (or a well-defined -subset of them) and applies one or several local optimizers. +In order to apply the node rewriter throughout a graph, we use it in conjunction +with a :class:`NodeProcessingGraphRewriter`. A :class:`NodeProcessingGraphRewriter` is +a graph rewriter that loops through all nodes in the graph (or a well-defined +subset of them) and applies one or several node rewriters. >>> x = float64('x') >>> y = float64('y') @@ -264,69 +265,69 @@ subset of them) and applies one or several local optimizers. >>> e = aesara.graph.fg.FunctionGraph([x, y, z], [a]) >>> e FunctionGraph(add(z, mul(true_div(mul(y, x), y), true_div(z, x)))) ->>> simplify = aesara.graph.opt.TopoOptimizer(local_simplify) ->>> simplify.optimize(e) -(, 1, 5, 3, ..., ..., ...) +>>> simplify = aesara.graph.rewriting.basic.WalkingGraphRewriter(local_simplify) +>>> simplify.rewrite(e) +(, 1, 5, 3, ..., ..., ...) >>> e FunctionGraph(add(z, mul(x, true_div(z, x)))) -:class:`OpSub`, :class:`OpRemove`, :class:`PatternSub` -++++++++++++++++++++++++++++++++++++++++++++++++++++++ +:class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter`, :class:`PatternNodeRewriter` ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -Aesara defines some shortcuts to make :class:`LocalOptimizer`\s: +Aesara defines some shortcuts to make :class:`NodeRewriter`\s: -.. function:: OpSub(op1, op2) +.. function:: SubstitutionNodeRewriter(op1, op2) Replaces all uses of ``op1`` by ``op2``. In other words, the outputs of all :class:`Apply` nodes using ``op1`` by the outputs of :class:`Apply` nodes involving ``op2``, where their inputs are the same. -.. function:: OpRemove(op) +.. function:: RemovalNodeRewriter(op) Removes all uses of ``op`` in the following way: if ``y = op(x)`` then ``y`` is replaced by ``x``. ``op`` must have as many outputs as it has inputs. The first output becomes the first input, the second output becomes the second input, and so on. -.. function:: PatternSub(pattern1, pattern2) +.. function:: PatternNodeRewriter(pattern1, pattern2) Replaces all occurrences of the first pattern by the second pattern. - See :class:`PatternSub`. + See :class:`PatternNodeRewriter`. .. code:: from aesara.scalar import identity - from aesara.graph.opt import OpSub, OpRemove, PatternSub + from aesara.graph.rewriting.basic import SubstitutionNodeRewriter, RemovalNodeRewriter, PatternNodeRewriter # Replacing `add` by `mul` (this is not recommended for primarily # mathematical reasons): - add_to_mul = OpSub(add, mul) + add_to_mul = SubstitutionNodeRewriter(add, mul) # Removing `identity` - remove_identity = OpRemove(identity) + remove_identity = RemovalNodeRewriter(identity) # The "simplify" operation we've been defining in the past few # sections. Note that we need two patterns to account for the # permutations of the arguments to `mul`. - local_simplify_1 = PatternSub((true_div, (mul, 'x', 'y'), 'y'), 'x') - local_simplify_2 = PatternSub((true_div, (mul, 'x', 'y'), 'x'), 'y') + local_simplify_1 = PatternNodeRewriter((true_div, (mul, 'x', 'y'), 'y'), 'x') + local_simplify_2 = PatternNodeRewriter((true_div, (mul, 'x', 'y'), 'x'), 'y') .. note:: - :class:`OpSub`, :class:`OpRemove` and :class:`PatternSub` produce local optimizers, which - means that everything we said previously about local optimizers - apply (e.g. they need to be wrapped in a :class:`NavigatorOptimizer`, etc.) + :class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter` and :class:`PatternNodeRewriter` produce node rewriters, which + means that everything we said previously about node rewriters + apply (e.g. they need to be wrapped in a :class:`NodeProcessingGraphRewriter`, etc.) -When an optimization can be naturally expressed using :class:`OpSub`, :class:`OpRemove` -or :class:`PatternSub`, it is highly recommended to use them. +When a rewriter can be naturally expressed using :class:`SubstitutionNodeRewriter`, :class:`RemovalNodeRewriter` +or :class:`PatternNodeRewriter`, it is highly recommended to use them. .. _unification: Unification and reification =========================== -The :class:`PatternSub` class uses `unification and reification +The :class:`PatternNodeRewriter` class uses `unification and reification `_ to implement a more succinct and reusable form of "pattern matching and replacement". In general, *use of the unification and reification tools is preferable when @@ -345,7 +346,7 @@ In order to use :func:`unify` and :func:`reify` with Aesara graphs, we need an i structure that will allow us to represent Aesara graphs that contain :class:`var`\s, because Aesara :class:`Op`\s and :class:`Apply` nodes will not accept these foreign objects as inputs. -:class:`PatternSub` uses Python ``tuple``\s to effectively represent :class:`Apply` nodes and +:class:`PatternNodeRewriter` uses Python ``tuple``\s to effectively represent :class:`Apply` nodes and ``str``\s to represent logic variables (i.e. :class:`var`\s in the :mod:`unification` library). Behind the scenes, these ``tuple``\s are converted to a ``tuple`` subclass called :class:`ExpressionTuple`\s, which behave just like normal ``tuple``\s except for some special caching features that allow for easy @@ -432,8 +433,8 @@ it does so in the context of relational operators (e.g. equations like :math:`x This means that a relation that--say--represents :math:`x + x = 2 x` can be utilized in both directions. -Currently, the local optimizer :class:`KanrenRelationSub` provides a means of -turning :mod:`kanren` relations into :class:`LocalOptimizer`\s; however, +Currently, the node rewriter :class:`KanrenRelationSub` provides a means of +turning :mod:`kanren` relations into :class:`NodeRewriter`\s; however, :mod:`kanren` can always be used directly from within a custom :class:`Rewriter`, so :class:`KanrenRelationSub` is not necessary. @@ -443,9 +444,9 @@ The following is an example that distributes dot products across additions. import aesara import aesara.tensor as at - from aesara.graph.kanren import KanrenRelationSub - from aesara.graph.opt import EquilibriumOptimizer - from aesara.graph.opt_utils import optimize_graph + from aesara.graph.rewriting.kanren import KanrenRelationSub + from aesara.graph.rewriting.basic import EquilibriumGraphRewriter + from aesara.graph.rewriting.utils import rewrite_graph from aesara.tensor.math import _dot from etuples import etuple from kanren import conso, eq, fact, heado, tailo @@ -484,10 +485,10 @@ The following is an example that distributes dot products across additions. ) - dot_distribute_opt = EquilibriumOptimizer([KanrenRelationSub(dot_distributeo)], max_use_ratio=10) + dot_distribute_rewrite = EquilibriumGraphRewriter([KanrenRelationSub(dot_distributeo)], max_use_ratio=10) -Below, we apply `dot_distribute_opt` to a few example graphs. First we create simple test graph: +Below, we apply `dot_distribute_rewrite` to a few example graphs. First we create simple test graph: >>> x_at = at.vector("x") >>> y_at = at.vector("y") @@ -498,7 +499,7 @@ Below, we apply `dot_distribute_opt` to a few example graphs. First we create s Next we apply the rewrite to the graph: ->>> res = optimize_graph(test_at, include=[], custom_opt=dot_distribute_opt, clone=False) +>>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False) >>> print(aesara.pprint(res)) ((A @ x) + (A @ y)) @@ -510,7 +511,7 @@ few more test cases: >>> test_at = A_at.dot((x_at + y_at) + (z_at + w_at)) >>> print(aesara.pprint(test_at)) (A @ ((x + y) + (z + w))) ->>> res = optimize_graph(test_at, include=[], custom_opt=dot_distribute_opt, clone=False) +>>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False) >>> print(aesara.pprint(res)) (((A @ x) + (A @ y)) + ((A @ z) + (A @ w))) @@ -519,7 +520,7 @@ few more test cases: >>> test_at = A_at.dot(x_at + (y_at + B_at.dot(z_at + w_at))) >>> print(aesara.pprint(test_at)) (A @ (x + (y + ((B @ z) + (B @ w))))) ->>> res = optimize_graph(test_at, include=[], custom_opt=dot_distribute_opt, clone=False) +>>> res = rewrite_graph(test_at, include=[], custom_rewrite=dot_distribute_rewrite, clone=False) >>> print(aesara.pprint(res)) ((A @ x) + ((A @ y) + ((A @ (B @ z)) + (A @ (B @ w))))) @@ -531,8 +532,8 @@ relational properties. To do that, we will create another :class:`Rewriter` that simply reverses the arguments to the relation :func:`dot_distributeo` and apply it to the distributed result in ``res``: ->>> dot_gather_opt = EquilibriumOptimizer([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10) ->>> rev_res = optimize_graph(res, include=[], custom_opt=dot_gather_opt, clone=False) +>>> dot_gather_rewrite = EquilibriumGraphRewriter([KanrenRelationSub(lambda x, y: dot_distributeo(y, x))], max_use_ratio=10) +>>> rev_res = rewrite_graph(res, include=[], custom_rewrite=dot_gather_rewrite, clone=False) >>> print(aesara.pprint(rev_res)) (A @ (x + (y + (B @ (z + w))))) @@ -549,32 +550,34 @@ high-level overview of miniKanren's use as a tool for symbolic computation see .. _optdb: -The optimization database (:obj:`optdb`) +The Optimization Database (:obj:`optdb`) ======================================== -Aesara exports a symbol called :obj:`optdb` which acts as a sort of -ordered database of optimizations. When you make a new optimization, -you must insert it at the proper place in the database. Furthermore, -you can give each optimization in the database a set of tags that can -serve as a basis for filtering. - -The point of :obj:`optdb` is that you might want to apply many optimizations -to a computation graph in many unique patterns. For example, you might -want to do optimization X, then optimization Y, then optimization Z. And then -maybe optimization Y is an :class:`EquilibriumOptimizer` containing :class:`LocalOptimizer`\s A, B -and C which are applied on every node of the graph until they all fail to change -it. If some optimizations act up, we want an easy way to turn them off. Ditto if -some optimizations are very CPU-intensive and we don't want to take the time to -apply them. - -The :obj:`optdb` system allows us to tag each optimization with a unique name -as well as informative tags such as 'stable', 'buggy' or -'cpu_intensive', all this without compromising the structure of the -optimizations. - -For instance, the optimization tag ``cxx_only`` is used for optimizations that +Aesara exports a symbol called :obj:`optdb` which acts as a sort of ordered +database of rewrites. When a new rewrite is constructed, it must be inserted at +the proper place in the database in order for it to be deployed during function +compilation. + +Each rewrite in a database can be assigned a set of tags that serve as a basis +for filtering/querying. + +The point of :obj:`optdb` is that one might want to apply many rewrites +to a graph in many unique patterns. + +For example, one might want to perform rewrite X, then rewrite Y, then +rewrite Z. Perhaps rewrite Y is an :class:`EquilibriumGraphRewriter` containing +:class:`NodeRewriter`\s A, B and C, which are applied on every node of until +they all fail to change it. If some rewrites fail, we may want an easy way to +turn them off. Similarly, if some rewrites are very CPU-intensive and we don't +want to take the time to apply them, then we should be able to disable them. + +The :obj:`optdb` system allows us to tag each rewrite with a unique name, +as well as informative descriptions such as 'stable', 'buggy' or +'cpu_intensive'. + +For instance, the rewrite tag ``cxx_only`` is used for rewrites that insert :class:`Op`\s that have no Python implementation (i.e. they only have C -implementations). Optimizations with this tag can be skipped when the C backend +implementations). Rewrites with this tag can be skipped when the C backend is not being used. @@ -582,164 +585,164 @@ Definition of :obj:`optdb` -------------------------- :obj:`optdb` is an object which is an instance of -:class:`SequenceDB `, -itself a subclass of :class:`OptimizationDatabase `. -There exist (for now) two types of :class:`OptimizationDatabase`, :class:`SequenceDB` and :class:`EquilibriumDB`. -When given an appropriate :class:`OptimizationQuery`, :class:`OptimizationDatabase` objects build an :class:`Optimizer` matching +:class:`SequenceDB`, +itself a subclass of :class:`RewriteDatabase`. +There exist (for now) two types of :class:`RewriteDatabase`, :class:`SequenceDB` and :class:`EquilibriumDB`. +When given an appropriate :class:`RewriteDatabaseQuery`, :class:`RewriteDatabase` objects build an :class:`Rewriter` matching the query. -A :class:`SequenceDB` contains :class:`Optimizer` or :class:`OptimizationDatabase` objects. Each of them +A :class:`SequenceDB` contains :class:`Rewriter` or :class:`RewriteDatabase` objects. Each of them has a name, an arbitrary number of tags and an integer representing their order -in the sequence. When a :class:`OptimizationQuery` is applied to a :class:`SequenceDB`, all :class:`Optimizer`\s whose -tags match the query are inserted in proper order in a :class:`SequenceOptimizer`, which -is returned. If the :class:`SequenceDB` contains :class:`OptimizationDatabase` -instances, the :class:`OptimizationQuery` will be passed to them as well and the -optimizers they return will be put in their places. - -An :class:`EquilibriumDB` contains :class:`LocalOptimizer` or :class:`OptimizationDatabase` objects. Each of them -has a name and an arbitrary number of tags. When a :class:`OptimizationQuery` is applied to -an :class:`EquilibriumDB`, all :class:`LocalOptimizer`\s that match the query are -inserted into an :class:`EquilibriumOptimizer`, which is returned. If the -:class:`SequenceDB` contains :class:`OptimizationDatabase` instances, the -:class:`OptimizationQuery` will be passed to them as well and the -:class:`LocalOptimizer`\s they return will be put in their places -(note that as of yet no :class:`OptimizationDatabase` can produce :class:`LocalOptimizer` objects, so this +in the sequence. When a :class:`RewriteDatabaseQuery` is applied to a :class:`SequenceDB`, all :class:`Rewriter`\s whose +tags match the query are inserted in proper order in a :class:`SequenceRewriter`, which +is returned. If the :class:`SequenceDB` contains :class:`RewriteDatabase` +instances, the :class:`RewriteDatabaseQuery` will be passed to them as well and the +rewriters they return will be put in their places. + +An :class:`EquilibriumDB` contains :class:`NodeRewriter` or :class:`RewriteDatabase` objects. Each of them +has a name and an arbitrary number of tags. When a :class:`RewriteDatabaseQuery` is applied to +an :class:`EquilibriumDB`, all :class:`NodeRewriter`\s that match the query are +inserted into an :class:`EquilibriumGraphRewriter`, which is returned. If the +:class:`SequenceDB` contains :class:`RewriteDatabase` instances, the +:class:`RewriteDatabaseQuery` will be passed to them as well and the +:class:`NodeRewriter`\s they return will be put in their places +(note that as of yet no :class:`RewriteDatabase` can produce :class:`NodeRewriter` objects, so this is a moot point). -Aesara contains one principal :class:`OptimizationDatabase` object, :class:`optdb`, which -contains all of Aesara's optimizers with proper tags. It is -recommended to insert new :class:`Optimizer`\s in it. As mentioned previously, -optdb is a :class:`SequenceDB`, so, at the top level, Aesara applies a sequence -of global optimizations to the computation graphs. +Aesara contains one principal :class:`RewriteDatabase` object, :class:`optdb`, which +contains all of Aesara's rewriters with proper tags. It is +recommended to insert new :class:`Rewriter`\s in it. As mentioned previously, +:obj:`optdb` is a :class:`SequenceDB`, so, at the top level, Aesara applies a sequence +of graph rewrites to the graphs it compiles. -:class:`OptimizationQuery` --------------------------- +:class:`RewriteDatabaseQuery` +----------------------------- -A :class:`OptimizationQuery` is built by the following call: +A :class:`RewriteDatabaseQuery` is built by the following call: .. code-block:: python - aesara.graph.optdb.OptimizationQuery(include, require=None, exclude=None, subquery=None) + aesara.graph.rewriting.db.RewriteDatabaseQuery(include, require=None, exclude=None, subquery=None) -.. class:: OptimizationQuery +.. class:: RewriteDatabaseQuery .. attribute:: include A set of tags (a tag being a string) such that every - optimization obtained through this :class:`OptimizationQuery` must have **one** of the tags + rewrite obtained through this :class:`RewriteDatabaseQuery` must have **one** of the tags listed. This field is required and basically acts as a starting point for the search. .. attribute:: require - A set of tags such that every optimization obtained - through this :class:`OptimizationQuery` must have **all** of these tags. + A set of tags such that every rewrite obtained + through this :class:`RewriteDatabaseQuery` must have **all** of these tags. .. attribute:: exclude - A set of tags such that every optimization obtained - through this :class:`OptimizationQuery` must have **none** of these tags. + A set of tags such that every rewrite obtained + through this :class:`RewriteDatabaseQuery` must have **none** of these tags. .. attribute:: subquery :obj:`optdb` can contain sub-databases; subquery is a - dictionary mapping the name of a sub-database to a special :class:`OptimizationQuery`. - If no subquery is given for a sub-database, the original :class:`OptimizationQuery` will be + dictionary mapping the name of a sub-database to a special :class:`RewriteDatabaseQuery`. + If no subquery is given for a sub-database, the original :class:`RewriteDatabaseQuery` will be used again. -Furthermore, a :class:`OptimizationQuery` object includes three methods, :meth:`including`, -:meth:`requiring` and :meth:`excluding`, which each produce a new :class:`OptimizationQuery` object +Furthermore, a :class:`RewriteDatabaseQuery` object includes three methods, :meth:`including`, +:meth:`requiring` and :meth:`excluding`, which each produce a new :class:`RewriteDatabaseQuery` object with the include, require, and exclude sets refined to contain the new entries. Examples -------- -Here are a few examples of how to use a :class:`OptimizationQuery` on :obj:`optdb` to produce an -:class:`Optimizer`: +Here are a few examples of how to use a :class:`RewriteDatabaseQuery` on :obj:`optdb` to produce an +:class:`Rewriter`: .. testcode:: - from aesara.graph.optdb import OptimizationQuery + from aesara.graph.rewriting.db import RewriteDatabaseQuery from aesara.compile import optdb - # This is how the optimizer for the fast_run mode is defined - fast_run = optdb.query(OptimizationQuery(include=['fast_run'])) + # This is how the rewrites for the fast_run mode are defined + fast_run = optdb.query(RewriteDatabaseQuery(include=['fast_run'])) - # This is how the optimizer for the fast_compile mode is defined - fast_compile = optdb.query(OptimizationQuery(include=['fast_compile'])) + # This is how the rewrites for the fast_compile mode are defined + fast_compile = optdb.query(RewriteDatabaseQuery(include=['fast_compile'])) - # This is the same as fast_run but no optimizations will replace + # This is the same as fast_run but no rewrites will replace # any operation by an inplace version. This assumes, of course, # that all inplace operations are tagged as 'inplace' (as they # should!) - fast_run_no_inplace = optdb.query(OptimizationQuery(include=['fast_run'], + fast_run_no_inplace = optdb.query(RewriteDatabaseQuery(include=['fast_run'], exclude=['inplace'])) -Registering an :class:`Optimizer` +Registering a :class:`Rewriter` --------------------------------- -Let's say we have a global optimizer called ``simplify``. We can add +Let's say we have a graph rewriter called ``simplify``. We can add it to :obj:`optdb` as follows: .. testcode:: - # optdb.register(name, optimizer, order, *tags) optdb.register('simplify', simplify, 'fast_run', position=0.5) -Once this is done, the ``FAST_RUN`` mode will automatically include your -optimization (since you gave it the ``'fast_run'`` tag). Of course, -already-compiled functions will see no change. The 'order' parameter -(what it means and how to choose it) will be explained in -:ref:`optdb-structure` below. +Once this is done, the ``FAST_RUN`` mode will automatically include the +rewrite, since it was given the ``'fast_run'`` tag. Of course, +already-compiled functions will see no change. The ``position`` parameter +is specific to the type of rewrite database that :obj:`obtdb` is, and +is explained in :ref:`optdb-structure`. -Registering a :class:`LocalOptimizer` -------------------------------------- +Registering a :class:`NodeRewriter` +----------------------------------- -:class:`LocalOptimizer`\s may be registered in two ways: +:class:`NodeRewriter`\s may be registered in two ways: -* Wrap them in a :class:`NavigatorOptimizer` and insert them like a global optimizer +* Wrap them in a :class:`NodeProcessingGraphRewriter` and insert them like a graph rewriter (see previous section). * Put them in an :class:`EquilibriumDB`. -Aesara defines two :class:`EquilibriumDB`\s in which one can put local -optimizations: +Aesara defines two :class:`EquilibriumDB`\s in which one can put node +rewrites: .. function:: canonicalize - This contains optimizations that aim to *simplify* the graph: + This contains rewrites that aim to put graphs in a standard "canonical" form: * Replace rare or esoterical operations with their equivalents using elementary operations. - * Order operations in a canonical way (any sequence of - multiplications and divisions can be rewritten to contain at most - one division, for example; ``x*x`` can be rewritten ``x**2``; etc.) + * Order operations in a canonical way. + For example, any sequence of multiplications and divisions can be rewritten to contain at most + one division (e.g. ``x * x`` can be rewritten to ``x**2``, etc.) - * Fold constants (``Constant(2)*Constant(2)`` becomes ``Constant(4)``) + * Fold constants (e.g. ``Constant(2) * Constant(2)`` becomes ``Constant(4)``). .. function:: specialize - This contains optimizations that aim to *specialize* the graph: + This contains rewrites that aim to *specialize* the graph: * Replace a combination of operations with a special operation that does the same thing (but better). -For each group, all optimizations of the group that are selected by -the :class:`OptimizationQuery` will be applied on the graph over and over again until none -of them is applicable, so keep that in mind when designing it: check -carefully that your optimization leads to a fixpoint (a point where it -cannot apply anymore) at which point it returns ``False`` to indicate its -job is done. Also be careful not to undo the work of another local -optimizer in the group, because then the graph will oscillate between -two or more states and nothing will get done. +For each group, all rewrites of the group that are selected by +the :class:`RewriteDatabaseQuery` will be applied on the graph over and over +again until no changes are made. + +When using :class:`EquilibriumDB`, be sure to check carefully that your rewrite +leads to a fixed-point (i.e. a graph for which the rewrite cannot be applied +anymore), at which point it returns ``False`` to indicate its job is done. Also +be careful not to undo the work of another rewrites in the group, because the +graph will oscillate between two or more states and nothing will get done. .. _optdb-structure: @@ -747,7 +750,7 @@ two or more states and nothing will get done. :obj:`optdb` structure ---------------------- -:obj:`optdb` contains the following :class:`Optimizer`\s and sub-DBs, with the given +:obj:`optdb` contains the following :class:`Rewriters`\s and sub-DBs, with the given priorities and tags: +-------+---------------------+------------------------------+ @@ -761,13 +764,13 @@ priorities and tags: +-------+---------------------+------------------------------+ | 49 | merge2 | Second merge operation | +-------+---------------------+------------------------------+ -| 49.5 | add_destroy_handler | Enable inplace optimizations | +| 49.5 | add_destroy_handler | Enable inplace rewrites | +-------+---------------------+------------------------------+ | 100 | merge3 | Third merge operation | +-------+---------------------+------------------------------+ The merge operations are meant to put together parts of the graph that -represent the same computation. Since optimizations can modify the +represent the same computation. Since rewrites can modify the graph in such a way that two previously different-looking parts of the graph become similar, we merge at the beginning, in the middle and at the very end. Technically, we only really need to do it at the end, @@ -777,38 +780,40 @@ therefore increases the efficiency of the process. See previous section for more information about the canonicalize and specialize steps. -The ``add_destroy_handler`` step is not really an optimization. It is +The ``add_destroy_handler`` step is not really an rewrite. It is a marker. Basically: .. warning:: - Any optimization which inserts inplace operations in the + Any rewrite which inserts inplace operations in the computation graph must appear **after** the ``add_destroy_handler`` - "optimizer". In other words, the priority of any such optimization + "rewriter". In other words, the priority of any such rewrites must be **>= 50**. Failure to comply by this restriction can lead to the creation of incorrect computation graphs. The reason the destroy handler is not inserted at the beginning is -that it is costly to run. It is cheaper to run most optimizations +that it is costly to run. It is cheaper to run most rewrites under the assumption there are no inplace operations. -.. _navigator: +.. _node_processing_rewriter: + +:class:`NodeProcessingGraphRewriter` +------------------------------------ -:class:`NavigatorOptimizer` ---------------------------- +.. autoclass:: aesara.graph.rewriting.basic.NodeProcessingGraphRewriter + :noindex: -WRITEME -.. _profiling_opt: +.. _profiling_rewrite: -Profiling Aesara function compilation +Profiling Aesara Function Compilation ===================================== -You find that compiling an Aesara function is taking too much time? You -can get profiling information about Aesara optimization. The normal -:ref:`Aesara profiler ` will provide you with very -high-level information. The indentation shows the included in/subset +If one finds that compiling an Aesara function is taking too much time, +profiling information about each Aesara rewrite can be obtained. The normal +:ref:`Aesara profiler ` provides some +high-level performance information. The indentation shows the included in/subset relationship between sections. The top of its output look like this: .. code-block:: none @@ -819,7 +824,7 @@ relationship between sections. The top of its output look like this: Time in 0 calls to Function.__call__: 0.000000e+00s Total compile time: 1.131874e+01s Number of Apply nodes: 50 - Aesara Optimizer time: 1.152431e+00s + Aesara rewriter time: 1.152431e+00s Aesara validate time: 2.790451e-02s Aesara Linker time (includes C, CUDA code generation/compiling): 7.893991e-02s Import time 1.153541e-02s @@ -828,12 +833,12 @@ relationship between sections. The top of its output look like this: Explanations: * ``Total compile time: 1.131874e+01s`` gives the total time spent inside `aesara.function`. -* ``Number of Apply nodes: 50`` means that after optimization, there are 50 apply node in the graph. -* ``Aesara Optimizer time: 1.152431e+00s`` means that we spend 1.15s in the ``aesara.function`` phase where we optimize (modify) the graph to make it faster / more stable numerically /... -* ``Aesara validate time: 2.790451e-02s`` means that we spent 2.8e-2s in the *validate* subset of the optimization phase. -* ``Aesara Linker time (includes C code generation/compiling): 7.893991e-02s`` means that we spent 7.9e-2s in *linker* phase of ``aesara.function``. +* ``Number of Apply nodes: 50`` means that after rewriting, there are 50 apply node in the graph. +* ``Aesara rewrite time: 1.152431e+00s`` means that we spend 1.15s in the rewriting phase of `aesara.function`. +* ``Aesara validate time: 2.790451e-02s`` means that we spent 2.8e-2s in the validation phase of rewriting. +* ``Aesara Linker time (includes C code generation/compiling): 7.893991e-02s`` means that we spent 7.9e-2s in linker phase of `aesara.function`. * ``Import time 1.153541e-02s`` is a subset of the linker time where we import the compiled module. -* ``Time in all call to aesara.grad() 4.732513e-02s`` tells that we spent a total of 4.7e-2s in all calls to ``aesara.grad``. This is outside of the calls to ``aesara.function``. +* ``Time in all call to aesara.grad() 4.732513e-02s`` tells that we spent a total of 4.7e-2s in all calls to `aesara.grad`. This is outside of the calls to `aesara.function`. The *linker* phase includes the generation of the C code, the time spent by g++ to compile and the time needed by Aesara to build the object we @@ -841,11 +846,11 @@ return. The C code generation and compilation is cached, so the first time you compile a function and the following ones could take different amount of execution time. -Detailed profiling of Aesara optimizations ------------------------------------------- +Detailed Profiling of Aesara Rewrites +------------------------------------- You can get more detailed profiling information about the Aesara -optimizer phase by setting to ``True`` the Aesara flags +rewriting phase by setting to ``True`` the Aesara flags :attr:`config.profile_optimizer` (this requires ``config.profile`` to be ``True`` as well). @@ -853,33 +858,33 @@ This will output something like this: .. code-block:: none - Optimizer Profile - ----------------- - SeqOptimizer OPT_FAST_RUN time 1.152s for 123/50 nodes before/after optimization + Rewriter Profile + ---------------- + SequentialGraphRewriter OPT_FAST_RUN time 1.152s for 123/50 nodes before/after rewriting 0.028s for fgraph.validate() 0.131s for callback time - (name, class, index) - validate time - 0.751816s - ('canonicalize', 'EquilibriumOptimizer', 4) - 0.004s - EquilibriumOptimizer canonicalize + 0.751816s - ('canonicalize', 'EquilibriumGraphRewriter', 4) - 0.004s + EquilibriumGraphRewriter canonicalize time 0.751s for 14 passes nb nodes (start, end, max) 108 81 117 time io_toposort 0.029s - time in local optimizers 0.687s - time in global optimizers 0.010s - 0 - 0.050s 27 (0.000s in global opts, 0.002s io_toposort) - 108 nodes - ('local_dimshuffle_lift', 9) ('local_upcast_elemwise_constant_inputs', 5) ('local_shape_to_shape_i', 3) ('local_fill_sink', 3) ('local_fill_to_alloc', 2) ... - 1 - 0.288s 26 (0.002s in global opts, 0.002s io_toposort) - 117 nodes - ('local_dimshuffle_lift', 8) ('local_fill_sink', 4) ('constant_folding', 4) ('local_useless_elemwise', 3) ('local_subtensor_make_vector', 3) ... - 2 - 0.044s 13 (0.002s in global opts, 0.003s io_toposort) - 96 nodes - ('constant_folding', 4) ('local_dimshuffle_lift', 3) ('local_fill_sink', 3) ('local_useless_elemwise', 1) ('local_fill_to_alloc', 1) ... - 3 - 0.045s 11 (0.000s in global opts, 0.002s io_toposort) - 91 nodes - ('constant_folding', 3) ('local_fill_to_alloc', 2) ('local_dimshuffle_lift', 2) ('local_mul_canonizer', 2) ('MergeOptimizer', 1) ... - 4 - 0.035s 8 (0.002s in global opts, 0.002s io_toposort) - 93 nodes - ('local_fill_sink', 3) ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('constant_folding', 1) - 5 - 0.035s 6 (0.000s in global opts, 0.002s io_toposort) - 88 nodes - ('local_fill_sink', 2) ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('local_mul_canonizer', 1) - 6 - 0.038s 10 (0.001s in global opts, 0.002s io_toposort) - 95 nodes - ('local_fill_sink', 3) ('local_dimshuffle_lift', 3) ('constant_folding', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) - 7 - 0.032s 5 (0.001s in global opts, 0.002s io_toposort) - 91 nodes - ('local_fill_sink', 3) ('MergeOptimizer', 1) ('local_dimshuffle_lift', 1) - 8 - 0.034s 5 (0.000s in global opts, 0.002s io_toposort) - 92 nodes - ('local_fill_sink', 3) ('MergeOptimizer', 1) ('local_greedy_distributor', 1) - 9 - 0.031s 6 (0.001s in global opts, 0.002s io_toposort) - 90 nodes - ('local_fill_sink', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('local_dimshuffle_lift', 1) ('local_greedy_distributor', 1) - 10 - 0.032s 5 (0.000s in global opts, 0.002s io_toposort) - 89 nodes - ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('local_fill_sink', 1) - 11 - 0.030s 5 (0.000s in global opts, 0.002s io_toposort) - 88 nodes - ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('constant_folding', 1) - 12 - 0.026s 1 (0.000s in global opts, 0.003s io_toposort) - 81 nodes - ('MergeOptimizer', 1) - 13 - 0.031s 0 (0.000s in global opts, 0.003s io_toposort) - 81 nodes - + time in node rewriters 0.687s + time in graph rewriters 0.010s + 0 - 0.050s 27 (0.000s in global rewrites, 0.002s io_toposort) - 108 nodes - ('local_dimshuffle_lift', 9) ('local_upcast_elemwise_constant_inputs', 5) ('local_shape_to_shape_i', 3) ('local_fill_sink', 3) ('local_fill_to_alloc', 2) ... + 1 - 0.288s 26 (0.002s in global rewrites, 0.002s io_toposort) - 117 nodes - ('local_dimshuffle_lift', 8) ('local_fill_sink', 4) ('constant_folding', 4) ('local_useless_elemwise', 3) ('local_subtensor_make_vector', 3) ... + 2 - 0.044s 13 (0.002s in global rewrites, 0.003s io_toposort) - 96 nodes - ('constant_folding', 4) ('local_dimshuffle_lift', 3) ('local_fill_sink', 3) ('local_useless_elemwise', 1) ('local_fill_to_alloc', 1) ... + 3 - 0.045s 11 (0.000s in global rewrites, 0.002s io_toposort) - 91 nodes - ('constant_folding', 3) ('local_fill_to_alloc', 2) ('local_dimshuffle_lift', 2) ('local_mul_canonizer', 2) ('MergeOptimizer', 1) ... + 4 - 0.035s 8 (0.002s in global rewrites, 0.002s io_toposort) - 93 nodes - ('local_fill_sink', 3) ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('constant_folding', 1) + 5 - 0.035s 6 (0.000s in global rewrites, 0.002s io_toposort) - 88 nodes - ('local_fill_sink', 2) ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('local_mul_canonizer', 1) + 6 - 0.038s 10 (0.001s in global rewrites, 0.002s io_toposort) - 95 nodes - ('local_fill_sink', 3) ('local_dimshuffle_lift', 3) ('constant_folding', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) + 7 - 0.032s 5 (0.001s in global rewrites, 0.002s io_toposort) - 91 nodes - ('local_fill_sink', 3) ('MergeOptimizer', 1) ('local_dimshuffle_lift', 1) + 8 - 0.034s 5 (0.000s in global rewrites, 0.002s io_toposort) - 92 nodes - ('local_fill_sink', 3) ('MergeOptimizer', 1) ('local_greedy_distributor', 1) + 9 - 0.031s 6 (0.001s in global rewrites, 0.002s io_toposort) - 90 nodes - ('local_fill_sink', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('local_dimshuffle_lift', 1) ('local_greedy_distributor', 1) + 10 - 0.032s 5 (0.000s in global rewrites, 0.002s io_toposort) - 89 nodes - ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('local_fill_sink', 1) + 11 - 0.030s 5 (0.000s in global rewrites, 0.002s io_toposort) - 88 nodes - ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('constant_folding', 1) + 12 - 0.026s 1 (0.000s in global rewrites, 0.003s io_toposort) - 81 nodes - ('MergeOptimizer', 1) + 13 - 0.031s 0 (0.000s in global rewrites, 0.003s io_toposort) - 81 nodes - times - times applied - nb node created - name: 0.263s - 15 - 0 - constant_folding 0.096s - 2 - 14 - local_greedy_distributor @@ -896,7 +901,7 @@ This will output something like this: 0.002s - 2 - 4 - local_subtensor_lift 0.001s - 3 - 0 - local_subtensor_make_vector 0.001s - 1 - 1 - local_sum_all_to_none - 0.131s - in 62 optimization that where not used (display only those with a runtime > 0) + 0.131s - in 62 rewrite(s) that where not used (display only those with a runtime > 0) 0.050s - local_add_canonizer 0.018s - local_mul_zero 0.016s - local_one_minus_erf @@ -924,8 +929,8 @@ This will output something like this: 0.000s - local_subtensor_of_alloc 0.000s - local_subtensor_of_dot 0.000s - local_subtensor_merge - 0.101733s - ('elemwise_fusion', 'SeqOptimizer', 13) - 0.000s - SeqOptimizer elemwise_fusion time 0.102s for 78/50 nodes before/after optimization + 0.101733s - ('elemwise_fusion', 'SequentialGraphRewriter', 13) - 0.000s + SequentialGraphRewriter elemwise_fusion time 0.102s for 78/50 nodes before/after rewriting 0.000s for fgraph.validate() 0.004s for callback 0.095307s - ('composite_elemwise_fusion', 'FusionOptimizer', 1) - 0.000s @@ -944,9 +949,9 @@ This will output something like this: validate_time 6.43730163574e-05 callback_time 0.000783205032349 time_toposort 0.0035240650177 - 0.090089s - ('inplace_elemwise_optimizer', 'FromFunctionOptimizer', 30) - 0.019s - 0.048993s - ('BlasOpt', 'SeqOptimizer', 8) - 0.000s - SeqOptimizer BlasOpt time 0.049s for 81/80 nodes before/after optimization + 0.090089s - ('inplace_elemwise_optimizer', 'FromFunctionGraphRewriter', 30) - 0.019s + 0.048993s - ('BlasOpt', 'SequentialGraphRewriter', 8) - 0.000s + SequentialGraphRewriter BlasOpt time 0.049s for 81/80 nodes before/after rewriting 0.000s for fgraph.validate() 0.003s for callback 0.035997s - ('gemm_optimizer', 'GemmOptimizer', 1) - 0.000s @@ -962,54 +967,54 @@ This will output something like this: time_toposort 0.00311398506165 validate_time 4.60147857666e-05 callback_time 0.00174236297607 - 0.004569s - ('local_dot_to_dot22', 'TopoOptimizer', 0) - 0.000s - TopoOptimizer + 0.004569s - ('local_dot_to_dot22', 'WalkingGraphRewriter', 0) - 0.000s + WalkingGraphRewriter nb_node (start, end, changed) (81, 81, 5) init io_toposort 0.00139284133911 loop time 0.00312399864197 callback_time 0.00172805786133 - 0.002283s - ('local_dot22_to_dot22scalar', 'TopoOptimizer', 2) - 0.000s - TopoOptimizer + 0.002283s - ('local_dot22_to_dot22scalar', 'WalkingGraphRewriter', 2) - 0.000s + WalkingGraphRewriter nb_node (start, end, changed) (80, 80, 0) init io_toposort 0.00171804428101 loop time 0.000502109527588 callback_time 0.0 - 0.002257s - ('local_gemm_to_gemv', 'EquilibriumOptimizer', 3) - 0.000s - EquilibriumOptimizer local_gemm_to_gemv + 0.002257s - ('local_gemm_to_gemv', 'EquilibriumGraphRewriter', 3) - 0.000s + EquilibriumGraphRewriter local_gemm_to_gemv time 0.002s for 1 passes nb nodes (start, end, max) 80 80 80 time io_toposort 0.001s - time in local optimizers 0.000s - time in global optimizers 0.000s - 0 - 0.002s 0 (0.000s in global opts, 0.001s io_toposort) - 80 nodes - - 0.002227s - ('use_c_blas', 'TopoOptimizer', 4) - 0.000s - TopoOptimizer + time in node rewriters 0.000s + time in graph rewriters 0.000s + 0 - 0.002s 0 (0.000s in global rewrites, 0.001s io_toposort) - 80 nodes - + 0.002227s - ('use_c_blas', 'WalkingGraphRewriter', 4) - 0.000s + WalkingGraphRewriter nb_node (start, end, changed) (80, 80, 0) init io_toposort 0.0014750957489 loop time 0.00068998336792 callback_time 0.0 - 0.001632s - ('use_scipy_ger', 'TopoOptimizer', 5) - 0.000s - TopoOptimizer + 0.001632s - ('use_scipy_ger', 'WalkingGraphRewriter', 5) - 0.000s + WalkingGraphRewriter nb_node (start, end, changed) (80, 80, 0) init io_toposort 0.00138401985168 loop time 0.000202178955078 callback_time 0.0 - 0.031740s - ('specialize', 'EquilibriumOptimizer', 9) - 0.000s - EquilibriumOptimizer specialize + 0.031740s - ('specialize', 'EquilibriumGraphRewriter', 9) - 0.000s + EquilibriumGraphRewriter specialize time 0.031s for 2 passes nb nodes (start, end, max) 80 78 80 time io_toposort 0.003s - time in local optimizers 0.022s - time in global optimizers 0.004s - 0 - 0.017s 6 (0.002s in global opts, 0.001s io_toposort) - 80 nodes - ('constant_folding', 2) ('local_mul_to_sqr', 1) ('local_elemwise_alloc', 1) ('local_div_to_inv', 1) ('local_mul_specialize', 1) - 1 - 0.014s 0 (0.002s in global opts, 0.001s io_toposort) - 78 nodes - + time in node rewriters 0.022s + time in graph rewriters 0.004s + 0 - 0.017s 6 (0.002s in global rewrites, 0.001s io_toposort) - 80 nodes - ('constant_folding', 2) ('local_mul_to_sqr', 1) ('local_elemwise_alloc', 1) ('local_div_to_inv', 1) ('local_mul_specialize', 1) + 1 - 0.014s 0 (0.002s in global rewrites, 0.001s io_toposort) - 78 nodes - times - times applied - nb node created - name: 0.003s - 1 - 1 - local_mul_specialize 0.002s - 1 - 2 - local_elemwise_alloc 0.002s - 2 - 0 - constant_folding 0.001s - 1 - 1 - local_div_to_inv 0.001s - 1 - 1 - local_mul_to_sqr - 0.016s - in 69 optimization that where not used (display only those with a runtime > 0) + 0.016s - in 69 rewrite(s) that where not used (display only those with a runtime > 0) 0.004s - crossentropy_to_crossentropy_with_softmax_with_bias 0.002s - local_one_minus_erf 0.002s - Elemwise{sub,no_inplace}(z, Elemwise{mul,no_inplace}(alpha subject to at 0x7f475e4da050>, SparseDot(x, y))) -> Usmm{no_inplace}(Elemwise{neg,no_inplace}(alpha), x, y, z) @@ -1039,68 +1044,68 @@ This will output something like this: ... -To understand this profile here is some explanation of how optimizations work: +To understand this profile here is some explanation of how rewrites work: -* Optimizations are organized in an hierarchy. At the top level, there - is a :class:`SeqOptimizer`. It contains other optimizers, - and applies them in the order they were specified. Those sub-optimizers can be - of other types, but are all *global* optimizers. +* Rewrites are organized in a hierarchy. At the top level, there + is a :class:`SequentialGraphRewriter`. It contains other rewriters, + and applies them in the order they were specified. Those sub-rewriters can be + of other types, but are all **graph** rewriters. -* Each :class:`Optimizer` in the hierarchy will print some stats about +* Each :class:`Rewriter` in the hierarchy will print some stats about itself. The information that it prints depends of the type of the - optimizer. + rewriter. -* The :class:`SeqOptimizer` will print some stats at the start: +* The :class:`SequentialGraphRewriter` will print some stats at the start: .. code-block:: none - Optimizer Profile - ----------------- - SeqOptimizer OPT_FAST_RUN time 1.152s for 123/50 nodes before/after optimization + Rewriter Profile + ---------------- + SequentialGraphRewriter OPT_FAST_RUN time 1.152s for 123/50 nodes before/after rewriting 0.028s for fgraph.validate() 0.131s for callback time - (name, class, index) - validate time - Then it will print, with some additional indentation, each sub-optimizer's profile + Then it will print, with some additional indentation, each sub-rewriter's profile information. These sub-profiles are ordered by the time they took to execute, not by their execution order. - * ``OPT_FAST_RUN`` is the name of the optimizer - * 1.152s is the total time spent in that optimizer - * 123/50 means that before this optimization, there were 123 apply node in the function graph, and after only 50. + * ``OPT_FAST_RUN`` is the name of the rewriter + * 1.152s is the total time spent in that rewriter + * 123/50 means that before this rewriter, there were 123 apply node in the function graph, and after only 50. * 0.028s means it spent that time calls to ``fgraph.validate()`` * 0.131s means it spent that time for callbacks. This is a mechanism that can trigger other execution when there is a change to the FunctionGraph. - * ``time - (name, class, index) - validate time`` tells how the information for each sub-optimizer get printed. - * All other instances of :class:`SeqOptimizer` are described like this. In - particular, some sub-optimizer from ``OPT_FAST_RUN`` that are also - :class:`SeqOptimizer`. + * ``time - (name, class, index) - validate time`` tells how the information for each sub-rewriter get printed. + * All other instances of :class:`SequentialGraphRewriter` are described like this. In + particular, some sub-rewriter from ``OPT_FAST_RUN`` that are also + :class:`SequentialGraphRewriter`. -* The :class:`SeqOptimizer` will print some stats at the start: +* The :class:`SequentialGraphRewriter` will print some stats at the start: .. code-block:: none - 0.751816s - ('canonicalize', 'EquilibriumOptimizer', 4) - 0.004s - EquilibriumOptimizer canonicalize + 0.751816s - ('canonicalize', 'EquilibriumGraphRewriter', 4) - 0.004s + EquilibriumGraphRewriter canonicalize time 0.751s for 14 passes nb nodes (start, end, max) 108 81 117 time io_toposort 0.029s - time in local optimizers 0.687s - time in global optimizers 0.010s - 0 - 0.050s 27 (0.000s in global opts, 0.002s io_toposort) - 108 nodes - ('local_dimshuffle_lift', 9) ('local_upcast_elemwise_constant_inputs', 5) ('local_shape_to_shape_i', 3) ('local_fill_sink', 3) ('local_fill_to_alloc', 2) ... - 1 - 0.288s 26 (0.002s in global opts, 0.002s io_toposort) - 117 nodes - ('local_dimshuffle_lift', 8) ('local_fill_sink', 4) ('constant_folding', 4) ('local_useless_elemwise', 3) ('local_subtensor_make_vector', 3) ... - 2 - 0.044s 13 (0.002s in global opts, 0.003s io_toposort) - 96 nodes - ('constant_folding', 4) ('local_dimshuffle_lift', 3) ('local_fill_sink', 3) ('local_useless_elemwise', 1) ('local_fill_to_alloc', 1) ... - 3 - 0.045s 11 (0.000s in global opts, 0.002s io_toposort) - 91 nodes - ('constant_folding', 3) ('local_fill_to_alloc', 2) ('local_dimshuffle_lift', 2) ('local_mul_canonizer', 2) ('MergeOptimizer', 1) ... - 4 - 0.035s 8 (0.002s in global opts, 0.002s io_toposort) - 93 nodes - ('local_fill_sink', 3) ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('constant_folding', 1) - 5 - 0.035s 6 (0.000s in global opts, 0.002s io_toposort) - 88 nodes - ('local_fill_sink', 2) ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('local_mul_canonizer', 1) - 6 - 0.038s 10 (0.001s in global opts, 0.002s io_toposort) - 95 nodes - ('local_fill_sink', 3) ('local_dimshuffle_lift', 3) ('constant_folding', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) - 7 - 0.032s 5 (0.001s in global opts, 0.002s io_toposort) - 91 nodes - ('local_fill_sink', 3) ('MergeOptimizer', 1) ('local_dimshuffle_lift', 1) - 8 - 0.034s 5 (0.000s in global opts, 0.002s io_toposort) - 92 nodes - ('local_fill_sink', 3) ('MergeOptimizer', 1) ('local_greedy_distributor', 1) - 9 - 0.031s 6 (0.001s in global opts, 0.002s io_toposort) - 90 nodes - ('local_fill_sink', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('local_dimshuffle_lift', 1) ('local_greedy_distributor', 1) - 10 - 0.032s 5 (0.000s in global opts, 0.002s io_toposort) - 89 nodes - ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('local_fill_sink', 1) - 11 - 0.030s 5 (0.000s in global opts, 0.002s io_toposort) - 88 nodes - ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('constant_folding', 1) - 12 - 0.026s 1 (0.000s in global opts, 0.003s io_toposort) - 81 nodes - ('MergeOptimizer', 1) - 13 - 0.031s 0 (0.000s in global opts, 0.003s io_toposort) - 81 nodes - + time in node rewriters 0.687s + time in graph rewriters 0.010s + 0 - 0.050s 27 (0.000s in global rewrites, 0.002s io_toposort) - 108 nodes - ('local_dimshuffle_lift', 9) ('local_upcast_elemwise_constant_inputs', 5) ('local_shape_to_shape_i', 3) ('local_fill_sink', 3) ('local_fill_to_alloc', 2) ... + 1 - 0.288s 26 (0.002s in global rewrites, 0.002s io_toposort) - 117 nodes - ('local_dimshuffle_lift', 8) ('local_fill_sink', 4) ('constant_folding', 4) ('local_useless_elemwise', 3) ('local_subtensor_make_vector', 3) ... + 2 - 0.044s 13 (0.002s in global rewrites, 0.003s io_toposort) - 96 nodes - ('constant_folding', 4) ('local_dimshuffle_lift', 3) ('local_fill_sink', 3) ('local_useless_elemwise', 1) ('local_fill_to_alloc', 1) ... + 3 - 0.045s 11 (0.000s in global rewrites, 0.002s io_toposort) - 91 nodes - ('constant_folding', 3) ('local_fill_to_alloc', 2) ('local_dimshuffle_lift', 2) ('local_mul_canonizer', 2) ('MergeOptimizer', 1) ... + 4 - 0.035s 8 (0.002s in global rewrites, 0.002s io_toposort) - 93 nodes - ('local_fill_sink', 3) ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('constant_folding', 1) + 5 - 0.035s 6 (0.000s in global rewrites, 0.002s io_toposort) - 88 nodes - ('local_fill_sink', 2) ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('local_mul_canonizer', 1) + 6 - 0.038s 10 (0.001s in global rewrites, 0.002s io_toposort) - 95 nodes - ('local_fill_sink', 3) ('local_dimshuffle_lift', 3) ('constant_folding', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) + 7 - 0.032s 5 (0.001s in global rewrites, 0.002s io_toposort) - 91 nodes - ('local_fill_sink', 3) ('MergeOptimizer', 1) ('local_dimshuffle_lift', 1) + 8 - 0.034s 5 (0.000s in global rewrites, 0.002s io_toposort) - 92 nodes - ('local_fill_sink', 3) ('MergeOptimizer', 1) ('local_greedy_distributor', 1) + 9 - 0.031s 6 (0.001s in global rewrites, 0.002s io_toposort) - 90 nodes - ('local_fill_sink', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('local_dimshuffle_lift', 1) ('local_greedy_distributor', 1) + 10 - 0.032s 5 (0.000s in global rewrites, 0.002s io_toposort) - 89 nodes - ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('local_fill_sink', 1) + 11 - 0.030s 5 (0.000s in global rewrites, 0.002s io_toposort) - 88 nodes - ('local_dimshuffle_lift', 2) ('local_fill_to_alloc', 1) ('MergeOptimizer', 1) ('constant_folding', 1) + 12 - 0.026s 1 (0.000s in global rewrites, 0.003s io_toposort) - 81 nodes - ('MergeOptimizer', 1) + 13 - 0.031s 0 (0.000s in global rewrites, 0.003s io_toposort) - 81 nodes - times - times applied - nb node created - name: 0.263s - 15 - 0 - constant_folding 0.096s - 2 - 14 - local_greedy_distributor @@ -1117,7 +1122,7 @@ To understand this profile here is some explanation of how optimizations work: 0.002s - 2 - 4 - local_subtensor_lift 0.001s - 3 - 0 - local_subtensor_make_vector 0.001s - 1 - 1 - local_sum_all_to_none - 0.131s - in 62 optimization that where not used (display only those with a runtime > 0) + 0.131s - in 62 rewrite(s) that where not used (display only those with a runtime > 0) 0.050s - local_add_canonizer 0.018s - local_mul_zero 0.016s - local_one_minus_erf @@ -1146,22 +1151,22 @@ To understand this profile here is some explanation of how optimizations work: 0.000s - local_subtensor_of_dot 0.000s - local_subtensor_merge - * ``0.751816s - ('canonicalize', 'EquilibriumOptimizer', 4) - 0.004s`` - This line is from :class:`SeqOptimizer`, and indicates information related - to a sub-optimizer. It means that this sub-optimizer took + * ``0.751816s - ('canonicalize', 'EquilibriumGraphRewriter', 4) - 0.004s`` + This line is from :class:`SequentialGraphRewriter`, and indicates information related + to a sub-rewriter. It means that this sub-rewriter took a total of .7s. Its name is ``'canonicalize'``. It is an - :class:`EquilibriumOptimizer`. It was executed at index 4 by the - :class:`SeqOptimizer`. It spent 0.004s in the *validate* phase. - * All other lines are from the profiler of the :class:`EquilibriumOptimizer`. - - * An :class:`EquilibriumOptimizer` does multiple passes on the Apply nodes from - the graph, trying to apply local and global optimizations. - Conceptually, it tries to execute all global optimizations, - and to apply all local optimizations on all - nodes in the graph. If no optimization got applied during a pass, it - stops. So it tries to find an equilibrium state where none of the - optimizations get applied. This is useful when we do not know a fixed order for - the execution of the optimization. + :class:`EquilibriumGraphRewriter`. It was executed at index 4 by the + :class:`SequentialGraphRewriter`. It spent 0.004s in the *validate* phase. + * All other lines are from the profiler of the :class:`EquilibriumGraphRewriter`. + + * An :class:`EquilibriumGraphRewriter` does multiple passes on the Apply nodes from + the graph, trying to apply local and graph rewriters. + Conceptually, it tries to execute all graph rewriters, + and to apply all node rewriters on all + nodes in the graph. If no rewrites got applied during a pass, it + stops. So it tries to find an equilibrium state where no further rewrites + can be applied. This is useful when we do not know a fixed order for the + execution of rewrites. * ``time 0.751s for 14 passes`` means that it took .7s and did 14 passes over the graph. * ``nb nodes (start, end, max) 108 81 117`` means that at the start, @@ -1169,30 +1174,30 @@ To understand this profile here is some explanation of how optimizations work: was 117. * Then it prints some global timing information: it spent 0.029s in - :func:`io_toposort`, all local optimizers took 0.687s together for all - passes, and global optimizers took a total of 0.010s. + :func:`io_toposort`, all node rewriters took 0.687s together for all + passes, and graph rewriters took a total of 0.010s. - * Then we print the timing for each pass, the optimization that + * Then we print the timing for each pass, the rewrite that got applied, and the number of time they got applied. For example, - in pass 0, the :func:`local_dimshuffle_lift` optimizer changed the graph 9 - time. + in pass zero, the :func:`local_dimshuffle_lift` rewrite changed the graph + nine time. - * Then we print the time spent in each optimizer, the number of times + * Then we print the time spent in each rewriter, the number of times they changed the graph and the number of nodes they introduced in the graph. - * Optimizations with that pattern :func:`local_op_lift` means that a node - with that op will be replaced by another node, with the same op, - but will do computation closer to the inputs of the graph. - For instance, ``local_op(f(x))`` getting replaced by ``f(local_op(x))``. - - * Optimization with that pattern :func:`local_op_sink` is the opposite of - "lift". For instance ``f(local_op(x))`` getting replaced by ``local_op(f(x))``. - - * Local optimizers can replace any arbitrary node in the graph, not - only the node it received as input. For this, it must return a - ``dict``. The keys being nodes to replace and the - values being the corresponding replacement. - - This is useful to replace a client of the node received as - parameter. + * Rewrites with that pattern :func:`local_op_lift` indicate that a node + with that `Op` will be replaced by another node with the same `Op`, + but will do computation closer to the inputs of the graph: i.e. a "lift" of + the `Op`. + For instance, in ``local_op(f(x))``, ``local_op`` is lifted through ``f`` to + produce ``f(local_op(x))``. + + * Rewrites with that pattern :func:`local_op_sink` is the opposite of + lifting. For instance, in ``f(local_op(x))``, ``local_op`` is sunk through + ``f`` to produce ``local_op(f(x))``. + + * Local rewriters can replace any arbitrary node in the graph, not + only the nodes they receive as input. In this case, the local rewrite returns a + ``dict``, where the keys are `Variable`\s to be replaced and the + values are the corresponding replacements. diff --git a/doc/extending/graphstructures.rst b/doc/extending/graphstructures.rst index 64fabc9706..41f16504f5 100644 --- a/doc/extending/graphstructures.rst +++ b/doc/extending/graphstructures.rst @@ -205,7 +205,7 @@ structures, code going like ``def f(x): ...`` would produce an :class:`Op` for A :class:`Type` in Aesara provides static information (or constraints) about data objects in a graph. The information provided by :class:`Type`\s allows -Aesara to perform optimizations and produce more efficient compiled code. +Aesara to perform rewrites and produce more efficient compiled code. Every symbolic :class:`Variable` in an Aesara graph has an associated :class:`Type` instance, and :class:`Type`\s also serve as a means of @@ -306,7 +306,7 @@ When used in a computation graph as the input of an will *always* take the value contained in the :class:`Constant`'s data field. Furthermore, it is assumed that the :class:`Op` will not under any circumstances modify the input. This means that a :class:`Constant` is -eligible to participate in numerous optimizations: constant in-lining +eligible to participate in numerous rewrites: constant in-lining in C code, constant folding, etc. Automatic Differentiation @@ -327,26 +327,26 @@ gradient of the graph's output with respect to the graph's inputs. A following section of this tutorial will examine the topic of :ref:`differentiation` in greater detail. -Optimizations -============= +Rewrites +======== When compiling an Aesara graph using :func:`aesara.function`, a graph is necessarily provided. While this graph structure shows how to compute the output from the input, it also offers the possibility to improve the way this -computation is carried out. The way optimizations work in Aesara is by +computation is carried out. The way rewrites work in Aesara is by identifying and replacing certain patterns in the graph with other specialized patterns that produce the same results but are either faster or more -stable. Optimizations can also detect identical subgraphs and ensure that the +stable. Rewrites can also detect identical subgraphs and ensure that the same values are not computed twice. -For example, one (simple) optimization that Aesara uses is to replace +For example, one simple rewrite that Aesara uses is to replace the pattern :math:`\frac{xy}{y}` by :math:`x`. See :ref:`graph_rewriting` and :ref:`optimizations` for more information. **Example** -Consider the following example of optimization: +Consider the following example of rewrites: >>> import aesara >>> a = aesara.tensor.vector("a") # declare symbolic variable @@ -354,13 +354,13 @@ Consider the following example of optimization: >>> f = aesara.function([a], b) # compile function >>> print(f([0, 1, 2])) # prints `array([0,2,1026])` [ 0. 2. 1026.] ->>> aesara.printing.pydotprint(b, outfile="./pics/symbolic_graph_unopt.png", var_with_name_simple=True) # doctest: +SKIP -The output file is available at ./pics/symbolic_graph_unopt.png ->>> aesara.printing.pydotprint(f, outfile="./pics/symbolic_graph_opt.png", var_with_name_simple=True) # doctest: +SKIP -The output file is available at ./pics/symbolic_graph_opt.png +>>> aesara.printing.pydotprint(b, outfile="./pics/symbolic_graph_no_rewrite.png", var_with_name_simple=True) # doctest: +SKIP +The output file is available at ./pics/symbolic_graph_no_rewrite.png +>>> aesara.printing.pydotprint(f, outfile="./pics/symbolic_graph_rewite.png", var_with_name_simple=True) # doctest: +SKIP +The output file is available at ./pics/symbolic_graph_rewrite.png -We used :func:`aesara.printing.pydotprint` to visualize the optimized graph -(right), which is much more compact than the unoptimized graph (left). +We used :func:`aesara.printing.pydotprint` to visualize the rewritten graph +(right), which is much more compact than the un-rewritten graph (left). .. |g1| image:: ./pics/symbolic_graph_unopt.png :width: 500 px @@ -368,7 +368,7 @@ We used :func:`aesara.printing.pydotprint` to visualize the optimized graph :width: 500 px ================================ ====================== ================================ - Unoptimized graph Optimized graph + Un-rewritten graph Rewritten graph ================================ ====================== ================================ |g1| |g2| ================================ ====================== ================================ diff --git a/doc/extending/index.rst b/doc/extending/index.rst index 55c2fb5864..321c1ad8dd 100644 --- a/doc/extending/index.rst +++ b/doc/extending/index.rst @@ -6,14 +6,14 @@ Extending Aesara ================ This advanced tutorial is for users who want to extend Aesara with new :class:`Type`\s, -new Operations (:Class:`Op`\S), and new graph optimizations. This first page of the -tutorial mainly focuses on the Python implementation of an :Class:`Op` and then +new operations (i.e. :class:`Op`\s), and new graph rewrites. This first page of the +tutorial mainly focuses on the Python implementation of an :class:`Op` and then proposes an overview of the most important methods that define an :class:`Op`. The second page of the tutorial (:ref:`creating_a_c_op`) provides then -information on the C implementation of an :Class:`Op`. The rest of the tutorial -goes more in depth on advanced topics related to :Class:`Op`\s, such as how to write -efficient code for an :Class:`Op` and how to write an optimization to speed up the -execution of an :Class:`Op`. +information on the C implementation of an :class:`Op`. The rest of the tutorial +goes more in depth on advanced topics related to :class:`Op`\s, such as how to write +efficient code for an :class:`Op` and how to write an rewrite to speed up the +execution of an :class:`Op`. Along the way, this tutorial also introduces many aspects of how Aesara works, so it is also good for you if you are interested in getting more under the hood @@ -23,11 +23,11 @@ with Aesara itself. Before tackling this more advanced presentation, it is highly recommended to read the introductory :ref:`Tutorial`, especially the sections - that introduce the Aesara Graphs, as providing a novel Aesara :class:`Op` requires a - basic understanting of the Aesara Graphs. + that introduce the Aesara graphs, as providing a novel Aesara :class:`Op` requires a + basic understanting of the Aesara graphs. See also the :ref:`dev_start_guide` for information regarding the - versioning framework, namely about *git* and *GitHub*, regarding the + versioning framework, namely about Git and GitHub, regarding the development workflow and how to make a quality contribution. .. toctree:: diff --git a/doc/extending/inplace.rst b/doc/extending/inplace.rst index 04fd9a7551..f3a73037fc 100644 --- a/doc/extending/inplace.rst +++ b/doc/extending/inplace.rst @@ -5,11 +5,11 @@ Views and inplace operations ============================ -Aesara allows the definition of ``Op``\s which return a :term:`view` on one +Aesara allows the definition of :class:`Op`\s which return a :term:`view` on one of their inputs or operate :term:`inplace` on one or several -inputs. This allows more efficient operations on NumPy's ``ndarray`` +inputs. This allows more efficient operations on NumPy's :class:`ndarray` data type than would be possible otherwise. -However, in order to work correctly, these ``Op``\s need to +However, in order to work correctly, these :class:`Op`\s need to implement an additional interface. Aesara recognizes views and inplace operations specially. It ensures @@ -23,7 +23,7 @@ Views A "view" on an object ``x`` is an object ``y`` which shares memory with ``x`` in some way. In other words, changing ``x`` might also -change ``y`` and vice versa. For example, imagine a ``vector`` structure +change ``y`` and vice versa. For example, imagine a `vector` structure which contains two fields: an integer length and a pointer to a memory buffer. Suppose we have: @@ -44,9 +44,9 @@ range ``0xDEADBEFF - 0xDEADBFDF`` and z the range ``0xCAFEBABE - 0xCAFEBBBE``. Since the ranges for ``x`` and ``y`` overlap, ``y`` is considered to be a view of ``x`` and vice versa. -Suppose you had an ``Op`` which took ``x`` as input and returned +Suppose you had an :class:`Op` which took ``x`` as input and returned ``y``. You would need to tell Aesara that ``y`` is a view of ``x``. For this -purpose, you would set the ``view_map`` field as follows: +purpose, you would set the :class:`Op.view_map` field as follows: .. testsetup:: @@ -88,15 +88,15 @@ Inplace operations An inplace operation is one that modifies one or more of its inputs. For example, the expression ``x += y`` where ``x`` and ``y`` -are ``numpy.ndarray`` instances would normally represent an inplace +are :class:`numpy.ndarray` instances would normally represent an inplace operation on ``x``. .. note:: Inplace operations in Aesara still work in a functional setting: they need to return the modified input. Symbolically, Aesara - requires one Variable standing for the input *before* being modified - and *another* Variable representing the input *after* being + requires one :class:`Variable` standing for the input before being modified + and another :class:`Variable` representing the input after being modified. Therefore, code using inplace operations would look like this: @@ -121,29 +121,29 @@ operation on ``x``. Needless to say, this goes for user-defined inplace operations as well; the modified input must figure in the list of outputs you - give to ``Apply`` in the definition of ``make_node``. + give to :class:`Apply` in the definition of :meth:`Apply.make_node`. Also, for technical reasons but also because they are slightly confusing to use as evidenced by the previous code, Aesara does not allow the end user to use inplace operations by default. However, - it does allow *optimizations* to substitute them in in a later + it does allow rewrites to substitute them in in a later phase. Therefore, typically, if you define an inplace operation, - you will define a pure equivalent and an optimization which + you will define a pure equivalent and a rewrite which substitutes one for the other. Aesara will automatically verify if it is possible to do so and will refuse the substitution if it introduces inconsistencies. -Take the previous definitions of ``x``, ``y`` and ``z`` and suppose an ``Op`` which +Take the previous definitions of ``x``, ``y`` and ``z`` and suppose an :class:`Op` which adds one to every byte of its input. If we give ``x`` as an input to -that ``Op``, it can either allocate a new buffer of the same size as ``x`` +that :class:`Op`, it can either allocate a new buffer of the same size as ``x`` (that could be ``z``) and set that new buffer's bytes to the variable of -the addition. That would be a normal, :term:`pure` ``Op``. Alternatively, -it could add one to each byte *in* the buffer ``x``, therefore -changing it. That would be an inplace ``Op``. +the addition. That would be a normal, :term:`pure`\ :class:`Op`. Alternatively, +it could add one to each byte in the buffer ``x``, therefore +changing it. That would be an inplace :class:`Op`. Aesara needs to be notified of this fact. The syntax is similar to -that of ``view_map``: +that of :attr:`Op.view_map`: .. testcode:: @@ -171,27 +171,27 @@ first input (position 0). # unlike for views, the previous line is legal and supported .. note:: - ``DestroyHandler`` provides a hackish means of specifying that a variable cannot be + :class:`DestroyHandler` provides a hackish means of specifying that a variable cannot be "destroyed" by an in-place operation: ``var.tag.indestructible = True``. Destructive Operations ====================== While some operations will operate inplace on their inputs, some might -simply destroy or corrupt them. For example, an ``Op`` could do temporary +simply destroy or corrupt them. For example, an :class:`Op` could do temporary calculations right in its inputs. If that is the case, Aesara also needs to be notified. The way to notify Aesara is to assume that some output operated inplace on whatever inputs are changed or corrupted by -the ``Op`` (even if the output does not technically reuse any of the +the :class:`Op` (even if the output does not technically reuse any of the input(s)'s memory). From there, go to the previous section. .. warning:: Failure to correctly mark down views and inplace operations using - ``view_map`` and ``destroy_map`` can lead to nasty bugs. In the + :attr:`Op.view_map` and :attr:`Op.destroy_map` can lead to nasty bugs. In the absence of this information, Aesara might assume that it is safe to - execute an inplace operation on some inputs *before* doing other - calculations on the *previous* values of the inputs. For example, + execute an inplace operation on some inputs before doing other + calculations on the previous values of the inputs. For example, in the code: ``y = log(x); x2 = add_inplace(x, z)`` it is imperative to do the logarithm before the addition (because after the addition, the original x that we wanted to take the logarithm @@ -199,25 +199,28 @@ input(s)'s memory). From there, go to the previous section. the value of ``x`` it might invert the order and that will certainly lead to erroneous computations. - You can often identify an incorrect ``view_map`` or ``destroy_map`` - by using :ref:`DebugMode`. *Be sure to use ``DebugMode`` when developing - a new ``Op`` that uses ``view_map`` and/or ``destroy_map``.* + You can often identify an incorrect `Op.view_map` or :attr:`Op.destroy_map` + by using :ref:`DebugMode`. -Inplace optimization and DebugMode -================================== +.. note:: + Consider using :class:`DebugMode` when developing + a new :class:`Op` that uses :attr:`Op.view_map` and/or :attr:`Op.destroy_map`. + +Inplace Rewriting and `DebugMode` +================================= -It is recommended that during the graph construction, all ``Op``\s are not inplace. -Then an optimization replaces them with inplace ones. Currently ``DebugMode`` checks -all optimizations that were tried even if they got rejected. One reason an inplace -optimization can get rejected is when there is another ``Op`` that is already being applied -inplace on the same input. Another reason to reject an inplace optimization is +It is recommended that during the graph construction, all :class:`Op`\s are not inplace. +Then a rewrite replaces them with inplace ones. Currently :class:`DebugMode` checks +all rewrites that were tried even if they got rejected. One reason an inplace +rewrite can get rejected is when there is another :class:`Op` that is already being applied +inplace on the same input. Another reason to reject an inplace rewrite is if it would introduce a cycle into the graph. -The problem with ``DebugMode`` is that it will trigger a useless error when -checking a rejected inplace optimization, since it will lead to wrong results. -In order to be able to use ``DebugMode`` in more situations, your inplace -optimization can pre-check whether it will get rejected by using the -``aesara.graph.destroyhandler.fast_inplace_check()`` function, that will tell -which ``Op``\s can be performed inplace. You may then skip the optimization if it is -incompatible with this check. Note however that this check does not cover all -cases where an optimization may be rejected (it will not detect cycles). +The problem with `DebugMode` is that it will trigger a useless error when +checking a rejected inplace rewrite, since it will lead to wrong results. +In order to be able to use `DebugMode` in more situations, your inplace +rewrite can pre-check whether it will get rejected by using the +:func:`aesara.graph.destroyhandler.fast_inplace_check` function, that will tell +which :class:`Op`\s can be performed inplace. You may then skip the rewrite if it is +incompatible with this check. Note, however, that this check does not cover all +cases where a rewrite may be rejected (it will not detect cycles). diff --git a/doc/extending/op.rst b/doc/extending/op.rst index fb242d28ad..5f860d9849 100644 --- a/doc/extending/op.rst +++ b/doc/extending/op.rst @@ -77,12 +77,12 @@ It has to define the following methods. ``other`` is also an :class:`Op`. - Returning ``True`` here is a promise to the optimization system + Returning ``True`` here is a promise to the rewrite system that the other :class:`Op` will produce exactly the same graph effects - (from perform) as this one, given identical inputs. This means it + (e.g. from its :meth:`Op.perform`) as this one, given identical inputs. This means it will produce the same output values, it will destroy the same - inputs (same ``destroy_map``), and will alias outputs to the same - inputs (same ``view_map``). For more details, see + inputs (same :attr:`Op.destroy_map`), and will alias outputs to the same + inputs (same :attr:`Op.view_map`). For more details, see :ref:`views_and_inplace`. .. note:: @@ -99,9 +99,9 @@ It has to define the following methods. lifetime of self. :class:`Op` instances should be immutable in this sense. - .. note:: +.. note:: - If you set `__props__`, this will be automatically generated. + If you set :attr:`Op.__props__`, this will be automatically generated. .. op_optional: @@ -110,7 +110,7 @@ Optional methods or attributes .. attribute:: __props__ - *Default:* Undefined + Default: Undefined Must be a tuple. Lists the name of the attributes which influence the computation performed. This will also enable the automatic @@ -122,7 +122,7 @@ Optional methods or attributes .. attribute:: default_output - *Default:* None + Default: None If this member variable is an integer, then the default implementation of ``__call__`` will return @@ -177,7 +177,7 @@ Optional methods or attributes .. function:: infer_shape(fgraph, node, shapes) - This function is needed for shape optimization. ``shapes`` is a + This function is needed for shape rewrites. ``shapes`` is a list with one tuple for each input of the :class:`Apply` node (which corresponds to the inputs of the :class:`Op`). Each tuple contains as many elements as the number of dimensions of the corresponding input. The value of each element @@ -216,9 +216,9 @@ Optional methods or attributes .. function:: do_constant_folding(fgraph, node) - *Default:* Return True + Default: Return ``True`` - By default when optimizations are enabled, we remove during + By default when rewrites are enabled, we remove during function compilation :class:`Apply` nodes whose inputs are all constants. We replace the :class:`Apply` node with an Aesara constant variable. This way, the :class:`Apply` node is not executed at each function diff --git a/doc/extending/pipeline.rst b/doc/extending/pipeline.rst index 9e37e1a513..fb98e8752d 100644 --- a/doc/extending/pipeline.rst +++ b/doc/extending/pipeline.rst @@ -35,21 +35,20 @@ Some relevant :ref:`Features ` are typically added t rewrites from operating in-place on inputs declared as immutable. -Step 2 - Perform graph optimizations -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Step 2 - Perform graph rewrites +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Once the :class:`FunctionGraph` is constructed, an :term:`optimizer` is produced by -the :term:`mode` passed to :func:`function` (the :class:`Mode` basically has two -important fields, :attr:`linker` and :attr:`optimizer`). That optimizer is -applied on the :class:`FunctionGraph` using its :meth:`Optimizer.optimize` method. +Once the :class:`FunctionGraph` is constructed, a :term:`rewriter` is produced by +the :term:`mode` passed to :func:`function`. That rewrite is +applied to the :class:`FunctionGraph` using its :meth:`GraphRewriter.rewrite` method. -The optimizer is typically obtained through :attr:`optdb`. +The rewriter is typically obtained through a query on :attr:`optdb`. Step 3 - Execute linker to obtain a thunk ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Once the computation graph is optimized, the :term:`linker` is +Once the computation graph is rewritten, the :term:`linker` is extracted from the :class:`Mode`. It is then called with the :class:`FunctionGraph` as argument to produce a ``thunk``, which is a function with no arguments that returns nothing. Along with the thunk, one list of input containers (a @@ -61,9 +60,9 @@ the inputs must be placed in the input containers, the thunk must be called, and the outputs must be retrieved from the output containers where the thunk put them. -Typically, the linker calls the ``toposort`` method in order to obtain +Typically, the linker calls the :meth:`FunctionGraph.toposort` method in order to obtain a linear sequence of operations to perform. How they are linked -together depends on the Linker used. The :class:`CLinker` produces a single +together depends on the :class:`Linker` class used. For example, the :class:`CLinker` produces a single block of C code for the whole computation, whereas the :class:`OpWiseCLinker` produces one thunk for each individual operation and calls them in sequence. diff --git a/doc/extending/scan.rst b/doc/extending/scan.rst index ce5b89406c..69721d7118 100644 --- a/doc/extending/scan.rst +++ b/doc/extending/scan.rst @@ -36,7 +36,7 @@ The following sections assumes the reader is familiar with the following : 2. The interface and usage of Aesara's :ref:`scan ` function -Additionally, the :ref:`scan_internals_optimizations` section below assumes +Additionally, the :ref:`scan_internals_rewrites` section below assumes knowledge of: 3. Aesara's :ref:`graph rewriting ` @@ -63,7 +63,7 @@ deal with, are : * ``views.py`` contains different views of the `Scan` `Op` that have simpler and easier signatures to be used in specific cases. -* ``opt.py`` contains the list of all Aesara graph optimizations for the +* ``opt.py`` contains the list of all Aesara graph rewrites for the `Scan` operator. @@ -155,15 +155,15 @@ Multiply-recurrent multiple outputs (MITMOT) Initial values for =========================================================== ======================================================= ============================================================ ============================================================= ========================================================= ====================================================== -.. _scan_internals_optimizations: +.. _scan_internals_rewrites: -Optimizations -============= +Rewrites +======== `remove_constants_and_unused_inputs_scan` ----------------------------------------- -This optimization serves two purposes, The first is to remove a `Scan` `Op`'s +This rewrite serves two purposes, The first is to remove a :class:`Scan`\ `Op`'s unused inputs. The second is to take a `Scan` `Op`'s constant inputs and remove them, instead injecting the constants directly into the graph or the `Scan` `Op`'s inner function. This will allow constant folding to happen inside the @@ -173,31 +173,31 @@ inner function. `PushOutNonSeqScan` ------------------- -This optimizations pushes, out of `Scan`'s inner function and into the outer -function, computation that depends only on non-sequence inputs. Such -computation ends up being done every iteration on the same values so moving -it to the outer function to be executed only once, before the `Scan` `Op`, -reduces the amount of computation that needs to be performed. +This rewrite pushes sub-graphs that depends only on non-sequence inputs out of +`Scan`'s inner function and into the outer function. Such computation ends up +being done every iteration on the same values so moving it to the outer function +to be executed only once, before the `Scan`\ `Op`, reduces the amount of +computation that needs to be performed. `PushOutSeqScan` ---------------- -This optimization resembles `PushOutNonSeqScan` but it tries to push, out of +This rewrite resembles `PushOutNonSeqScan` but it tries to push, out of the inner function, the computation that only relies on sequence and -non-sequence inputs. The idea behind this optimization is that, when it is +non-sequence inputs. The idea behind this rewrite is that, when it is possible to do so, it is generally more computationally efficient to perform a single operation on a large tensor rather then perform that same operation -many times on many smaller tensors. In many cases, this optimization can +many times on many smaller tensors. In many cases, this rewrite can increase memory usage but, in some specific cases, it can also decrease it. `PushOutScanOutput` ------------------- -This optimizations attempts to push out some of the computation at the end +This rewrite attempts to push out some of the computation at the end of the inner function to the outer function, to be executed after the `Scan` -node. Like `PushOutSeqScan`, this optimization aims to replace many operations +node. Like `PushOutSeqScan`, this rewrite aims to replace many operations on small tensors by few operations on large tensors. It can also lead to increased memory usage. @@ -205,23 +205,23 @@ increased memory usage. `PushOutDot1` ------------- -This is another optimization that attempts to detect certain patterns of -computation in a `Scan` `Op`'s inner function and move this computation to the +This is another rewrite that attempts to detect certain patterns of +computation in a `Scan`\ `Op`'s inner function and move this computation to the outer graph. `ScanInplaceOptimizer` ---------------------- -This optimization attempts to make `Scan` compute its recurrent outputs inplace -on the input tensors that contain their initial states. This optimization can +This rewrite attempts to make `Scan` compute its recurrent outputs inplace +on the input tensors that contain their initial states. This rewrite can improve runtime performance as well as reduce memory usage. `ScanSaveMem` ------------- -This optimizations attempts to determine if a `Scan` node, during its execution, +This rewrite attempts to determine if a `Scan` node, during its execution, for any of its outputs, can get away with allocating a memory buffer that is large enough to contain some of the computed timesteps of that output but not all of them. @@ -233,7 +233,7 @@ need to store the most recent ``N`` values, not all of them. For instance, if a `Scan` node has a SITSOT output (last computed value is fed back as an input at the next iteration) and only the last timestep of -that output is ever used in the outer function, the `ScanSaveMem` optimization +that output is ever used in the outer function, the `ScanSaveMem` rewrite could determine that there is no need to store all computed timesteps for that SITSOT output. Only the most recently computed timestep ever needs to be kept in memory. @@ -242,11 +242,11 @@ be kept in memory. `ScanMerge` ----------- -This optimization attempts to fuse distinct `Scan` `Op`s into a single `Scan` `Op` -that performs all the computation. The main advantage of merging `Scan` `Op`\s -together comes from the possibility of both original `Op`\s having some +This rewrite attempts to fuse distinct `Scan` nodes into a single `Scan` node +that performs all the computation. The main advantage of merging `Scan` nodes +together comes from the possibility of both original `Scan`\ `Op`\s having some computation in common. In such a setting, this computation ends up being done -twice. The fused `Scan` `Op`, however, would only need to do it once and could +twice. The fused `Scan`\s, however, would only need to do it once and could therefore be more computationally efficient. Also, since every `Scan` node involves a certain overhead, at runtime, reducing the number of `Scan` nodes in the graph can improve performance. @@ -255,7 +255,7 @@ the graph can improve performance. `scan_merge_inouts` ------------------- -This optimization attempts to merge a `Scan` `Op`'s identical outer inputs as well +This rewrite attempts to merge a `Scan`\s identical outer inputs as well as merge its identical outer outputs (outputs that perform the same computation on the same inputs). This can reduce the amount of computation as well as result in a simpler graph for both the inner function and the outer @@ -267,7 +267,7 @@ Helper classes and functions Because of the complexity involved in dealing with `Scan`, a large number of helper classes and functions have been developed over time to implement -operations commonly needed when dealing with the `Scan` `Op`. The `Scan` `Op` +operations commonly needed when dealing with the `Scan`\ `Op`. The `Scan`\ `Op` itself defines a large number of them and others can be found in the file ``utils.py``. This sections aims to point out the most useful ones sorted by usage. diff --git a/doc/extending/tips.rst b/doc/extending/tips.rst index 8043364d7b..95ce9494c1 100644 --- a/doc/extending/tips.rst +++ b/doc/extending/tips.rst @@ -25,7 +25,7 @@ simple function: def sum_square_difference(a, b): return at.sum((a - b)**2) -Even without taking Aesara's optimizations into account, it is likely +Even without taking Aesara's rewrites into account, it is likely to work just as well as a custom implementation. It also supports all data types, tensors of all dimensions as well as broadcasting, whereas a custom implementation would probably only bother to support diff --git a/doc/extending/type.rst b/doc/extending/type.rst index 74d407052b..39accfc687 100644 --- a/doc/extending/type.rst +++ b/doc/extending/type.rst @@ -5,7 +5,7 @@ =============== The :class:`Type` class is used to provide "static" information about the types of -:class:`Variable`\s in an Aesara graph. This information is used for graph optimizations +:class:`Variable`\s in an Aesara graph. This information is used for graph rewrites and compilation to languages with typing that's stricter than Python's. The types handled by Aesara naturally overlap a lot with NumPy, but @@ -311,7 +311,7 @@ default values. Optional. Only needed to profile the memory of this :class:`Type` of object. - :param shape_info: the output of the call to get_shape_info() + :param shape_info: the output of the call to `get_shape_info` :return: the number of bytes taken by the object described by ``shape_info``. @@ -324,8 +324,8 @@ For certain mechanisms, you can register functions and other such things to plus your type into aesara's mechanisms. These are optional but will allow people to use you type with familiar interfaces. -`transfer()` -~~~~~~~~~~~~ +`transfer` +~~~~~~~~~~ To plug in additional options for the transfer target, define a function which takes an Aesara variable and a target argument and @@ -388,7 +388,7 @@ when ``allow_downcast`` is False, i.e. no precision loss is allowed. The second method we define is ``values_eq_approx``. This method allows approximate comparison between two values respecting our :class:`Type`'s -constraints. It might happen that an optimization changes the computation +constraints. It might happen that a rewrite changes the computation graph in such a way that it produces slightly different variables, for example because of numerical instability like rounding errors at the end of the mantissa. For instance, ``a + a + a + a + a + a`` might not diff --git a/doc/extending/unittest.rst b/doc/extending/unittest.rst index fba7bf0869..8cae140f5f 100644 --- a/doc/extending/unittest.rst +++ b/doc/extending/unittest.rst @@ -13,7 +13,7 @@ stressed enough! Unit Testing revolves around the following principles: * ensuring correctness: making sure that your :class:`Op`, :class:`Type` or - optimization works in the way you intended it to work. It is important for + rewrites works in the way you intended it to work. It is important for this testing to be as thorough as possible: test not only the obvious cases, but more importantly the corner cases which are more likely to trigger bugs down the line. diff --git a/doc/faq.rst b/doc/faq.rst index 54ae02acee..a274c0bb60 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -46,28 +46,28 @@ Faster Aesara Function Compilation Aesara function compilation can be time consuming. It can be sped up by setting the flag ``mode=FAST_COMPILE`` which instructs Aesara to skip most -optimizations and disables the generation of any c/cuda code. This is useful +rewrites and disables the generation of any c/cuda code. This is useful for quickly testing a simple idea. If C code is necessary, the flag ``optimizer=fast_compile`` can be used instead. It instructs Aesara to -skip time consuming optimizations but still generate C code. +skip time consuming rewrites but still generate C code. Similarly using the flag ``optimizer_excluding=inplace`` will speed up -compilation by preventing optimizations that replace operations with a +compilation by preventing rewrites that replace operations with a version that reuses memory where it will not negatively impact the -integrity of the operation. Such optimizations can be time +integrity of the operation. Such rewrites can be time consuming. However using this flag will result in greater memory usage because space must be allocated for the results which would be unnecessary otherwise. In short, using this flag will speed up compilation but it will also use more memory because -``optimizer_excluding=inplace`` excludes inplace optimizations +``optimizer_excluding=inplace`` excludes inplace rewrites resulting in a trade off between speed of compilation and memory usage. Alternatively, if the graph is big, using the flag ``cycle_detection=fast`` will speedup the computations by removing some of the inplace -optimizations. This would allow aesara to skip a time consuming cycle +rewrites. This would allow aesara to skip a time consuming cycle detection algorithm. If the graph is big enough,we suggest that you use this flag instead of ``optimizer_excluding=inplace``. It will result in a computation time that is in between fast compile and fast run. @@ -82,23 +82,23 @@ garbage collection will keep all intermediate results' memory space to allow to reuse them during the next call to the same Aesara function, if they are of the correct shape. The shape could change if the shapes of the inputs change. -.. _unsafe_optimization: +.. _unsafe_rewrites: -Unsafe optimization -=================== +Unsafe Rewrites +=============== -Some Aesara optimizations make the assumption that the user inputs are +Some Aesara rewrites make the assumption that the user inputs are valid. What this means is that if the user provides invalid values (like incompatible shapes or indexing values that are out of bounds) and -the optimizations are applied, the user error will get lost. Most of the +the rewrites are applied, the user error will get lost. Most of the time, the assumption is that the user inputs are valid. So it is good -to have the optimization being applied, but losing the error is bad. -The newest optimization in Aesara with such assumption will add an +to have the rewrite applied, but losing the error is bad. +The newest rewrite in Aesara with such an assumption will add an assertion in the graph to keep the user error message. Computing these assertions could take some time. If you are sure everything is valid -in your graph and want the fastest possible Aesara, you can enable an -optimization that will remove those assertions with: +in your graph and want the fastest possible Aesara, you can enable a +rewrite that will remove the assertions with: ``optimizer_including=local_remove_all_assert`` diff --git a/doc/glossary.rst b/doc/glossary.rst index 757c96b0bf..b10b8c8626 100644 --- a/doc/glossary.rst +++ b/doc/glossary.rst @@ -68,13 +68,13 @@ Glossary :term:`Type`, or read more about :ref:`graphstructures`. Destructive - An :term:`Op` is destructive (of particular input[s]) if its + An :term:`Op` is destructive--of particular input(s)--if its computation requires that one or more inputs be overwritten or otherwise invalidated. For example, :term:`inplace`\ :class:`Op`\s are destructive. Destructive :class:`Op`\s can sometimes be faster than non-destructive alternatives. Aesara encourages users not to put destructive :class:`Op`\s into graphs that are given to :term:`aesara.function`, - but instead to trust the optimizations to insert destructive ops + but instead to trust the rewrites to insert destructive :class:`Op`\s judiciously. Destructive :class:`Op`\s are indicated via a :attr:`Op.destroy_map` attribute. (See @@ -90,14 +90,16 @@ Glossary every element, this is an inplace operation because when you are done, the original input has been overwritten. :class:`Op`\s representing inplace computations are :term:`destructive`, and by default these can only be - inserted by optimizations, not user code. + inserted by rewrites, not user code. Linker - Part of a function :term:`Mode` -- an object responsible for 'running' - the compiled function. Among other things, the linker determines whether computations are carried out with C or Python code. + A :class:`Linker` instance responsible for "running" the compiled + function. Among other things, the linker determines whether + computations are carried out with + C or Python code. Mode - An object providing an :term:`optimizer` and a :term:`linker` that is + A :class:`Mode` instance specifying an :term:`optimizer` and a :term:`linker` that is passed to :term:`aesara.function`. It parametrizes how an expression graph is converted to a callable object. @@ -120,12 +122,6 @@ Glossary An instance of a :term:`rewriter` that has the capacity to provide an improvement to the performance of a graph. - Optimization - A :term:`graph` transformation applied by an :term:`optimizer` during - the compilation of a :term:`graph` by :term:`aesara.function`. These - are graph rewrites that are intended to improve the performance of - a compiled :term:`Graph`. - Pure An :term:`Op` is *pure* if it has no :term:`destructive` side-effects. diff --git a/doc/index.rst b/doc/index.rst index 7ccc37522a..e487df0118 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -2,14 +2,20 @@ Welcome ======= -Aesara is a Python library that allows you to define, optimize, and -evaluate mathematical expressions involving multi-dimensional -arrays efficiently. Aesara features: - -* **Tight integration with NumPy** -- Use ``numpy.ndarray`` in Aesara-compiled functions. -* **Efficient symbolic differentiation** -- Aesara does your derivatives for functions with one or many inputs. -* **Speed and stability optimizations** -- Get the right answer for ``log(1+x)`` even when ``x`` is really tiny. -* **Dynamic C/JAX/Numba code generation** -- Evaluate expressions faster. +Aesara is a Python library that allows you to define, optimize/rewrite, and +evaluate mathematical expressions involving multi-dimensional arrays +efficiently. + +Some of Aesara's features are: + +* **Tight integration with NumPy** + - Use `numpy.ndarray` in Aesara-compiled functions +* **Efficient symbolic differentiation** + - Aesara efficiently computes your derivatives for functions with one or many inputs +* **Speed and stability optimizations** + - Get the right answer for ``log(1 + x)`` even when ``x`` is near zero +* **Dynamic C/JAX/Numba code generation** + - Evaluate expressions faster Aesara is based on `Theano`_, which has been powering large-scale computationally intensive scientific investigations since 2007. diff --git a/doc/introduction.rst b/doc/introduction.rst index 123d488a6e..d896367ede 100644 --- a/doc/introduction.rst +++ b/doc/introduction.rst @@ -5,28 +5,28 @@ Aesara at a Glance ================== -Aesara is a Python library that lets you define, optimize, and evaluate -mathematical expressions, especially ones involving multi-dimensional arrays -(e.g. :class:`numpy.ndarray`\s). Using Aesara it is -possible to attain speeds rivaling hand-crafted C implementations for problems -involving large amounts of data. +Aesara is a Python library that allows one to define, optimize/rewrite, and +evaluate mathematical expressions, especially ones involving multi-dimensional +arrays (e.g. :class:`numpy.ndarray`\s). Using Aesara, it is possible to attain +speeds rivaling hand-crafted C implementations for problems involving large +amounts of data. Aesara combines aspects of a computer algebra system (CAS) with aspects of an -optimizing compiler. It can also generate customized C code for many -mathematical operations. This combination of CAS with optimizing compilation +optimizing compiler. It can also generate customized code for multiple compiled +languages and/or their Python-based interfaces, such as C, Numba, and JAX. This +combination of CAS features with optimizing compilation and transpilation is particularly useful for tasks in which complicated mathematical expressions are evaluated repeatedly and evaluation speed is critical. For situations where many different expressions are each evaluated once, Aesara can minimize -the amount of compilation/analysis overhead, but still provide symbolic +the amount of compilation and analysis overhead, but still provide symbolic features such as automatic differentiation. -Aesara's compiler applies many optimizations of varying complexity to -these symbolic expressions. These optimizations include, but are not -limited to: +Aesara's compiler applies many default optimizations of varying +complexity. These optimizations include, but are not limited to: * constant folding -* merging of similar subgraphs, to avoid redundant calculation -* arithmetic simplification (e.g. ``x*y/x -> y``, ``--x -> x``) +* merging of similar sub-graphs, to avoid redundant calculations +* arithmetic simplifications (e.g. ``x * y / x -> y``, ``-(-x) -> x``) * inserting efficient BLAS_ operations (e.g. ``GEMM``) in a variety of contexts * using memory aliasing to avoid unnecessary calculations @@ -37,7 +37,7 @@ limited to: For more information see :ref:`optimizations`. Theano ------------------ +------ The library that Aesara is based on, Theano, was written at the LISA lab to support rapid development of efficient machine learning algorithms but while Theano was commonly referred to as a "deep learning" (DL) library, Aesara is not a DL library. diff --git a/doc/library/compile/debugmode.rst b/doc/library/compile/debugmode.rst index dcf6224067..ddd4c1895b 100644 --- a/doc/library/compile/debugmode.rst +++ b/doc/library/compile/debugmode.rst @@ -1,9 +1,9 @@ .. _debugmode: -================= +================ :mod:`debugmode` -================= +================ .. module:: aesara.compile.debugmode :platform: Unix, Windows @@ -14,16 +14,16 @@ Guide ===== -The DebugMode evaluation mode includes a number of self-checks and assertions +The `DebugMode` evaluation mode includes a number of self-checks and assertions that can help to diagnose several kinds of programmer errors that can lead to incorrect output. -It is much slower to evaluate a function or method with DebugMode than +It is much slower to evaluate a function or method with `DebugMode` than it would be in ``'FAST_RUN'`` or even ``'FAST_COMPILE'``. We recommended you use -DebugMode during development, but not when you launch 1000 processes on +`DebugMode` during development, but not when you launch 1000 processes on a cluster. -DebugMode can be used as follows: +`DebugMode` can be used as follows: .. testcode:: @@ -44,7 +44,7 @@ or passing a `DebugMode` instance, as in >>> f = aesara.function([x], 10*x, mode=DebugMode(check_c_code=False)) -If any problem is detected, DebugMode will raise an exception according to +If any problem is detected, `DebugMode` will raise an exception according to what went wrong, either at call time (``f(5)``) or compile time ( ``f = aesara.function(x, 10*x, mode='DebugMode')``). These exceptions should *not* be ignored; talk to your local Aesara guru or email the @@ -52,35 +52,35 @@ users list if you cannot make the exception go away. Some kinds of errors can only be detected for certain input value combinations. In the example above, there is no way to guarantee that a future call to say, -``f(-1)`` won't cause a problem. DebugMode is not a silver bullet. +``f(-1)`` won't cause a problem. `DebugMode` is not a silver bullet. If you use `DebugMode` by constructing a `DebugMode` object explicitly, rather than using the keyword ``mode="DebugMode"``, you can configure its behaviour via constructor arguments. Reference -========== +========= .. class:: DebugMode(Mode) - Evaluation Mode that detects internal aesara errors. + Evaluation :class:`Mode` that detects internal Aesara errors. This mode catches several kinds of internal error: - - inconsistent outputs when calling the same Op twice with the same - inputs, for instance if c_code and perform implementations, are + - inconsistent outputs when calling the same :class:`Op` twice with the same + inputs, for instance if :meth:`COp.c_code` and perform implementations, are inconsistent, or in case of incorrect handling of output memory (see `BadThunkOutput`) - a variable replacing another when their runtime values don't match. This is a symptom of - an incorrect optimization step, or faulty Op implementation (raises `BadOptimization`) + an incorrect rewrite step, or faulty :class:`Op` implementation (raises `BadOptimization`) - - stochastic optimization ordering (raises `StochasticOrder`) + - stochastic rewrite ordering (raises `StochasticOrder`) - - incomplete `destroy_map` specification (raises `BadDestroyMap`) + - incomplete :attr:`Op.destroy_map` specification (raises `BadDestroyMap`) - - an op that returns an illegal value not matching the output Variable Type (raises - InvalidValueError) + - an :class:`Op` that returns an illegal value not matching the output :class:`Variable`\ :class:`Type` (raises + :class:`InvalidValueError`) Each of these exceptions inherits from the more generic `DebugModeError`. @@ -91,7 +91,7 @@ Reference .. attribute:: stability_patience = config.DebugMode__patience - When checking for the stability of optimization, recompile the graph this many times. + When checking the stability of rewrites, recompile the graph this many times. Default 10. .. attribute:: check_c_code = config.DebugMode__check_c @@ -132,15 +132,15 @@ Reference Initialize member variables. - If any of these arguments (except optimizer) is not None, it overrides the class default. + If any of these arguments (except `optimizer`) is not ``None``, it overrides the class default. The linker arguments is not used. It is set there to allow - Mode.requiring() and some other functions to work with DebugMode too. + :meth:`Mode.requiring` and some other functions to work with `DebugMode` too. -The keyword version of DebugMode (which you get by using ``mode='DebugMode``) -is quite strict, and can raise several different Exception types. -There following are DebugMode exceptions you might encounter: +The keyword version of `DebugMode` (which you get by using ``mode='DebugMode``) +is quite strict, and can raise several different `Exception` types. +There following are `DebugMode` exceptions you might encounter: .. class:: DebugModeError(Exception) @@ -148,18 +148,18 @@ There following are DebugMode exceptions you might encounter: This is a generic error. All the other exceptions inherit from this one. This error is typically not raised directly. However, you can use ``except DebugModeError: ...`` to catch any of the more - specific types of Exception. + specific types of `Exception`\s. .. class:: BadThunkOutput(DebugModeError) - This exception means that different calls to the same Op with the same + This exception means that different calls to the same `Op` with the same inputs did not compute the same thing like they were supposed to. - For instance, it can happen if the python (``perform``) and c (``c_code``) - implementations of the Op are inconsistent (the problem might be a bug in - either ``perform`` or ``c_code`` (or both)). It can also happen if - ``perform`` or ``c_code`` does not handle correctly output memory that + For instance, it can happen if the Python (i.e. :meth:`Op.perform`) and C (i.e. :meth:`COp.c_code`) + implementations of the `Op` are inconsistent. The problem might be a bug in + either :meth:`Op.perform` or :meth:`COp.c_code` (or both). It can also happen if + :meth:`Op.perform` or :meth:`COp.c_code` does not handle correctly output memory that has been preallocated (for instance, if it did not clear the memory before accumulating into it, or if it assumed the memory layout was C-contiguous even if it is not). @@ -168,54 +168,54 @@ There following are DebugMode exceptions you might encounter: .. class:: BadOptimization(DebugModeError) - This exception indicates that an Optimization replaced one variable (say V1) - with another one (say V2) but at runtime, the values for V1 and V2 were - different. This is something that optimizations are not supposed to do. + This exception indicates that a rewrite replaced one variable (say ``V1``) + with another one (say ``V2``) but at runtime, the values for ``V1`` and ``V2`` were + different. This is something that rewrites are not supposed to do. - It can be tricky to identify the one-true-cause of an optimization error, but + It can be tricky to identify the one-true-cause of a rewrite error, but this exception provides a lot of guidance. Most of the time, the - exception object will indicate which optimization was at fault. + exception object will indicate which rewrite was at fault. The exception object also contains information such as a snapshot of the - before/after graph where the optimization introduced the error. + before/after graph where the rewrite introduced the error. .. class:: BadDestroyMap(DebugModeError) - This happens when an Op's ``perform()`` or ``c_code()`` modifies an input that it wasn't - supposed to. If either the ``perform`` or ``c_code`` implementation of an Op - might modify any input, it has to advertise that fact via the ``destroy_map`` - attribute. + This happens when an :meth:`Op.perform` or :meth:`COp.c_code` modifies an + input that it wasn't supposed to. If either the :meth:`Op.perform` or + :meth:`COp.c_code` implementation of an :class:`Op` might modify any input, it has + to advertise that fact via the :attr:`Op.destroy_map` attribute. - For detailed documentation on the ``destroy_map`` attribute, see :ref:`inplace`. + For detailed documentation on the :attr:`Op.destroy_map` attribute, see :ref:`inplace`. .. class:: BadViewMap(DebugModeError) - This happens when an Op's perform() or c_code() creates an alias or alias-like - dependency between an input and an output... and it didn't warn the - optimization system via the ``view_map`` attribute. + This happens when an :meth:`Op.perform` or :meth:`COp.c_code` creates an + alias or alias-like dependency between an input and an output, and it didn't + warn the rewrite system via the :attr:`Op.view_map` attribute. - For detailed documentation on the ``view_map`` attribute, see :ref:`views`. + For detailed documentation on the :attr:`Op.view_map` attribute, see :ref:`views`. .. class:: StochasticOrder(DebugModeError) - This happens when an optimization does not perform the same graph operations + This happens when an rewrite does not perform the same graph operations in the same order when run several times in a row. This can happen if any steps are ordered by ``id(object)`` somehow, such as via the default object - hash function. A Stochastic optimization invalidates the pattern of work - whereby we debug in DebugMode and then run the full-size jobs in FAST_RUN. + hash function. A stochastic rewrite invalidates the pattern of work + whereby we debug in `DebugMode` and then run the full-size jobs in FAST_RUN. .. class:: InvalidValueError(DebugModeError) - This happens when some Op's ``perform`` or ``c_code`` implementation computes + This happens when some :meth:`Op.perform` or :meth:`COp.c_code` implementation computes an output that is invalid with respect to the type of the corresponding output variable. Like if it returned a complex-valued ndarray for a ``dscalar`` - Type. + :class:`Type`. This can also be triggered when floating-point values such as NaN and Inf are - introduced into the computations. It indicates which Op created the first + introduced into the computations. It indicates which :class:`Op` created the first NaN. These floating-point values can be allowed by passing the - ``check_isfinite=False`` argument to DebugMode. + ``check_isfinite=False`` argument to `DebugMode`. diff --git a/doc/library/compile/function.rst b/doc/library/compile/function.rst index 8a7835517f..bf7f409798 100644 --- a/doc/library/compile/function.rst +++ b/doc/library/compile/function.rst @@ -181,7 +181,7 @@ Reference and update the implicit function arguments according to the `updates`. - Inputs can be given as variables or In instances. + Inputs can be given as variables or :class:`In` instances. :class:`In` instances also have a variable, but they attach some extra information about how call-time arguments corresponding to that variable should be used. Similarly, :class:`Out` instances can attach information @@ -189,28 +189,28 @@ Reference The default is typically 'FAST_RUN' but this can be changed in :doc:`aesara.config <../config>`. The mode - argument controls the sort of optimizations that will be applied to the - graph, and the way the optimized graph will be evaluated. + argument controls the sort of rewrites that will be applied to the + graph, and the way the rewritten graph will be evaluated. After each function evaluation, the `updates` mechanism can replace the - value of any SharedVariable [implicit] inputs with new values computed + value of any (implicit) `SharedVariable` inputs with new values computed from the expressions in the `updates` list. An exception will be raised - if you give two update expressions for the same SharedVariable input (that + if you give two update expressions for the same `SharedVariable` input (that doesn't make sense). - If a SharedVariable is not given an update expression, but has a - ``default_update`` member containing an expression, this expression + If a `SharedVariable` is not given an update expression, but has a + :attr:`Variable.default_update` member containing an expression, this expression will be used as the update expression for this variable. Passing ``no_default_updates=True`` to ``function`` disables this behavior entirely, passing ``no_default_updates=[sharedvar1, sharedvar2]`` disables it for the mentioned variables. Regarding givens: Be careful to make sure that these substitutions are - independent, because behaviour when Var1 of one pair appears in the graph leading - to Var2 in another expression is undefined (e.g. with ``{a: x, b: a + 1}``). - Replacements specified with - givens are different from optimizations in that Var2 is not expected to be - equivalent to Var1. + independent, because behaviour when ``Var1`` of one pair appears in the graph leading + to ``Var2`` in another expression is undefined (e.g. with ``{a: x, b: a + 1}``). + Replacements specified with givens are different from replacements that + occur during normal rewriting, in that ``Var2`` is not expected to be + equivalent to ``Var1``. .. autofunction:: aesara.compile.function.function_dump diff --git a/doc/library/compile/mode.rst b/doc/library/compile/mode.rst index 9af851cb89..8799f620a5 100644 --- a/doc/library/compile/mode.rst +++ b/doc/library/compile/mode.rst @@ -18,8 +18,8 @@ inputs-to-outputs graph is transformed into a callable object. Aesara defines the following modes by name: -- ``'FAST_COMPILE'``: Apply just a few graph optimizations and only use Python implementations. -- ``'FAST_RUN'``: Apply all optimizations, and use C implementations where possible. +- ``'FAST_COMPILE'``: Apply just a few graph rewrites and only use Python implementations. +- ``'FAST_RUN'``: Apply all rewrites, and use C implementations where possible. - ``'DebugMode'``: A mode for debugging. See :ref:`DebugMode ` for details. - ``'NanGuardMode``: :ref:`Nan detector ` - ``'DEBUG_MODE'``: Deprecated. Use the string DebugMode. @@ -30,7 +30,7 @@ overridden by passing the keyword argument to :func:`aesara.function`. .. TODO:: - For a finer level of control over which optimizations are applied, and whether + For a finer level of control over which rewrites are applied, and whether C or Python implementations are used, read.... what exactly? @@ -43,9 +43,9 @@ Reference .. class:: Mode(object) - Compilation is controlled by two attributes: the `optimizer` controls how - an expression graph will be transformed; the `linker` controls how the - optimized expression graph will be evaluated. + Compilation is controlled by two attributes: the :attr:`optimizer` controls how + an expression graph will be transformed; the :attr:`linker` controls how the + rewritten expression graph will be evaluated. .. attribute:: optimizer @@ -57,15 +57,15 @@ Reference .. method:: including(*tags) - Return a new Mode instance like this one, but with an - optimizer modified by including the given tags. + Return a new :class:`Mode` instance like this one, but with its + :attr:`optimizer` modified by including the given tags. .. method:: excluding(*tags) - Return a new Mode instance like this one, but with an - optimizer modified by excluding the given tags. + Return a new :class:`Mode` instance like this one, but with an + :attr:`optimizer` modified by excluding the given tags. .. method:: requiring(*tags) - Return a new Mode instance like this one, but with an - optimizer modified by requiring the given tags. + Return a new :class:`Mode` instance like this one, but with an + :attr:`optimizer` modified by requiring the given tags. diff --git a/doc/library/compile/opfromgraph.rst b/doc/library/compile/opfromgraph.rst index 20a26c0be3..9bdc1049a8 100644 --- a/doc/library/compile/opfromgraph.rst +++ b/doc/library/compile/opfromgraph.rst @@ -2,22 +2,22 @@ .. _opfromgraph: -=========== -OpFromGraph -=========== +============ +`OpFromGraph` +============ This page describes :class:`aesara.compile.builders.OpFromGraph -`, an Op that allows to -encapsulate an Aesara graph in an op. +`, an `Op` constructor that allows one to +encapsulate an Aesara graph in a single `Op`. This can be used to encapsulate some functionality in one block. It is useful to scale Aesara compilation for regular bigger graphs when we reuse that encapsulated functionality with different inputs many -times. Due to this encapsulation, it can make Aesara compilation phase +times. Due to this encapsulation, it can make Aesara's compilation phase faster for graphs with many nodes. Using this for small graphs is not recommended as it disables -optimizations between what is inside the encapsulation and outside of it. +rewrites between what is inside the encapsulation and outside of it. .. note: diff --git a/doc/library/config.rst b/doc/library/config.rst index 2ae8d27304..a6782b9679 100644 --- a/doc/library/config.rst +++ b/doc/library/config.rst @@ -170,8 +170,8 @@ import ``aesara`` and print the config variable, as in: Default: ``True`` - This enables, or disables, an optimization in :class:`Scan` that tries to - pre-allocate memory for its outputs. Enabling the optimization can give a + This enables, or disables, a rewrite in :class:`Scan` that tries to + pre-allocate memory for its outputs. Enabling the rewrite can give a significant speed up at the cost of slightly increased memory usage. .. attribute:: config.scan__allow_gc @@ -202,10 +202,10 @@ import ``aesara`` and print the config variable, as in: Default: ``off`` - This is a flag for checking the stack trace during graph optimization. + This is a flag for checking stack traces during graph rewriting. If :attr:`check_stack_trace` is set to ``off``, no check is performed on the stack trace. If :attr:`check_stack_trace` is set to ``log`` or ``warn``, a - dummy stack trace is inserted that indicates which optimization inserted the + dummy stack trace is inserted that indicates which rewrite inserted the variable that had an empty stack trace, but, when ``warn`` is set, a warning is also printed. If :attr:`check_stack_trace` is set to ``raise``, an exception is raised if a @@ -315,7 +315,7 @@ import ``aesara`` and print the config variable, as in: Default: ``False`` - When ``True``, the VM and CVM linkers profile the optimization phase when + When ``True``, the :class:`VM` and :class:`CVM` linkers profile the rewriting phase when compiling an Aesara function. This only works when ``profile=True``. .. attribute:: config.profiling__n_apply @@ -398,7 +398,7 @@ import ``aesara`` and print the config variable, as in: Default: ``'fast_run'`` - When the mode is ``'Mode'``, it sets the default optimizer used. + When the mode is ``'Mode'``, it sets the default rewrites used during compilation. .. attribute:: on_opt_error @@ -406,8 +406,8 @@ import ``aesara`` and print the config variable, as in: Default: ``'warn'`` - When a crash occurs while trying to apply an optimization, either warn the - user and skip the optimization (i.e. ``'warn'``), raise the exception + When a crash occurs while trying to apply a rewrite, either warn the + user and skip the rewrite (i.e. ``'warn'``), raise the exception (i.e. ``'raise'``), drop into the ``pdb`` debugger (i.e. ``'pdb'``), or ignore it (i.e. ``'ignore'``). We suggest never using ``'ignore'`` except during testing. @@ -503,9 +503,9 @@ import ``aesara`` and print the config variable, as in: When ``True``, add asserts that highlight shape errors. - Without such asserts, the underlying optimization could hide errors in user + Without such asserts, the underlying rewrite could hide errors in user code. Aesara adds the asserts only if it cannot infer that the shapes are - equivalent. When it can determine equivalence, this optimization does not + equivalent. When it can determine equivalence, this rewrite does not introduce an assert. Removing these asserts can speed up execution. @@ -653,11 +653,11 @@ import ``aesara`` and print the config variable, as in: Default: ``""`` - A list of optimizer tags that shouldn't be included in the default ``Mode``. + A list of rewriter tags that shouldn't be included in the default ``Mode``. If multiple tags are provided, separate them by ``':'``. - For example, to remove the ``Elemwise`` in-place optimizations, + For example, to remove the ``Elemwise`` in-place rewrites, use the flags: ``optimizer_excluding:inplace_opt``, where - ``inplace_opt`` is the name of the optimization group. + ``inplace_opt`` is the name of the rewrite group. This flag's value cannot be modified during the program execution. @@ -665,7 +665,7 @@ import ``aesara`` and print the config variable, as in: Default: ``""`` - A list of optimizer tags to be included in the default ``Mode``. + A list of rewriter tags to be included in the default ``Mode``. If multiple tags are provided, separate them by ``':'``. This flag's value cannot be modified during the program execution. @@ -674,7 +674,7 @@ import ``aesara`` and print the config variable, as in: Default: ``""`` - A list of optimizer tags that are required for optimization in the default + A list of rewriter tags that are required for rewriting in the default ``Mode``. If multiple tags are provided, separate them by ``':'``. @@ -686,7 +686,7 @@ import ``aesara`` and print the config variable, as in: Default: ``False`` - When ``True``, print the optimizations applied to stdout. + When ``True``, print the rewrites applied to stdout. .. attribute:: nocleanup @@ -792,7 +792,7 @@ import ``aesara`` and print the config variable, as in: Setting this attribute to something other than ``'off'`` activates a debugging mechanism, for which Aesara executes the graph on-the-fly, as it is being built. This allows the user to spot errors early on (such as - dimension mis-matches) **before** optimizations are applied. + dimension mis-matches) **before** rewrites are applied. Aesara will execute the graph using constants and/or shared variables provided by the user. Purely symbolic variables (e.g. ``x = @@ -809,8 +809,8 @@ import ``aesara`` and print the config variable, as in: .. attribute:: compute_test_value_opt As ``compute_test_value``, but it is the value used during Aesara's - optimization phase. This is used to help debug shape errors in Aesara's - optimizations. + rewriting phase. This is used to help debug shape errors in Aesara's + rewrites. .. attribute:: print_test_value @@ -898,21 +898,21 @@ import ``aesara`` and print the config variable, as in: Int value, default: 0 - The verbosity level of the meta-optimizer: ``0`` for silent, ``1`` to only - warn when Aesara cannot meta-optimize an :class:`Op`, ``2`` for full output (e.g. - timings and the optimizations selected). + The verbosity level of the meta-rewriter: ``0`` for silent, ``1`` to only + warn when Aesara cannot meta-rewrite an :class:`Op`, ``2`` for full output (e.g. + timings and the rewrites selected). .. attribute:: config.metaopt__optimizer_excluding Default: ``""`` - A list of optimizer tags that we don't want included in the meta-optimizer. + A list of rewrite tags that we don't want included in the meta-rewriter. Multiple tags are separate by ``':'``. .. attribute:: config.metaopt__optimizer_including Default: ``""`` - A list of optimizer tags to be included during meta-optimization. + A list of rewriter tags to be included during meta-rewriting. Multiple tags are separate by ``':'``. diff --git a/doc/library/printing.rst b/doc/library/printing.rst index 945633634b..7157c4beda 100644 --- a/doc/library/printing.rst +++ b/doc/library/printing.rst @@ -33,7 +33,7 @@ hello world __str__ = [ 1. 2. 3.] If you print more than one thing in a function like `f`, they will not necessarily be printed in the order that you think. The order might even depend -on which graph optimizations are applied. Strictly speaking, the order of +on which graph rewrites are applied. Strictly speaking, the order of printing is not completely defined by the interface -- the only hard rule is that if the input of some print output `a` is ultimately used as an input to some other print input `b` (so that `b` depends on `a`), @@ -56,7 +56,7 @@ Aesara also provides :func:`aesara.printing.pydotprint` that creates a png image >>> x = at.dscalar('x') >>> y = x ** 2 >>> gy = grad(y, x) ->>> pp(gy) # print out the gradient prior to optimization +>>> pp(gy) # print out the gradient prior to rewriting '((fill((x ** TensorConstant{2}), TensorConstant{1.0}) * TensorConstant{2}) * (x ** (TensorConstant{2} - TensorConstant{1})))' >>> f = function([x], gy) >>> pp(f.maker.fgraph.outputs[0]) diff --git a/doc/library/scan.rst b/doc/library/scan.rst index 24bca1da29..0abcbcd6ed 100644 --- a/doc/library/scan.rst +++ b/doc/library/scan.rst @@ -81,7 +81,7 @@ Scan returns a tuple containing our result (``result``) and a dictionary of updates (empty in this case). Note that the result is not a matrix, but a 3D tensor containing the value of ``A**k`` for each step. We want the last value (after ``k`` steps) so we compile -a function to return just that. Note that there is an optimization, that +a function to return just that. Note that there is a rewrite that at compile time will detect that you are using just the last value of the result and ensure that scan does not store all the intermediate values that are used. So do not worry if ``A`` and ``k`` are large. @@ -254,40 +254,35 @@ Another useful feature of scan, is that it can handle shared variables. For example, if we want to implement a Gibbs chain of length 10 we would do the following: -.. testsetup:: scan1 - - import aesara - import numpy - W_values = numpy.random.random((2, 2)) - bvis_values = numpy.random.random((2,)) - bhid_values = numpy.random.random((2,)) - .. testcode:: scan1 - import aesara - from aesara import tensor as at + import aesara + import aesara.tensor as at + import numpy as np - W = aesara.shared(W_values) # we assume that ``W_values`` contains the - # initial values of your weight matrix + rng = np.random.default_rng(203940) + W_values = rng.uniform(size=(2, 2)) + bvis_values = rng.uniform(size=(2,)) + bhid_values = rng.uniform(size=(2,)) - bvis = aesara.shared(bvis_values) - bhid = aesara.shared(bhid_values) + W = aesara.shared(W_values) + bvis = aesara.shared(bvis_values) + bhid = aesara.shared(bhid_values) - trng = aesara.tensor.random.utils.RandomStream(1234) + srng = at.random.RandomStream(1234) - def OneStep(vsample) : - hmean = at.sigmoid(aesara.dot(vsample, W) + bhid) - hsample = trng.binomial(size=hmean.shape, n=1, p=hmean) - vmean = at.sigmoid(aesara.dot(hsample, W.T) + bvis) - return trng.binomial(size=vsample.shape, n=1, p=vmean, - dtype=aesara.config.floatX) + def one_step(vsample): + hmean = at.sigmoid(at.dot(vsample, W) + bhid) + hsample = srng.binomial(1, hmean, size=hmean.shape) + vmean = at.sigmoid(at.dot(hsample, W.T) + bvis) - sample = aesara.tensor.vector() + return srng.binomial(1, vmean, size=vsample.shape) - values, updates = aesara.scan(OneStep, outputs_info=sample, n_steps=10) + sample = at.lvector() - gibbs10 = aesara.function([sample], values[-1], updates=updates) + values, updates = aesara.scan(one_step, outputs_info=sample, n_steps=10) + gibbs10 = aesara.function([sample], values[-1], updates=updates) The first, and probably most crucial observation is that the updates dictionary becomes important in this case. It links a shared variable @@ -341,7 +336,7 @@ function applied at each step) you do not need to pass them as arguments. Scan will find them on its own and add them to the graph. However, passing them to the scan function is a good practice, as it avoids Scan Op calling any earlier (external) Op over and over. This results in a -simpler computational graph, which speeds up the optimization and the +simpler computational graph, which speeds up the rewriting and the execution. To pass the shared variables to Scan you need to put them in a list and give it to the ``non_sequences`` argument. Here is the Gibbs sampling code updated: @@ -381,7 +376,7 @@ Using shared variables - the strict flag ---------------------------------------- As we just saw, passing the shared variables to scan may result in a simpler -computational graph, which speeds up the optimization and the execution. A +computational graph, which speeds up the rewriting and the execution. A good way to remember to pass every shared variable used during scan is to use the ``strict`` flag. When set to true, scan checks that all the necessary shared variables in ``fn`` are passed as explicit arguments to ``fn``. This has to be @@ -599,8 +594,8 @@ about 6x slower than the forward, a ~20% slowdown is expected. Apart from the is similar to the classic ``scan`` function. -Optimizing Scan's performance ------------------------------ +Improving Scan's performance +---------------------------- This section covers some ways to improve performance of an Aesara function using Scan. @@ -645,29 +640,29 @@ is not provided for this argument, the value of the flag ``config.scan__allow_gc`` is used). -Graph optimizations -^^^^^^^^^^^^^^^^^^^ +Graph Rewrites +^^^^^^^^^^^^^^ This one is simple but still worth pointing out. Aesara is able to -automatically recognize and optimize many computation patterns. However, there -are patterns that Aesara doesn't optimize because doing so would change the +automatically recognize and rewrite many computation patterns. However, there +are patterns that Aesara doesn't rewrite because doing so would change the user interface (such as merging shared variables together into a single one, for instance). Additionally, Aesara doesn't catch every case that it could -optimize and so it remains useful for performance that the user defines an +rewrite and so it remains useful for performance that the user defines an efficient graph in the first place. This is also the case, and sometimes even more so, for the graph inside of Scan. This is because it will be executed many times for every execution of the Aesara function that contains it. The `LSTM tutorial `_ on -`DeepLearning.net `_ provides an example of an -optimization that Aesara cannot perform. Instead of performing many matrix +`DeepLearning.net `_ provides an example of a +rewrite that Aesara cannot perform. Instead of performing many matrix multiplications between matrix :math:`x_t` and each of the shared matrices :math:`W_i`, :math:`W_c`, :math:`W_f` and :math:`W_o`, the matrices :math:`W_*`, are merged into a single shared matrix :math:`W` and the graph performs a single larger matrix multiplication between :math:`W` and :math:`x_t`. The resulting matrix is then sliced to obtain the results of that the small individual matrix multiplications would have produced. This -optimization replaces several small and inefficient matrix multiplications by +rewrite replaces several small and inefficient matrix multiplications by a single larger one and thus improves performance at the cost of a potentially higher memory usage. diff --git a/doc/library/sparse/index.rst b/doc/library/sparse/index.rst index 02943e17d8..4f0ce8dd5c 100644 --- a/doc/library/sparse/index.rst +++ b/doc/library/sparse/index.rst @@ -231,18 +231,18 @@ List of Implemented Operations - :func:`sampling_dot `. - Both inputs must be dense. - - The grad implemented is structured for `p`. + - The grad implemented is structured for ``p``. - Sample of the dot and sample of the gradient. - C code for perform but not for grad. - Returns sparse for perform and grad. - :func:`usmm `. - You *shouldn't* insert this op yourself! - - There is an optimization that transform a - :func:`dot ` to ``Usmm`` when possible. + - There is a rewrite that transforms a + :func:`dot ` to :class:`Usmm` when possible. - - This op is the equivalent of gemm for sparse dot. - - There is no grad implemented for this op. + - This :class:`Op` is the equivalent of gemm for sparse dot. + - There is no grad implemented for this :class:`Op`. - One of the inputs must be sparse, the other sparse or dense. - Returns a dense from perform. diff --git a/doc/library/tensor/basic.rst b/doc/library/tensor/basic.rst index 0eb73771f0..b9e52e6684 100644 --- a/doc/library/tensor/basic.rst +++ b/doc/library/tensor/basic.rst @@ -1199,7 +1199,7 @@ Bitwise Inplace ------- -In-place operators are *not* supported. Aesara's graph-optimizations +In-place operators are *not* supported. Aesara's graph rewrites will determine which intermediate values to use for in-place computations. If you would like to update the value of a :term:`shared variable`, consider using the ``updates`` argument to diff --git a/doc/library/tensor/basic_opt.rst b/doc/library/tensor/basic_opt.rst index 4c4f4624b4..1f871e8b69 100644 --- a/doc/library/tensor/basic_opt.rst +++ b/doc/library/tensor/basic_opt.rst @@ -1,11 +1,11 @@ -=================================================================== -:mod:`tensor.basic_opt` -- Tensor Optimizations -=================================================================== +================================================ +:mod:`tensor.rewriting.basic` -- Tensor Rewrites +================================================ -.. module:: tensor.basic_opt +.. module:: tensor.rewriting.basic :platform: Unix, Windows - :synopsis: Tensor Optimizations + :synopsis: Tensor Rewrites .. moduleauthor:: LISA, PyMC Developers, Aesara Developers -.. automodule:: aesara.tensor.basic_opt +.. automodule:: aesara.tensor.rewriting.basic :members: diff --git a/doc/library/tensor/index.rst b/doc/library/tensor/index.rst index 7374585177..427c9e4c2d 100644 --- a/doc/library/tensor/index.rst +++ b/doc/library/tensor/index.rst @@ -25,8 +25,8 @@ They are grouped into the following sections: elemwise extra_ops io - basic_opt slinalg nlinalg fft math_opt + basic_opt diff --git a/doc/library/tensor/math_opt.rst b/doc/library/tensor/math_opt.rst index 5675da77fb..ed0000b496 100644 --- a/doc/library/tensor/math_opt.rst +++ b/doc/library/tensor/math_opt.rst @@ -1,11 +1,11 @@ =================================================================== -:mod:`tensor.math_opt` -- Tensor Optimizations for Math Operations +:mod:`tensor.rewriting.math` -- Tensor Rewrites for Math Operations =================================================================== -.. module:: tensor.math_opt +.. module:: tensor.rewriting.math :platform: Unix, Windows - :synopsis: Tensor Optimizations for Math Operations + :synopsis: Tensor Rewrites for Math Operations .. moduleauthor:: LISA, PyMC Developers, Aesara Developers -.. automodule:: aesara.tensor.math_opt +.. automodule:: aesara.tensor.rewriting.math :members: diff --git a/doc/library/tensor/nnet/basic.rst b/doc/library/tensor/nnet/basic.rst index 4443947087..c9a71d97b5 100644 --- a/doc/library/tensor/nnet/basic.rst +++ b/doc/library/tensor/nnet/basic.rst @@ -61,45 +61,44 @@ .. function:: ultra_fast_sigmoid(x) - Returns the *approximated* standard :func:`sigmoid` nonlinearity applied to x. - :Parameters: *x* - symbolic Tensor (or compatible) - :Return type: same as x + Returns an approximate standard :func:`sigmoid` nonlinearity applied to ``x``. + :Parameters: ``x`` - symbolic Tensor (or compatible) + :Return type: same as ``x`` :Returns: approximated element-wise sigmoid: :math:`sigmoid(x) = \frac{1}{1 + \exp(-x)}`. - :note: To automatically change all :func:`sigmoid` ops to this version, use - the Aesara optimization ``local_ultra_fast_sigmoid``. This can be done + :note: To automatically change all :func:`sigmoid`\ :class:`Op`\s to this version, use + the Aesara rewrite `local_ultra_fast_sigmoid`. This can be done with the Aesara flag ``optimizer_including=local_ultra_fast_sigmoid``. - This optimization is done late, so it should not affect - stabilization optimization. + This rewrite is done late, so it should not affect stabilization rewrites. .. note:: The underlying code will return 0.00247262315663 as the minimum value and 0.997527376843 as the maximum value. So it never returns 0 or 1. - .. note:: Using directly the ultra_fast_sigmoid in the graph will - disable stabilization optimization associated with it. But - using the optimization to insert them won't disable the - stability optimization. + .. note:: Using directly the `ultra_fast_sigmoid` in the graph will + disable stabilization rewrites associated with it. But + using the rewrite to insert them won't disable the + stability rewrites. .. function:: hard_sigmoid(x) - Returns the *approximated* standard :func:`sigmoid` nonlinearity applied to x. - :Parameters: *x* - symbolic Tensor (or compatible) - :Return type: same as x + Returns an approximate standard :func:`sigmoid` nonlinearity applied to `1x1`. + :Parameters: ``x`` - symbolic Tensor (or compatible) + :Return type: same as ``x`` :Returns: approximated element-wise sigmoid: :math:`sigmoid(x) = \frac{1}{1 + \exp(-x)}`. - :note: To automatically change all :func:`sigmoid` ops to this version, use - the Aesara optimization ``local_hard_sigmoid``. This can be done + :note: To automatically change all :func:`sigmoid`\ :class:`Op`\s to this version, use + the Aesara rewrite `local_hard_sigmoid`. This can be done with the Aesara flag ``optimizer_including=local_hard_sigmoid``. - This optimization is done late, so it should not affect - stabilization optimization. + This rewrite is done late, so it should not affect + stabilization rewrites. .. note:: The underlying code will return an exact 0 or 1 if an - element of x is too small or too big. + element of ``x`` is too small or too big. - .. note:: Using directly the ultra_fast_sigmoid in the graph will - disable stabilization optimization associated with it. But - using the optimization to insert them won't disable the - stability optimization. + .. note:: Using directly the `ultra_fast_sigmoid` in the graph will + disable stabilization rewrites associated with it. But + using the rewrites to insert them won't disable the + stability rewrites. .. function:: softplus(x) diff --git a/doc/library/tensor/random/basic.rst b/doc/library/tensor/random/basic.rst index 6a86b4bd29..bc88aaf900 100644 --- a/doc/library/tensor/random/basic.rst +++ b/doc/library/tensor/random/basic.rst @@ -49,5 +49,113 @@ Distributions Aesara can produce :class:`RandomVariable`\s that draw samples from many different statistical distributions, using the following :class:`Op`\s. +.. autoclass:: aesara.tensor.random.basic.UniformRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.RandIntRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.IntegersRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.ChoiceRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.PermutationRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.BernoulliRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.BetaRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.BetaBinomialRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.BinomialRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.CauchyRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.CategoricalRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.ChiSquareRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.DirichletRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.ExponentialRV + :members: __call__ + .. autoclass:: aesara.tensor.random.basic.GammaRV :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.GenGammaRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.GeometricRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.GumbelRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.HalfCauchyRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.HalfNormalRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.HyperGeometricRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.InvGammaRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.LaplaceRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.LogisticRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.LogNormalRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.MultinomialRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.MvNormalRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.NegBinomialRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.NormalRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.ParetoRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.PoissonRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.StandardNormalRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.TriangularRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.TruncExponentialRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.VonMisesRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.WaldRV + :members: __call__ + +.. autoclass:: aesara.tensor.random.basic.WeibullRV + :members: __call__ diff --git a/doc/optimizations.rst b/doc/optimizations.rst index e987509170..8989dafbc9 100644 --- a/doc/optimizations.rst +++ b/doc/optimizations.rst @@ -4,7 +4,7 @@ Optimizations ============== -Aesara applies many kinds of graph optimizations, with different objectives: +Aesara applies many kinds of graph rewrites, some of which can be considered "optimizations": * simplifying and standardizing the form of the expression graph (e.g. :term:`merge`, :term:`add canonicalization` ), * reducing the maximum memory footprint (e.g. :term:`inplace_elemwise`), * increasing execution speed (e.g. :term:`constant folding`). @@ -17,9 +17,12 @@ If you would like to add an additional optimization, see :ref:`graph_rewriting`. When compiling, we can make a tradeoff between compile-time and run-time. Faster compile times will result in fewer optimizations being applied, hence generally slower run-times. -For making this tradeoff when compiling, we provide a set of 4 optimization modes, 'o1' to 'o4', where 'o1' leads to fastest compile-time and 'o4' leads to fastest run-time in general. -For an even faster run-time, we could disable assertions (which could be time consuming) for valid user inputs, using the optimization mode 'unsafe', but this is, as the name suggests, unsafe. -(Also see note at :ref:`unsafe_optimization`.) +For making this tradeoff when compiling, we provide a set of 4 optimization +modes, 'o1' to 'o4', where 'o1' leads to fastest compile-time and 'o4' leads to +fastest run-time in general. +For an even faster run-time, we could disable assertions (which could be time +consuming) for valid user inputs, using the optimization mode 'unsafe', but this +is, as the name suggests, unsafe. See :ref:`unsafe_rewrites`. .. note:: @@ -263,4 +266,4 @@ Optimization o4 o3 o2 remove all assertions in the graph for checking user inputs are valid. Use this optimization if you are sure everything is valid in your graph. - See :ref:`unsafe_optimization` + See :ref:`unsafe_rewrites` diff --git a/doc/sandbox/elemwise_compiler.rst b/doc/sandbox/elemwise_compiler.rst index 3dcd4d43c8..8c7825b7c4 100644 --- a/doc/sandbox/elemwise_compiler.rst +++ b/doc/sandbox/elemwise_compiler.rst @@ -71,7 +71,7 @@ When {{{order == f}}}, the iterators ''ideally'' (but not necessarily) iterate i {{{order}}} does __not__ represent the {{{C/F_CONTIGUOUS}}} flags of the inputs or outputs. Depending on combinations of those parameters, different loops will be used. If {{{order == f and C_CONTIGUOUS(array)}}}, for example, the loop will be on {{{dim1..dimN}}} and the matrices of lesser rank will need to be looped over several times. -An Optimizer should look at the operations in the graph and figure out whether to allocate C_CONTIGUOUS (ideal for {{{order == c}}}) or F_CONTIGUOUS (ideal for {{{order == f}}}) arrays. +An rewrite should look at the operations in the graph and figure out whether to allocate C_CONTIGUOUS (ideal for {{{order == c}}}) or F_CONTIGUOUS (ideal for {{{order == f}}}) arrays. Gradient ======== diff --git a/doc/sandbox/how_to_make_ops.rst b/doc/sandbox/how_to_make_ops.rst index a3a3d05f69..962ad6a705 100644 --- a/doc/sandbox/how_to_make_ops.rst +++ b/doc/sandbox/how_to_make_ops.rst @@ -33,27 +33,47 @@ Ideas: ``__eq__``, ``__ne__`` and ``__hash__`` --------------------------------------------- -In order for certain optimizations to apply (such as the merging of duplicate calculations by ``MergeOptimizer``), it is necessary for Ops that do the same thing to compare equal. If ``Op`` instances are generated by a function call (for example) then it can happen that several different ``Op`` instances do the same thing; in that case you will have to override ``__eq__``, ``__ne__``, and ``__hash__`` for the ``MergeOptimizer`` to recognize them as equal. +In order for certain rewrites to apply (such as the merging of duplicate +calculations by `MergeOptimizer`), it is necessary for `Op`\s that do the same +thing to compare equal. If `Op` instances are generated by a function call +(for example) then it can happen that several different `Op` instances do the +same thing; in that case you will have to override `Op.__eq__`, `Op.__ne__`, and +`Op.__hash__` for the `MergeOptimizer` to recognize them as equal. -Recall: the contract for ``__hash__`` is that ``a == b`` implies ``hash(a) == hash(b)``. +Recall: the contract for any ``__hash__`` is that ``a == b`` implies ``hash(a) == hash(b)``. -make_node -========= +:meth:`Op.make_node` +==================== -The ``make_node`` method is expected to have the following signature: +The :meth:`Op.make_node` method is expected to have the following signature: .. code-block:: python make_node(self, *inputs) -``inputs`` may be a list of anything that the user wants to provide as symbolic input (symbolic: standing for the actual values that will be passed when the graph is compiled into an executable function). [*The Aesara intro should describe symbolic in greater depth, and we should link to that from here.*] This may or may not include Variable instances (but if you want the inputs of this Op to sometimes be outputs of another Op, then the inputs should be Variable instances). [*What else could they be? Constant, Values, ...*] The return value should be an instance of [GraphStructures Apply] (see the example below). Here are the tasks typically handled in ``make_node``. +``inputs`` may be a list of anything that the user wants to provide as symbolic +input (symbolic: standing for the actual values that will be passed when the +graph is compiled into an executable function). [*The Aesara intro should +describe symbolic in greater depth, and we should link to that from here.*] This +may or may not include Variable instances (but if you want the inputs of this Op +to sometimes be outputs of another Op, then the inputs should be Variable +instances). [*What else could they be? Constant, Values, ...*] The return value +should be an instance of [GraphStructures Apply] (see the example below). Here +are the tasks typically handled in ``make_node``. * Check that the inputs are valid (type checking, etc.). [*Since we don't actually have values, what can we do besides type checking?*] * If needed, wrap the inputs in Variable instances with the proper type. * Make the Variable instances that will serve as the outputs of the node. * ``return Apply(self, , )`` -The ``inputs`` and ``outputs`` arguments to ``Apply`` must be lists of ``Variable`` instances (or instances of subclasses of ``Variable``). The inputs given to ``Apply`` do not have to be the same as the inputs passed to ``make_node``, but it is recommended that the order corresponds. [*why?*] The behavior of ``make_node`` should not depend on the structure of the graph of [*or?*] its inputs: it may look at the type and type fields of its inputs, but not at their owner field, because modifications to the graph structure do not use ``make_node``. [*???*] +The ``inputs`` and ``outputs`` arguments to ``Apply`` must be lists of +`Variable` instances (or instances of subclasses of ``Variable``). The inputs +given to `Apply` do not have to be the same as the inputs passed to +`make_node`, but it is recommended that the order corresponds. [*why?*] The +behavior of `make_node` should not depend on the structure of the graph of +[*or?*] its inputs: it may look at the type and type fields of its inputs, but +not at their owner field, because modifications to the graph structure do not +use `make_node`. Example: @@ -167,7 +187,7 @@ Advanced note: for an Op with multiple outputs, it is possible that some of them grad ==== -``grad`` is an Aesara-specific [*as opposed to?*] function - it does not interface with core optimization and compilation facilities, but it provides a useful interface to differentiation. Its expected signature is: +``grad`` is an Aesara-specific [*as opposed to?*] function - it does not interface with core rewrite and compilation facilities, but it provides a useful interface to differentiation. Its expected signature is: .. code-block:: python diff --git a/doc/sandbox/performance.rst b/doc/sandbox/performance.rst index 1d44126ebc..ffdde87f3e 100644 --- a/doc/sandbox/performance.rst +++ b/doc/sandbox/performance.rst @@ -15,7 +15,7 @@ speed improvements over basic numpy by using aesara. With a little work, Aesara could also implement more sophisticated -optimizations: +rewrites: * automatic ordering of matrix multiplications * profile-based memory layout decisions (e.g. row-major vs. col-major) diff --git a/doc/troubleshooting.rst b/doc/troubleshooting.rst index 38674bf18c..a15f0a75c5 100644 --- a/doc/troubleshooting.rst +++ b/doc/troubleshooting.rst @@ -130,12 +130,12 @@ Could lower the memory usage, but raise computation time: - Use :func:`batch_normalization() `. It use less memory then building a corresponding Aesara graph. -- Disable one or scan more optimizations: +- Disable one or scan more rewrites: - ``optimizer_excluding=scan_pushout_seqs_ops`` - ``optimizer_excluding=scan_pushout_dot1`` - ``optimizer_excluding=scan_pushout_add`` -- Disable all optimization tagged as raising memory usage: - ``optimizer_excluding=more_mem`` (currently only the 3 scan optimizations above) +- Disable all rewrites tagged as raising memory usage: + ``optimizer_excluding=more_mem`` (currently only the 3 scan rewrites above) - `float16 `_. If you want to analyze the memory usage during computation, the diff --git a/doc/tutorial/aliasing.rst b/doc/tutorial/aliasing.rst index 22f6ab6102..d2bed06039 100644 --- a/doc/tutorial/aliasing.rst +++ b/doc/tutorial/aliasing.rst @@ -21,24 +21,24 @@ There are some simple principles that guide Aesara's handling of memory. The main idea is that there is a pool of memory managed by Aesara, and Aesara tracks changes to values in that pool. -* Aesara manages its own memory space, which typically does not overlap with +- Aesara manages its own memory space, which typically does not overlap with the memory of normal Python variables that non-Aesara code creates. -* Aesara functions only modify buffers that are in Aesara's memory space. +- Aesara functions only modify buffers that are in Aesara's memory space. -* Aesara's memory space includes the buffers allocated to store ``shared`` +- Aesara's memory space includes the buffers allocated to store ``shared`` variables and the temporaries used to evaluate functions. -* Physically, Aesara's memory space may be spread across the host, a GPU +- Physically, Aesara's memory space may be spread across the host, a GPU device(s), and in the future may even include objects on a remote machine. -* The memory allocated for a ``shared`` variable buffer is unique: it is never +- The memory allocated for a ``shared`` variable buffer is unique: it is never aliased to another ``shared`` variable. -* Aesara's managed memory is constant while Aesara functions are not running +- Aesara's managed memory is constant while Aesara functions are not running and Aesara's library code is not running. -* The default behaviour of a function is to return user-space values for +- The default behaviour of a function is to return user-space values for outputs, and to expect user-space values for inputs. The distinction between Aesara-managed memory and user-managed memory can be @@ -64,9 +64,9 @@ A ``borrow`` argument can be provided to the shared-variable constructor. s_false = aesara.shared(np_array, borrow=False) s_true = aesara.shared(np_array, borrow=True) -By default (*s_default*) and when explicitly setting ``borrow=False``, the -shared variable we construct gets a [deep] copy of *np_array*. So changes we -subsequently make to *np_array* have no effect on our shared variable. +By default (``s_default``) and when explicitly setting ``borrow=False``, the +shared variable we construct gets a (deep) copy of ``np_array``. So changes we +subsequently make to ``np_array`` have no effect on our shared variable. .. testcode:: borrow @@ -84,22 +84,22 @@ subsequently make to *np_array* have no effect on our shared variable. If we are running this with the CPU as the device, -then changes we make to *np_array* *right away* will show up in +then changes we make to ``np_array`` right away will show up in ``s_true.get_value()`` -because NumPy arrays are mutable, and *s_true* is using the *np_array* +because NumPy arrays are mutable, and ``s_true`` is using the ``np_array`` object as it's internal buffer. -However, this aliasing of *np_array* and *s_true* is not guaranteed to occur, +However, this aliasing of ``np_array`` and ``s_true`` is not guaranteed to occur, and may occur only temporarily even if it occurs at all. It is not guaranteed to occur because if Aesara is using a GPU device, then the ``borrow`` flag has no effect. It may occur only temporarily because -if we call an Aesara function that updates the value of *s_true* the aliasing -relationship *may* or *may not* be broken (the function is allowed to +if we call an Aesara function that updates the value of ``s_true`` the aliasing +relationship may or may not be broken (the function is allowed to update the ``shared`` variable by modifying its buffer, which will preserve the aliasing, or by changing which buffer the variable points to, which will terminate the aliasing). -*Take home message:* +Take home message: It is a safe practice (and a good idea) to use ``borrow=True`` in a ``shared`` variable constructor when the ``shared`` variable stands for a large object (in @@ -131,7 +131,7 @@ retrieved. When ``borrow=False`` is passed to ``get_value``, it means that the return value may not be aliased to any part of Aesara's internal memory. When ``borrow=True`` is passed to ``get_value``, it means that the return value -*might* be aliased to some of Aesara's internal memory. +might be aliased to some of Aesara's internal memory. But both of these calls might create copies of the internal memory. The reason that ``borrow=True`` might still make a copy is that the internal @@ -140,7 +140,7 @@ create a ``shared`` variable by passing a NumPy array for example, then ``get_va must return a NumPy array too. That's how Aesara can make the GPU use transparent. But when you are using a GPU (or in the future perhaps a remote machine), then the numpy.ndarray is not the internal representation of your data. -If you really want Aesara to return its internal representation *and never copy it* +If you really want Aesara to return its internal representation and never copy it then you should use the ``return_internal_type=True`` argument to ``get_value``. It will never cast the internal object (always return in constant time), but might return various datatypes depending on contextual @@ -154,17 +154,17 @@ It is possible to use ``borrow=False`` in conjunction with ``return_internal_type=True``, which will return a deep copy of the internal object. This is primarily for internal debugging, not for typical use. -For the transparent use of different type of optimization Aesara can make, -there is the policy that ``get_value()`` always return by default the same object type -it received when the ``shared`` variable was created. So if you created manually data on -the gpu and create a ``shared`` variable on the gpu with this data, ``get_value`` will always -return gpu data even when ``return_internal_type=False``. +For the transparent use rewrites, there is the policy that ``get_value()`` +always return by default the same object type it received when the ``shared`` +variable was created. So if you created manually data on the gpu and create a +``shared`` variable on the gpu with this data, ``get_value`` will always return +gpu data even when ``return_internal_type=False``. -*Take home message:* +Take home message: It is safe (and sometimes much faster) to use ``get_value(borrow=True)`` when -your code does not modify the return value. *Do not use this to modify a ``shared`` -variable by side-effect* because it will make your code device-dependent. +your code does not modify the return value. Do not use this to modify a ``shared`` +variable by side-effect because it will make your code device-dependent. Modification of GPU variables through this sort of side-effect is impossible. Assigning @@ -173,7 +173,7 @@ Assigning ``Shared`` variables also have a ``set_value`` method that can accept an optional ``borrow=True`` argument. The semantics are similar to those of creating a new ``shared`` variable - ``borrow=False`` is the default and ``borrow=True`` means -that Aesara *may* reuse the buffer you provide as the internal storage for the variable. +that Aesara may reuse the buffer you provide as the internal storage for the variable. A standard pattern for manually updating the value of a ``shared`` variable is as follows: @@ -250,26 +250,25 @@ output buffer every time you call the function. It will possibly reuse the same on a previous call, and overwrite the old content. Consequently, it may overwrite old return values through side-effect. Those return values may also be overwritten in -the course of evaluating *another compiled function* (for example, the output +the course of evaluating another compiled function (for example, the output may be aliased to a ``shared`` variable). So be careful to use a borrowed return value right away before calling any more Aesara functions. -The default is of course to *not borrow* internal results. +The default is of course to not borrow internal results. It is also possible to pass a ``return_internal_type=True`` flag to the ``Out`` variable which has the same interpretation as the ``return_internal_type`` flag to the ``shared`` variable's ``get_value`` function. Unlike ``get_value()``, the combination of ``return_internal_type=True`` and ``borrow=True`` arguments to ``Out()`` are not guaranteed to avoid copying an output value. They are just -hints that give more flexibility to the compilation and optimization of the +hints that give more flexibility to the compilation and rewriting of the graph. -*Take home message:* +Take home message: -When an input *x* to a function is not needed after the function +When an input ``x`` to a function is not needed after the function returns and you would like to make it available to Aesara as -additional workspace, then consider marking it with ``In(x, -borrow=True)``. It may make the function faster and reduce its memory -requirement. When a return value *y* is large (in terms of memory -footprint), and you only need to read from it once, right away when -it's returned, then consider marking it with an ``Out(y, -borrow=True)``. +additional workspace, then consider marking it with ``In(x, borrow=True)``. It +may make the function faster and reduce its memory requirement. When a return +value ``y`` is large (in terms of memory footprint), and you only need to read +from it once, right away when it's returned, then consider marking it with an +``Out(y, borrow=True)``. diff --git a/doc/tutorial/conditions.rst b/doc/tutorial/conditions.rst index 26020062e0..7d078dd67c 100644 --- a/doc/tutorial/conditions.rst +++ b/doc/tutorial/conditions.rst @@ -72,7 +72,7 @@ Unless ``linker='vm'`` or ``linker='cvm'`` are used, ``ifelse`` will compute bot variables and take the same computation time as ``switch``. Although the linker is not currently set by default to ``cvm``, it will be in the near future. -There is no automatic optimization replacing a ``switch`` with a +There is no automatic rewrite replacing a ``switch`` with a broadcasted scalar to an ``ifelse``, as this is not always faster. See this `ticket `_. diff --git a/doc/tutorial/debug_faq.rst b/doc/tutorial/debug_faq.rst index befb048ea7..7e4d6a12e9 100644 --- a/doc/tutorial/debug_faq.rst +++ b/doc/tutorial/debug_faq.rst @@ -15,7 +15,7 @@ Isolating the Problem/Testing Aesara Compiler --------------------------------------------- You can run your Aesara function in a :ref:`DebugMode`. -This tests the Aesara optimizations and helps to find where NaN, inf and other problems come from. +This tests the Aesara rewrites and helps to find where NaN, inf and other problems come from. Interpreting Error Messages --------------------------- @@ -50,7 +50,7 @@ Running the code above we see: Inputs strides: [(8,), (8,), (8,)] Inputs scalar values: ['not scalar', 'not scalar', 'not scalar'] - HINT: Re-running with most Aesara optimization disabled could give you a back-traces when this node was created. This can be done with by setting the Aesara flags 'optimizer=fast_compile'. If that does not work, Aesara optimization can be disabled with 'optimizer=None'. + HINT: Re-running with most Aesara optimizations disabled could give you a back-traces when this node was created. This can be done with by setting the Aesara flags 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'. HINT: Use the Aesara flag 'exception_verbosity=high' for a debugprint of this apply node. Arguably the most useful information is approximately half-way through @@ -250,7 +250,7 @@ Running the code above returns the following output: x -"How do I Print an Intermediate Value in a Function?" +"How do I print an intermediate value in a function?" ----------------------------------------------------- Aesara provides a :class:`Print`\ :class:`Op` to do this. @@ -278,62 +278,60 @@ Aesara provides a :class:`Print`\ :class:`Op` to do this. this is a very important value __str__ = [ 1. 2. 3.] Since Aesara runs your program in a topological order, you won't have precise -control over the order in which multiple ``Print()`` ops are evaluated. For a more +control over the order in which multiple :class:`Print`\ `Op`\s are evaluated. For a more precise inspection of what's being computed where, when, and how, see the discussion :ref:`faq_monitormode`. .. warning:: - Using this ``Print`` Aesara Op can prevent some Aesara - optimization from being applied. This can also happen with - stability optimization. So if you use this Print and have nan, try - to remove them to know if this is the cause or not. + Using this :class:`Print`\ `Op` can prevent some Aesara rewrites from being + applied. So, if you use `Print` and the graph now returns NaNs for example, + try removing the `Print`\s to see if they're the cause or not. -"How do I Print a Graph?" (before or after compilation) +"How do I print a graph (before or after compilation)?" ------------------------------------------------------- .. TODO: dead links in the next paragraph -Aesara provides two functions (:func:`aesara.pp` and -:func:`aesara.printing.debugprint`) to print a graph to the terminal before or after +Aesara provides two functions, :func:`aesara.pp` and +:func:`aesara.printing.debugprint`, to print a graph to the terminal before or after compilation. These two functions print expression graphs in different ways: -:func:`pp` is more compact and math-like, :func:`debugprint` is more verbose. -Aesara also provides :func:`aesara.printing.pydotprint` that creates a png image of the function. +:func:`pp` is more compact and somewhat math-like, and :func:`debugprint` is more verbose and true to +the underlying graph objects being printed. +Aesara also provides :func:`aesara.printing.pydotprint` that creates a PNG image of the graph. You can read about them in :ref:`libdoc_printing`. - - -"The Function I Compiled is Too Slow, what's up?" +"The function I compiled is too slow; what's up?" ------------------------------------------------- First, make sure you're running in ``FAST_RUN`` mode. Even though ``FAST_RUN`` is the default mode, insist by passing ``mode='FAST_RUN'`` -to ``aesara.function`` (or ``aesara.make``) or by setting :attr:`config.mode` +to `aesara.function` or by setting :attr:`config.mode` to ``FAST_RUN``. Second, try the Aesara :ref:`profiling `. This will tell you which -``Apply`` nodes, and which ops are eating up your CPU cycles. +:class:`Apply` nodes, and which :class:`Op`\s are eating up your CPU cycles. Tips: -* Use the flags ``floatX=float32`` to require type *float32* instead of *float64*; - Use the Aesara constructors matrix(),vector(),... instead of dmatrix(), dvector(),... - since they respectively involve the default types *float32* and *float64*. -* Check in the ``profile`` mode that there is no ``Dot`` op in the post-compilation - graph while you are multiplying two matrices of the same type. ``Dot`` should be +* Use the flags ``floatX=float32`` to require type float32 instead of float64. + Use the Aesara constructors `matrix`, `vector`, etc., instead of `dmatrix`, `dvector`, etc., + since the latter use the default detected precision and the former use only float64. +* Check in the ``profile`` mode that there is no `Dot`\ `Op` in the post-compilation + graph while you are multiplying two matrices of the same type. `Dot` should be optimized to ``dot22`` when the inputs are matrices and of the same type. This can still happen when using ``floatX=float32`` when one of the inputs of the graph is - of type *float64*. + of type float64. .. _faq_monitormode: -"How do I Step through a Compiled Function?" +"How do I step through a compiled function?" -------------------------------------------- -You can use ``MonitorMode`` to inspect the inputs and outputs of each +You can use `MonitorMode` to inspect the inputs and outputs of each node being executed when the function is called. The code snipped below shows how to print all inputs and outputs: @@ -360,9 +358,9 @@ shows how to print all inputs and outputs: 0 Elemwise{mul,no_inplace}(TensorConstant{5.0}, x) input(s) value(s): [array(5.0), array(3.0)] output(s) value(s): [array(15.0)] When using these ``inspect_inputs`` and ``inspect_outputs`` functions -with ``MonitorMode``, you should see (potentially a lot of) printed output. -Every ``Apply`` node will be printed out, along with its position in the graph, -the arguments to the functions ``perform`` or ``c_code`` and the output it +with `MonitorMode`, you should see (potentially a lot of) printed output. +Every `Apply` node will be printed out, along with its position in the graph, +the arguments to the functions `Op.perform` or `COp.c_code` and the output it computed. Admittedly, this may be a huge amount of output to read through if you are using large tensors, but you can choose to add logic that would, for instance, print @@ -411,14 +409,13 @@ computations, which can be achieved as follows: Outputs: [array(nan)] To help understand what is happening in your graph, you can -disable the ``local_elemwise_fusion`` and all ``inplace`` -optimizations. The first is a speed optimization that merges elemwise +disable the `local_elemwise_fusion` and all in-place +rewrites. The first is a speed optimization that merges elemwise operations together. This makes it harder to know which particular -elemwise causes the problem. The second optimization makes some ops' -outputs overwrite their inputs. So, if an op creates a bad output, you +elemwise causes the problem. The second makes some `Op`\s' +outputs overwrite their inputs. So, if an `Op` creates a bad output, you will not be able to see the input that was overwritten in the ``post_func`` -function. To disable those optimizations (with an Aesara version after -0.6rc3), define the MonitorMode like this: +function. To disable those rewrites, define the `MonitorMode` like this: .. testcode:: compiled @@ -430,9 +427,9 @@ function. To disable those optimizations (with an Aesara version after .. note:: The Aesara flags ``optimizer_including``, ``optimizer_excluding`` - and ``optimizer_requiring`` aren't used by the MonitorMode, they + and ``optimizer_requiring`` aren't used by the `MonitorMode`, they are used only by the ``default`` mode. You can't use the ``default`` - mode with MonitorMode, as you need to define what you monitor. + mode with `MonitorMode`, as you need to define what you monitor. To be sure all inputs of the node are available during the call to ``post_func``, you must also disable the garbage collector. Otherwise, @@ -448,8 +445,8 @@ flag: .. TODO: documentation for link.WrapLinkerMany -How to Use pdb --------------- +How to Use ``pdb`` +------------------ In the majority of cases, you won't be executing from the interactive shell but from a set of Python scripts. In such cases, the use of the Python @@ -515,8 +512,8 @@ The call stack contains some useful information to trace back the source of the error. There's the script where the compiled function was called -- but if you're using (improperly parameterized) prebuilt modules, the error might originate from `Op`\s in these modules, not this script. The last line -tells us about the `Op` that caused the exception. In this case it's a "mul" -involving variables with names "a" and "b". But suppose we instead had an +tells us about the `Op` that caused the exception. In this case it's a ``mul`` +involving variables with names ``a`` and ``b``. But suppose we instead had an intermediate result to which we hadn't given a name. After learning a few things about the graph structure in Aesara, we can use @@ -525,7 +522,7 @@ information about the error. Matrix dimensions, especially, are useful to pinpoint the source of the error. In the printout, there are also two of the four dimensions of the matrices involved, but for the sake of example say we'd need the other dimensions to pinpoint the error. First, we re-launch with the -debugger module and run the program with "c": +debugger module and run the program with ``c``: .. code-block:: text diff --git a/doc/tutorial/faq_tutorial.rst b/doc/tutorial/faq_tutorial.rst index a686c68561..1a8b7ba46c 100644 --- a/doc/tutorial/faq_tutorial.rst +++ b/doc/tutorial/faq_tutorial.rst @@ -44,7 +44,7 @@ Defining cost which depends only on subset and not the entire lookup_table There are two ways for updating the parameters: Either use inc_subtensor or set_subtensor. It is recommended to use -inc_subtensor. Some aesara optimizations do the conversion between +inc_subtensor. Some aesara rewrites do the conversion between the two functions, but not in all cases. .. code-block:: python diff --git a/doc/tutorial/gradients.rst b/doc/tutorial/gradients.rst index d13d5714af..ecaac80812 100644 --- a/doc/tutorial/gradients.rst +++ b/doc/tutorial/gradients.rst @@ -10,13 +10,10 @@ Computing Gradients =================== Now let's use Aesara for a slightly more sophisticated task: create a -function which computes the derivative of some expression *y* with -respect to its parameter *x*. To do this we will use the macro ``at.grad``. -For instance, we can compute the -gradient of :math:`x^2` with respect to :math:`x`. Note that: -:math:`d(x^2)/dx = 2 \cdot x`. - -.. TODO: fix the vertical positioning of the expressions in the preceding paragraph +function which computes the derivative of some expression ``y`` with +respect to its parameter ``x``. To do this we will use the macro `at.grad`. +For instance, we can compute the gradient of :math:`x^2` with respect to +:math:`x`. Note that: :math:`d(x^2)/dx = 2 \cdot x`. Here is the code to compute this gradient: @@ -40,11 +37,11 @@ True In this example, we can see from ``pp(gy)`` that we are computing the correct symbolic gradient. -``fill((x ** 2), 1.0)`` means to make a matrix of the same shape as -*x* ** *2* and fill it with *1.0*. +``fill((x**2), 1.0)`` means to make a matrix of the same shape as +``x**2`` and fill it with ``1.0``. .. note:: - The optimizer simplifies the symbolic gradient expression. You can see + Aesara's rewrites simplify the symbolic gradient expression. You can see this by digging inside the internal properties of the compiled function. .. testcode:: @@ -52,8 +49,7 @@ the correct symbolic gradient. pp(f.maker.fgraph.outputs[0]) '(2.0 * x)' - After optimization there is only one Apply node left in the graph, which - doubles the input. + After rewriting, there is only one `Apply` node left in the graph. We can also compute the gradient of complex expressions such as the logistic function defined above. It turns out that the derivative of the @@ -61,8 +57,8 @@ logistic is: :math:`ds(x)/dx = s(x) \cdot (1 - s(x))`. .. figure:: dlogistic.png - A plot of the gradient of the logistic function, with *x* on the x-axis - and :math:`ds(x)/dx` on the y-axis. + A plot of the gradient of the logistic function, with :math:`x` on the x-axis + and :math:`ds(x)/dx` on the :math:`y`-axis. .. If you modify this code, also change : @@ -76,22 +72,22 @@ logistic is: :math:`ds(x)/dx = s(x) \cdot (1 - s(x))`. array([[ 0.25 , 0.19661193], [ 0.19661193, 0.10499359]]) -In general, for any **scalar** expression *s*, ``at.grad(s, w)`` provides +In general, for any **scalar** expression ``s``, ``at.grad(s, w)`` provides the Aesara expression for computing :math:`\frac{\partial s}{\partial w}`. In this way Aesara can be used for doing **efficient** symbolic differentiation -(as the expression returned by ``at.grad`` will be optimized during compilation), even for +(as the expression returned by `at.grad` will be optimized during compilation), even for function with many inputs. (see `automatic differentiation `_ for a description of symbolic differentiation). .. note:: - The second argument of ``at.grad`` can be a list, in which case the + The second argument of `at.grad` can be a list, in which case the output is also a list. The order in both lists is important: element - *i* of the output list is the gradient of the first argument of - ``at.grad`` with respect to the *i*-th element of the list given as second argument. - The first argument of ``at.grad`` has to be a scalar (a tensor + ``i`` of the output list is the gradient of the first argument of + `at.grad` with respect to the ``i``-th element of the list given as second argument. + The first argument of `at.grad` has to be a scalar (a tensor of size 1). For more information on the semantics of the arguments of - ``at.grad`` and details about the implementation, see + `at.grad` and details about the implementation, see :ref:`this` section of the library. Additional information on the inner workings of differentiation may also be @@ -100,24 +96,24 @@ of symbolic differentiation). Computing the Jacobian ====================== -In Aesara's parlance, the term *Jacobian* designates the tensor comprising the +In Aesara's parlance, the term **Jacobian** designates the tensor comprising the first partial derivatives of the output of a function with respect to its inputs. (This is a generalization of to the so-called Jacobian matrix in Mathematics.) Aesara implements the :func:`aesara.gradient.jacobian` macro that does all that is needed to compute the Jacobian. The following text explains how to do it manually. -In order to manually compute the Jacobian of some function *y* with -respect to some parameter *x* we need to use ``scan``. What we -do is to loop over the entries in *y* and compute the gradient of -*y[i]* with respect to *x*. +In order to manually compute the Jacobian of some function ``y`` with +respect to some parameter ``x`` we need to use `scan`. What we +do is to loop over the entries in ``y`` and compute the gradient of +``y[i]`` with respect to ``x``. .. note:: - ``scan`` is a generic op in Aesara that allows writing in a symbolic + `scan` is a generic op in Aesara that allows writing in a symbolic manner all kinds of recurrent equations. While creating symbolic loops (and optimizing them for performance) is a hard task, - effort is being done for improving the performance of ``scan``. We + effort is being done for improving the performance of `scan`. We shall return to :ref:`scan` later in this tutorial. >>> import aesara @@ -130,25 +126,25 @@ do is to loop over the entries in *y* and compute the gradient of array([[ 8., 0.], [ 0., 8.]]) -What we do in this code is to generate a sequence of *ints* from *0* to -``y.shape[0]`` using ``at.arange``. Then we loop through this sequence, and -at each step, we compute the gradient of element *y[i]* with respect to -*x*. ``scan`` automatically concatenates all these rows, generating a +What we do in this code is to generate a sequence of integers from ``0`` to +``y.shape[0]`` using `at.arange`. Then we loop through this sequence, and +at each step, we compute the gradient of element ``y[i]`` with respect to +``x``. `scan` automatically concatenates all these rows, generating a matrix which corresponds to the Jacobian. .. note:: - There are some pitfalls to be aware of regarding ``at.grad``. One of them is that you + There are some pitfalls to be aware of regarding `at.grad`. One of them is that you cannot re-write the above expression of the Jacobian as - ``aesara.scan(lambda y_i,x: at.grad(y_i,x), sequences=y, - non_sequences=x)``, even though from the documentation of scan this - seems possible. The reason is that *y_i* will not be a function of - *x* anymore, while *y[i]* still is. + ``aesara.scan(lambda y_i,x: at.grad(y_i,x), sequences=y, non_sequences=x)``, + even though from the documentation of scan this + seems possible. The reason is that ``y_i`` will not be a function of + ``x`` anymore, while ``y[i]`` still is. Computing the Hessian ===================== -In Aesara, the term *Hessian* has the usual mathematical meaning: It is the +In Aesara, the term **Hessian** has the usual mathematical meaning: It is the matrix comprising the second order partial derivative of a function with scalar output and vector input. Aesara implements :func:`aesara.gradient.hessian` macro that does all that is needed to compute the Hessian. The following text explains how @@ -156,7 +152,7 @@ to do it manually. You can compute the Hessian manually similarly to the Jacobian. The only difference is that now, instead of computing the Jacobian of some expression -*y*, we compute the Jacobian of ``at.grad(cost,x)``, where *cost* is some +``y``, we compute the Jacobian of ``at.grad(cost,x)``, where ``cost`` is some scalar. >>> x = at.dvector('x') @@ -179,8 +175,7 @@ doing the product, there are methods that compute the desired results while avoiding actual evaluation of the Jacobian. This can bring about significant performance gains. A description of one such algorithm can be found here: -* Barak A. Pearlmutter, "Fast Exact Multiplication by the Hessian", *Neural - Computation, 1994* +- Barak A. Pearlmutter, "Fast Exact Multiplication by the Hessian", Neural Computation, 1994 While in principle we would want Aesara to identify these patterns automatically for us, in practice, implementing such optimizations in a generic manner is extremely @@ -190,14 +185,14 @@ difficult. Therefore, we provide special functions dedicated to these tasks. R-operator ---------- -The *R operator* is built to evaluate the product between a Jacobian and a +The **R operator** is built to evaluate the product between a Jacobian and a vector, namely :math:`\frac{\partial f(x)}{\partial x} v`. The formulation -can be extended even for *x* being a matrix, or a tensor in general, case in +can be extended even for :math:`x` being a matrix, or a tensor in general, case in which also the Jacobian becomes a tensor and the product becomes some kind of tensor product. Because in practice we end up needing to compute such expressions in terms of weight matrices, Aesara supports this more generic -form of the operation. In order to evaluate the *R-operation* of -expression *y*, with respect to *x*, multiplying the Jacobian with *v* +form of the operation. In order to evaluate the R-operation of +expression ``y``, with respect to ``x``, multiplying the Jacobian with ``V`` you need to do something similar to this: >>> W = at.dmatrix('W') @@ -214,9 +209,9 @@ array([ 2., 2.]) L-operator ---------- -In similitude to the *R-operator*, the *L-operator* would compute a *row* vector times +In similitude to the R-operator, the **L-operator** would compute a row vector times the Jacobian. The mathematical formula would be :math:`v \frac{\partial -f(x)}{\partial x}`. The *L-operator* is also supported for generic tensors +f(x)}{\partial x}`. The L-operator is also supported for generic tensors (not only for vectors). Similarly, it can be implemented as follows: >>> W = at.dmatrix('W') @@ -231,12 +226,12 @@ array([[ 0., 0.], .. note:: - `v`, the *point of evaluation*, differs between the *L-operator* and the *R-operator*. - For the *L-operator*, the point of evaluation needs to have the same shape - as the output, whereas for the *R-operator* this point should + ``v``, the point of evaluation, differs between the L-operator and the R-operator. + For the L-operator, the point of evaluation needs to have the same shape + as the output, whereas for the R-operator this point should have the same shape as the input parameter. Furthermore, the results of these two - operations differ. The result of the *L-operator* is of the same shape - as the input parameter, while the result of the *R-operator* has a shape similar + operations differ. The result of the L-operator is of the same shape + as the input parameter, while the result of the R-operator has a shape similar to that of the output. :ref:`List of op with r op support `. @@ -244,7 +239,7 @@ array([[ 0., 0.], Hessian times a Vector ====================== -If you need to compute the *Hessian times a vector*, you can make use of the +If you need to compute the Hessian times a vector, you can make use of the above-defined operators to do it more efficiently than actually computing the exact Hessian and then performing the product. Due to the symmetry of the Hessian matrix, you have two options that will @@ -261,7 +256,7 @@ Hence, we suggest profiling the methods before using either one of the two: array([ 4., 4.]) -or, making use of the *R-operator*: +or, making use of the R-operator: >>> x = at.dvector('x') >>> v = at.dvector('v') @@ -277,13 +272,13 @@ Final Pointers ============== -* The ``grad`` function works symbolically: it receives and returns Aesara variables. +- The `grad` function works symbolically: it receives and returns Aesara variables. -* ``grad`` can be compared to a macro since it can be applied repeatedly. +- `grad` can be compared to a macro since it can be applied repeatedly. -* Scalar costs only can be directly handled by ``grad``. Arrays are handled through repeated applications. +- Scalar costs only can be directly handled by `grad`. Arrays are handled through repeated applications. -* Built-in functions allow to compute efficiently *vector times Jacobian* and *vector times Hessian*. +- Built-in functions allow to compute efficiently vector times Jacobian and vector times Hessian. -* Work is in progress on the optimizations required to compute efficiently the full - Jacobian and the Hessian matrix as well as the *Jacobian times vector*. +- Work is in progress on the optimizations required to compute efficiently the full + Jacobian and the Hessian matrix as well as the Jacobian times vector. diff --git a/doc/tutorial/modes.rst b/doc/tutorial/modes.rst index ca3929bf4f..795e23ba28 100644 --- a/doc/tutorial/modes.rst +++ b/doc/tutorial/modes.rst @@ -129,12 +129,12 @@ as it will be useful later on. ------------------------------------------- -Mode -==== +Default Modes +============= Every time :func:`aesara.function ` is called, the symbolic relationships between the input and output Aesara *variables* -are optimized and compiled. The way this compilation occurs +are rewritten and compiled. The way this compilation occurs is controlled by the value of the ``mode`` parameter. Aesara defines the following modes by name: @@ -166,10 +166,10 @@ short name Full constructor see :ref:`the debugging FAQ` for details. -Linkers -======= +Default Linkers +=============== -A mode is composed of 2 things: an optimizer and a linker. Some modes, +A :class:`Mode` object is composed of two things: an optimizer and a linker. Some modes, like `NanGuardMode` and `DebugMode`, add logic around the optimizer and linker. `DebugMode` uses its own linker. @@ -201,12 +201,14 @@ For more detail, see :ref:`Mode` in the library. .. _optimizers: -Optimizers -========== +Default Optimizers +================== -Aesara allows compilations with a number of predefined optimizers. -An optimizer consists of a particular set of optimizations, that speed -up execution of Aesara programs. +Aesara allows compilations with a number of predefined rewrites that are +expected to improve graph evaluation performance on average. +An optimizer is technically just a :class:`Rewriter`, or an object that +indicates a particular set of rewrites (e.g. a string used to query `optdb` for +a :class:`Rewriter`). The optimizers Aesara provides are summarized below to indicate the trade-offs one might make between compilation time and execution time. @@ -217,35 +219,33 @@ or per call to aesara functions with ``function(...mode=Mode(optimizer="name"))` ================= ============ ============== ================================================== optimizer Compile time Execution time Description ================= ============ ============== ================================================== -None "++++++" "+" Applies none of Aesara's opts -o1 (fast_compile) "+++++" "++" Applies only basic opts -o2 "++++" "+++" Applies few basic opts and some that compile fast -o3 "+++" "++++" Applies all opts except ones that compile slower -o4 (fast_run) "++" "+++++" Applies all opts -unsafe "+" "++++++" Applies all opts, and removes safety checks -stabilize "+++++" "++" Only applies stability opts +None "++++++" "+" Applies none of Aesara's rewrites +o1 (fast_compile) "+++++" "++" Applies only basic rewrites +o2 "++++" "+++" Applies few basic rewrites and some that compile fast +o3 "+++" "++++" Applies all rewrites except ones that compile slower +o4 (fast_run) "++" "+++++" Applies all rewrites +unsafe "+" "++++++" Applies all rewrites, and removes safety checks +stabilize "+++++" "++" Only applies stability rewrites ================= ============ ============== ================================================== -For a detailed list of the specific optimizations applied for each of these -optimizers, see :ref:`optimizations`. Also, see :ref:`unsafe_optimization` and +For a detailed list of the specific rewrites applied for each of these +optimizers, see :ref:`optimizations`. Also, see :ref:`unsafe_rewrites` and :ref:`faster-aesara-function-compilation` for other trade-off. .. _using_debugmode: -Using DebugMode -=============== +Using :class:`DebugMode` +======================== While normally you should use the ``FAST_RUN`` or ``FAST_COMPILE`` mode, -it is useful at first (especially when you are defining new kinds of -expressions or new optimizations) to run your code using the `DebugMode` +it is useful at first--especially when you are defining new kinds of +expressions or new rewrites--to run your code using the `DebugMode` (available via ``mode='DebugMode``). The `DebugMode` is designed to run several self-checks and assertions that can help diagnose possible programming errors leading to incorrect output. Note that -``DebugMode`` is much slower than ``FAST_RUN`` or ``FAST_COMPILE`` so -use it only during development (not when you launch 1000 processes on a -cluster!). - +`DebugMode` is much slower than ``FAST_RUN`` or ``FAST_COMPILE``, so +use it only during development. .. If you modify this code, also change : .. tests/test_tutorial.py:T_modes.test_modes_1 diff --git a/doc/tutorial/nan_tutorial.rst b/doc/tutorial/nan_tutorial.rst index 831527a3eb..e23c8f6300 100644 --- a/doc/tutorial/nan_tutorial.rst +++ b/doc/tutorial/nan_tutorial.rst @@ -85,6 +85,15 @@ this flag while debugging NaN. NaN Introduced by AllocEmpty ----------------------------------------------- -AllocEmpty is used by many operation such as scan to allocate some memory without properly clearing it. The reason for that is that the allocated memory will subsequently be overwritten. However, this can sometimes introduce NaN depending on the operation and what was previously stored in the memory it is working on. For instance, trying to zero out memory using a multiplication before applying an operation could cause NaN if NaN is already present in the memory, since `0 * NaN => NaN`. - -Using ``optimizer_including=alloc_empty_to_zeros`` replaces `AllocEmpty` by `Alloc{0}`, which is helpful to diagnose where NaNs come from. Please note that when running in `NanGuardMode`, this optimizer is not included by default. Therefore, it might be helpful to use them both together. +AllocEmpty is used by many operation such as scan to allocate some memory +without properly clearing it. The reason for that is that the allocated memory +will subsequently be overwritten. However, this can sometimes introduce NaN +depending on the operation and what was previously stored in the memory it is +working on. For instance, trying to zero out memory using a multiplication +before applying an operation could cause NaN if NaN is already present in the +memory, since `0 * NaN => NaN`. + +Using ``optimizer_including=alloc_empty_to_zeros`` replaces `AllocEmpty` by +`Alloc{0}`, which is helpful to diagnose where NaNs come from. Please note that +when running in `NanGuardMode`, this rewrite is not included by +default. Therefore, it might be helpful to use them both together. diff --git a/doc/tutorial/printing_drawing.rst b/doc/tutorial/printing_drawing.rst index 408821cf6d..2323ee0560 100644 --- a/doc/tutorial/printing_drawing.rst +++ b/doc/tutorial/printing_drawing.rst @@ -15,9 +15,8 @@ that creates an image of the function. You can read about them in .. note:: - When printing Aesara functions, they can sometimes be hard to - read. To help with this, you can disable some Aesara optimizations + read. To help with this, you can disable some Aesara rewrites by using the Aesara flag: ``optimizer_excluding=fusion:inplace``. Do not use this during real job execution, as this will make the graph slower and use more diff --git a/doc/tutorial/profiling.rst b/doc/tutorial/profiling.rst index 353df7d722..3e9c3f732c 100644 --- a/doc/tutorial/profiling.rst +++ b/doc/tutorial/profiling.rst @@ -17,7 +17,7 @@ of the following two options: 1. Use the Aesara flag :attr:`config.profile` to enable profiling. - To enable the memory profiler use the Aesara flag: :attr:`config.profile_memory` in addition to :attr:`config.profile`. - - Moreover, to enable the profiling of Aesara optimization phases, + - Moreover, to enable the profiling of Aesara rewrite phases, use the Aesara flag: :attr:`config.profile_optimizer` in addition to :attr:`config.profile`. - You can also use the Aesara flags :attr:`profiling__n_apply`, @@ -55,7 +55,7 @@ calls. The time spent in :meth:`Function.vm.__call__` and in thunks is useful to understand Aesara's overhead. Also, we see the time spent in the two parts of the compilation process: -optimization (i.e. modifying the graph to make it more stable/faster) and the +rewriting (i.e. modifying the graph to make it more stable/faster) and the linking (i.e. compile C code and make the Python callable returned by :func:`aesara.function`). @@ -73,10 +73,10 @@ implementation. Developers wishing to optimize the performance of their graph should focus on the worst offending `Op`\s and `Apply` nodes--either by optimizing an implementation, providing a missing C implementation, or by writing -a graph optimization that eliminates the offending `Op` altogether. +a graph rewrite that eliminates the offending `Op` altogether. -Here is some example output when some Aesara optimizations are disabled. With -all optimizations enabled, there would be only one `Op` left in the graph. +Here is some example output when Aesara's rewrites are disabled. With all +rewrites enabled, there would be only one `Op` remaining in the graph. To run the example: diff --git a/doc/tutorial/profiling_example_out.prof b/doc/tutorial/profiling_example_out.prof index 4bd526577d..2d7c292c38 100644 --- a/doc/tutorial/profiling_example_out.prof +++ b/doc/tutorial/profiling_example_out.prof @@ -5,7 +5,7 @@ Function profiling Time in Function.vm.__call__: 1.192093e-05s (20.921%) Time in thunks: 6.198883e-06s (10.879%) Total compile time: 3.642474e+00s - Aesara Optimizer time: 7.326508e-02s + Aesara rewrite time: 7.326508e-02s Aesara validate time: 3.712177e-04s Aesara Linker time (includes C, CUDA code generation/compiling): 9.584920e-01s diff --git a/doc/tutorial/shape_info.rst b/doc/tutorial/shape_info.rst index 8c189662b4..002a48c2fc 100644 --- a/doc/tutorial/shape_info.rst +++ b/doc/tutorial/shape_info.rst @@ -112,13 +112,11 @@ this example, the computation of the shape of the output of ``join`` is done onl based on the first input Aesara variable, which leads to an error. This might happen with other `Op`\s such as :class:`Elemwise` and :class:`Dot`, for example. -Indeed, to perform some optimizations (for speed or stability, for instance), +Indeed, to perform some optimizations/rewrites (for speed or stability, for instance), Aesara assumes that the computation is correct and consistent in the first place, as it does here. -You can detect those problems by running the code without this -optimization, using the Aesara flag -``optimizer_excluding=local_shape_to_shape_i``. You can also obtain the -same effect by running in the modes ``FAST_COMPILE`` (it will not apply this -optimization, nor most other optimizations) or :class:`DebugMode` (it will test -before and after all optimizations). +You can detect those problems by running the code without this optimization, +using the Aesara flag ``optimizer_excluding=local_shape_to_shape_i``. You can +also obtain the same effect by running in the modes ``FAST_COMPILE`` or +:class:`DebugMode`. diff --git a/readthedocs.yml b/readthedocs.yml index a75dc548a2..4cb32ad57d 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -3,3 +3,7 @@ sphinx: configuration: doc/conf.py conda: environment: doc/environment.yml +build: + os: "ubuntu-20.04" + tools: + python: "mambaforge-4.10" diff --git a/setup.cfg b/setup.cfg index 68d01e2545..e7351fea96 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,91 @@ +[metadata] +name = aesara +author = aesara-devs +author_email = aesara.devs@gmail.com +description = Optimizing compiler for evaluating mathematical expressions on CPUs and GPUs. +long_description = file: DESCRIPTION.txt +long_description_content_type = text/x-rst +url = https://github.com/aesara-devs/aesara +license = BSD +platforms = + Windows + Linux + Solaris + Mac OS-X + Unix +classifiers = + Development Status :: 6 - Mature + Intended Audience :: Education + Intended Audience :: Science/Research + Intended Audience :: Developers + License :: OSI Approved :: BSD License + Programming Language :: Python + Topic :: Software Development :: Code Generators + Topic :: Software Development :: Compilers + Topic :: Scientific/Engineering :: Mathematics + Operating System :: Microsoft :: Windows + Operating System :: POSIX + Operating System :: Unix + Operating System :: MacOS + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 + Programming Language :: Python :: 3.9 +keywords = + aesara + math + numerical + symbolic + blas + numpy + autodiff + differentiation + +[options] +packages = find: +python_requires = >=3.7 +install_requires = + numpy >=1.17.0 + scipy >=0.14 + filelock + etuples + logical-unification + miniKanren + cons + typing_extensions + setuptools >=48.0.0 + +[options.packages.find] +exclude = + tests + tests.* + +[options.entry_points] +console_scripts = + aesara-cache = bin.aesara_cache:main + +[options.package_data] +* = + *.txt + *.rst + *.cu + *.cuh + *.c + *.sh + *.pkl + *.h + *.cpp + ChangeLog + c_code/* +aesara = + py.typed +aesara.misc = + *.sh +aesara.d3viz = + html/* + css/* + js/* + [flake8] select = C,E,F,W ignore = E203,E231,E501,E741,W503,W504,C901 @@ -17,7 +105,7 @@ per-file-ignores = tests/sparse/test_utils.py:E402,F401 tests/sparse/sandbox/test_sp.py:E402,F401 tests/scalar/test_basic_sympy.py:E402 - aesara/graph/unify.py:F811 + aesara/graph/rewriting/unify.py:F811 exclude = versioneer.py doc/ @@ -28,6 +116,11 @@ omit = aesara/_version.py tests/* aesara/assert_op.py + aesara/graph/opt.py + aesara/graph/opt_utils.py + aesara/graph/optdb.py + aesara/graph/kanren.py + aesara/graph/unify.py aesara/link/jax/jax_linker.py aesara/link/jax/jax_dispatch.py aesara/graph/toolbox.py @@ -67,7 +160,8 @@ lines_after_imports = 2 lines_between_sections = 1 honor_noqa = True skip_gitignore = True -skip = aesara/version.py, **/__init__.py +skip = aesara/version.py +skip_glob = **/*.pyx [mypy] ignore_missing_imports = True @@ -152,10 +246,16 @@ check_untyped_defs = False ignore_errors = True check_untyped_defs = False -[mypy-aesara.tensor.basic_opt] +[mypy-aesara.tensor.rewriting.basic] ignore_errors = True check_untyped_defs = False +[mypy-aesara.tensor.rewriting.shape] +warn_unused_ignores = False + +[mypy-aesara.tensor.rewriting.elemwise] +warn_unused_ignores = False + [mypy-aesara.tensor.subtensor] ignore_errors = True check_untyped_defs = False @@ -188,7 +288,7 @@ check_untyped_defs = False ignore_errors = True check_untyped_defs = False -[mypy-aesara.tensor.math_opt] +[mypy-aesara.tensor.rewriting.math] ignore_errors = True check_untyped_defs = False diff --git a/setup.py b/setup.py index ab7821bb5d..51d986c97e 100755 --- a/setup.py +++ b/setup.py @@ -1,65 +1,20 @@ #!/usr/bin/env python import os -from setuptools import find_packages, setup +from setuptools import setup +from setuptools.dist import Distribution import versioneer -def read_file(filename): - with open(filename, "rt") as buff: - return buff.read() +dist = Distribution() +dist.parse_config_files() -NAME = "aesara" -MAINTAINER = "Aesara developers" -MAINTAINER_EMAIL = "aesara.devs@gmail.com" -DESCRIPTION = ( - "Optimizing compiler for evaluating mathematical expressions on CPUs and GPUs." -) -LONG_DESCRIPTION = read_file("DESCRIPTION.txt") -URL = "https://github.com/aesara-devs/aesara" -LICENSE = "BSD" -AUTHOR = "aesara-devs" -AUTHOR_EMAIL = "aesara.devs@gmail.com" -PLATFORMS = ["Windows", "Linux", "Solaris", "Mac OS-X", "Unix"] -CLASSIFIERS = """\ -Development Status :: 6 - Mature -Intended Audience :: Education -Intended Audience :: Science/Research -Intended Audience :: Developers -License :: OSI Approved :: BSD License -Programming Language :: Python -Topic :: Software Development :: Code Generators -Topic :: Software Development :: Compilers -Topic :: Scientific/Engineering :: Mathematics -Operating System :: Microsoft :: Windows -Operating System :: POSIX -Operating System :: Unix -Operating System :: MacOS -Programming Language :: Python :: 3 -Programming Language :: Python :: 3.7 -Programming Language :: Python :: 3.8 -Programming Language :: Python :: 3.9 -""" -CLASSIFIERS = [_f for _f in CLASSIFIERS.split("\n") if _f] - -install_requires = [ - "numpy>=1.17.0", - "scipy>=0.14", - "filelock", - "etuples", - "logical-unification", - "miniKanren", - "cons", - "typing_extensions", - "setuptools>=45.0.0", -] - +NAME: str = dist.get_name() # type: ignore # Handle builds of nightly release if "BUILD_AESARA_NIGHTLY" in os.environ: - nightly = True NAME += "-nightly" from versioneer import get_versions as original_get_versions @@ -80,50 +35,4 @@ def get_versions(): name=NAME, version=versioneer.get_version(), cmdclass=versioneer.get_cmdclass(), - description=DESCRIPTION, - long_description=LONG_DESCRIPTION, - long_description_content_type="text/x-rst", - classifiers=CLASSIFIERS, - author=AUTHOR, - author_email=AUTHOR_EMAIL, - url=URL, - license=LICENSE, - platforms=PLATFORMS, - packages=find_packages(exclude=["tests", "tests.*"]), - install_requires=install_requires, - package_data={ - "": [ - "*.txt", - "*.rst", - "*.cu", - "*.cuh", - "*.c", - "*.sh", - "*.pkl", - "*.h", - "*.cpp", - "ChangeLog", - "c_code/*", - ], - "aesara": ["py.typed"], - "aesara.misc": ["*.sh"], - "aesara.d3viz": ["html/*", "css/*", "js/*"], - }, - entry_points={ - "console_scripts": [ - "aesara-cache = bin.aesara_cache:main", - ] - }, - keywords=" ".join( - [ - "aesara", - "math", - "numerical", - "symbolic", - "blas", - "numpy", - "autodiff", - "differentiation", - ] - ), ) diff --git a/tests/compile/function/test_types.py b/tests/compile/function/test_types.py index 185bcbaa05..e103e92b9b 100644 --- a/tests/compile/function/test_types.py +++ b/tests/compile/function/test_types.py @@ -13,7 +13,7 @@ from aesara.compile.mode import Mode, get_default_mode from aesara.configdefaults import config from aesara.graph.basic import Constant -from aesara.graph.opt import OpKeyOptimizer, PatternSub +from aesara.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter from aesara.graph.utils import MissingInputError from aesara.link.vm import VMLinker from aesara.tensor.math import dot @@ -35,7 +35,7 @@ def PatternOptimizer(p1, p2, ign=True): - return OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign) + return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) class TestFunction: diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index 2038b93666..b770121134 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -12,16 +12,16 @@ from aesara.graph.basic import equal_computations from aesara.graph.fg import FunctionGraph from aesara.graph.null_type import NullType -from aesara.graph.opt_utils import optimize_graph +from aesara.graph.rewriting.utils import rewrite_graph from aesara.graph.utils import MissingInputError from aesara.printing import debugprint from aesara.tensor.basic import as_tensor -from aesara.tensor.basic_opt import ShapeOptimizer from aesara.tensor.math import dot, exp from aesara.tensor.math import round as at_round from aesara.tensor.math import sigmoid from aesara.tensor.math import sum as at_sum from aesara.tensor.random.utils import RandomStream +from aesara.tensor.rewriting.shape import ShapeOptimizer from aesara.tensor.shape import specify_shape from aesara.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors from tests import unittest_tools @@ -455,7 +455,7 @@ def test_infer_shape(self): op_var = op_graph(x, y, z) fg = FunctionGraph(outputs=[op_var[1]], clone=False) - opt_res = optimize_graph(fg, custom_opt=ShapeOptimizer()) + opt_res = rewrite_graph(fg, custom_rewrite=ShapeOptimizer()) assert opt_res.shape_feature.shape_of[x] is None assert opt_res.shape_feature.shape_of[z][0].data == 2 diff --git a/tests/compile/test_debugmode.py b/tests/compile/test_debugmode.py index afd6a51b3f..e724cfd50a 100644 --- a/tests/compile/test_debugmode.py +++ b/tests/compile/test_debugmode.py @@ -18,8 +18,8 @@ from aesara.graph.basic import Apply, Variable from aesara.graph.features import BadOptimization from aesara.graph.op import Op -from aesara.graph.opt import local_optimizer -from aesara.graph.optdb import EquilibriumDB +from aesara.graph.rewriting.basic import node_rewriter +from aesara.graph.rewriting.db import EquilibriumDB from aesara.link.c.op import COp from aesara.tensor.math import add, dot, log from aesara.tensor.type import TensorType, dvector, fmatrix, fvector, scalar, vector @@ -237,7 +237,7 @@ def test_badthunkoutput(): def test_badoptimization(): - @local_optimizer([add]) + @node_rewriter([add]) def insert_broken_add(fgraph, node): if node.op == add: return [off_by_half(*node.inputs)] @@ -263,7 +263,7 @@ def insert_broken_add(fgraph, node): def test_badoptimization_opt_err(): # This variant of test_badoptimization() replace the working code # with a new apply node that will raise an error. - @local_optimizer([add]) + @node_rewriter([add]) def insert_bigger_b_add(fgraph, node): if node.op == add: inputs = list(node.inputs) @@ -272,7 +272,7 @@ def insert_bigger_b_add(fgraph, node): return [node.op(*inputs)] return False - @local_optimizer([add]) + @node_rewriter([add]) def insert_bad_dtype(fgraph, node): if node.op == add: inputs = list(node.inputs) @@ -326,7 +326,7 @@ def test_stochasticoptimization(): last_time_replaced = [False] - @local_optimizer([add]) + @node_rewriter([add]) def insert_broken_add_sometimes(fgraph, node): if node.op == add: last_time_replaced[0] = not last_time_replaced[0] diff --git a/tests/compile/test_mode.py b/tests/compile/test_mode.py index 0147232a8e..0c19dc3edc 100644 --- a/tests/compile/test_mode.py +++ b/tests/compile/test_mode.py @@ -1,18 +1,18 @@ from aesara.compile.function import function from aesara.compile.mode import AddFeatureOptimizer, Mode from aesara.graph.features import NoOutputFromInplace -from aesara.graph.optdb import OptimizationQuery, SequenceDB +from aesara.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB from aesara.tensor.math import dot, tanh from aesara.tensor.type import matrix def test_Mode_basic(): db = SequenceDB() - mode = Mode(linker="py", optimizer=OptimizationQuery(include=None), db=db) + mode = Mode(linker="py", optimizer=RewriteDatabaseQuery(include=None), db=db) assert mode.optdb is db - assert str(mode).startswith("Mode(linker=py, optimizer=OptimizationQuery") + assert str(mode).startswith("Mode(linker=py, optimizer=RewriteDatabaseQuery") def test_NoOutputFromInplace(): diff --git a/tests/compile/test_shared.py b/tests/compile/test_shared.py index 23ab32bfdd..49058a7fee 100644 --- a/tests/compile/test_shared.py +++ b/tests/compile/test_shared.py @@ -166,7 +166,7 @@ def f(var, val): with pytest.raises(TypeError): f(b, 8) - b = shared(np.float(7.234), strict=True) + b = shared(float(7.234), strict=True) assert b.type == dscalar with pytest.raises(TypeError): f(b, 8) @@ -214,8 +214,8 @@ def f(var, val): with pytest.raises(TypeError): f(b, 8) - # np.float([7.234]) don't work - # b = shared(np.float([7.234]), strict=True) + # float([7.234]) don't work + # b = shared(float([7.234]), strict=True) # assert b.type == dvector # with pytest.raises(TypeError): # f(b, 8) @@ -273,7 +273,7 @@ def f(var, val): f(b, 8) assert b.get_value() == 8 - b = shared(np.float(7.234), allow_downcast=True) + b = shared(float(7.234), allow_downcast=True) assert b.type == dscalar f(b, 8) assert b.get_value() == 8 @@ -321,8 +321,8 @@ def f(var, val): f(b, [8]) assert b.get_value() == 8 - # np.float([7.234]) don't work - # b = shared(np.float([7.234])) + # float([7.234]) don't work + # b = shared(float([7.234])) # assert b.type == dvector # f(b,[8]) diff --git a/tests/graph/rewriting/__init__.py b/tests/graph/rewriting/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/graph/test_opt.py b/tests/graph/rewriting/test_basic.py similarity index 73% rename from tests/graph/test_opt.py rename to tests/graph/rewriting/test_basic.py index 84e07afc2a..e68d42b9cb 100644 --- a/tests/graph/test_opt.py +++ b/tests/graph/rewriting/test_basic.py @@ -1,3 +1,5 @@ +import sys + import pytest from aesara.configdefaults import config @@ -5,24 +7,24 @@ from aesara.graph.features import Feature from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op -from aesara.graph.opt import ( - EquilibriumOptimizer, - LocalOptGroup, - LocalOptTracker, +from aesara.graph.rewriting.basic import ( + EquilibriumGraphRewriter, MergeOptimizer, - OpKeyOptimizer, - OpSub, - PatternSub, - TopoOptimizer, + OpKeyGraphRewriter, + OpToRewriterTracker, + PatternNodeRewriter, + SequentialNodeRewriter, + SubstitutionNodeRewriter, + WalkingGraphRewriter, in2out, - local_optimizer, logging, + node_rewriter, pre_constant_merge, - pre_greedy_local_optimizer, + pre_greedy_node_rewriter, ) from aesara.raise_op import assert_op -from aesara.tensor.basic_opt import constant_folding from aesara.tensor.math import Dot, add, dot +from aesara.tensor.rewriting.basic import constant_folding from aesara.tensor.subtensor import AdvancedSubtensor from aesara.tensor.type import matrix, values_eq_approx_always_true from aesara.tensor.type_other import MakeSlice, SliceConstant, slicetype @@ -50,40 +52,42 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None): raise AssertionError() -def PatternOptimizer(p1, p2, ign=False): - return OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign) +def OpKeyPatternNodeRewriter(p1, p2, ign=False): + return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) -def TopoPatternOptimizer(p1, p2, ign=True): - return TopoOptimizer(PatternSub(p1, p2), ignore_newtrees=ign) +def WalkingPatternNodeRewriter(p1, p2, ign=True): + return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) -class TestPatternOptimizer: +class TestPatternNodeRewriter: def test_replace_output(self): # replacing the whole graph x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) - PatternOptimizer((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).optimize(g) + OpKeyPatternNodeRewriter((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).rewrite( + g + ) assert str(g) == "FunctionGraph(Op4(z, y))" def test_nested_out_pattern(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(x, y) g = FunctionGraph([x, y, z], [e]) - PatternOptimizer( + OpKeyPatternNodeRewriter( (op1, "1", "2"), (op4, (op1, "1"), (op2, "2"), (op3, "1", "2")) - ).optimize(g) + ).rewrite(g) assert str(g) == "FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))" def test_unification_1(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, x), z) # the arguments to op2 are the same g = FunctionGraph([x, y, z], [e]) - PatternOptimizer( + OpKeyPatternNodeRewriter( (op1, (op2, "1", "1"), "2"), # they are the same in the pattern (op4, "2", "1"), - ).optimize(g) + ).rewrite(g) # So the replacement should occur assert str(g) == "FunctionGraph(Op4(z, x))" @@ -91,10 +95,10 @@ def test_unification_2(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), z) # the arguments to op2 are different g = FunctionGraph([x, y, z], [e]) - PatternOptimizer( + OpKeyPatternNodeRewriter( (op1, (op2, "1", "1"), "2"), # they are the same in the pattern (op4, "2", "1"), - ).optimize(g) + ).rewrite(g) # The replacement should NOT occur assert str(g) == "FunctionGraph(Op1(Op2(x, y), z))" @@ -103,7 +107,7 @@ def test_replace_subgraph(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) - PatternOptimizer((op2, "1", "2"), (op1, "2", "1")).optimize(g) + OpKeyPatternNodeRewriter((op2, "1", "2"), (op1, "2", "1")).rewrite(g) assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))" def test_no_recurse(self): @@ -113,7 +117,7 @@ def test_no_recurse(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), z) g = FunctionGraph([x, y, z], [e]) - PatternOptimizer((op2, "1", "2"), (op2, "2", "1"), ign=True).optimize(g) + OpKeyPatternNodeRewriter((op2, "1", "2"), (op2, "2", "1"), ign=True).rewrite(g) assert str(g) == "FunctionGraph(Op1(Op2(y, x), z))" def test_multiple(self): @@ -121,40 +125,40 @@ def test_multiple(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), op2(x, y), op2(y, z)) g = FunctionGraph([x, y, z], [e]) - PatternOptimizer((op2, "1", "2"), (op4, "1")).optimize(g) + OpKeyPatternNodeRewriter((op2, "1", "2"), (op4, "1")).rewrite(g) assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))" def test_nested_even(self): - # regardless of the order in which we optimize, this + # regardless of the order in which we rewrite, this # should work x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(op1(x)))) g = FunctionGraph([x, y, z], [e]) - PatternOptimizer((op1, (op1, "1")), "1").optimize(g) + OpKeyPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g) assert str(g) == "FunctionGraph(x)" def test_nested_odd(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(op1(op1(x))))) g = FunctionGraph([x, y, z], [e]) - PatternOptimizer((op1, (op1, "1")), "1").optimize(g) + OpKeyPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g) assert str(g) == "FunctionGraph(Op1(x))" def test_expand(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(x))) g = FunctionGraph([x, y, z], [e]) - PatternOptimizer((op1, "1"), (op2, (op1, "1")), ign=True).optimize(g) + OpKeyPatternNodeRewriter((op1, "1"), (op2, (op1, "1")), ign=True).rewrite(g) assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))" def test_ambiguous(self): - # this test should always work with TopoOptimizer and the + # this test should always work with WalkingGraphRewriter and the # ignore_newtrees flag set to False. Behavior with ignore_newtrees - # = True or with other NavigatorOptimizers may differ. + # = True or with other NodeProcessingGraphRewriters may differ. x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(op1(op1(x))))) g = FunctionGraph([x, y, z], [e]) - TopoPatternOptimizer((op1, (op1, "1")), (op1, "1"), ign=False).optimize(g) + WalkingPatternNodeRewriter((op1, (op1, "1")), (op1, "1"), ign=False).rewrite(g) assert str(g) == "FunctionGraph(Op1(x))" def test_constant(self): @@ -163,7 +167,7 @@ def test_constant(self): z = Constant(MyType(), 2, name="z") e = op1(op1(x, y), y) g = FunctionGraph([y], [e]) - PatternOptimizer((op1, z, "1"), (op2, "1", z)).optimize(g) + OpKeyPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g) assert str(g) == "FunctionGraph(Op1(Op2(y, z), y))" def test_constraints(self): @@ -175,16 +179,16 @@ def constraint(r): # Only replacing if the input is an instance of Op2 return r.owner.op == op2 - PatternOptimizer( + OpKeyPatternNodeRewriter( (op1, {"pattern": "1", "constraint": constraint}), (op3, "1") - ).optimize(g) + ).rewrite(g) assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))" def test_match_same(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(x, x) g = FunctionGraph([x, y, z], [e]) - PatternOptimizer((op1, "x", "y"), (op3, "x", "y")).optimize(g) + OpKeyPatternNodeRewriter((op1, "x", "y"), (op3, "x", "y")).rewrite(g) assert str(g) == "FunctionGraph(Op3(x, x))" @pytest.mark.xfail( @@ -199,9 +203,9 @@ def constraint(r): # Only replacing if the input is an instance of Op2 return r.owner.inputs[0] is not r.owner.inputs[1] - PatternOptimizer( + OpKeyPatternNodeRewriter( {"pattern": (op1, "x", "y"), "constraint": constraint}, (op3, "x", "y") - ).optimize(g) + ).rewrite(g) assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))" def test_allow_multiple_clients(self): @@ -210,7 +214,7 @@ def test_allow_multiple_clients(self): # `e0` has multiple clients (i.e. the `op4` and `op3` nodes) e = op3(op4(e0), e0) g = FunctionGraph([x, y, z], [e]) - PatternOptimizer((op4, (op1, "x", "y")), (op3, "x", "y")).optimize(g) + OpKeyPatternNodeRewriter((op4, (op1, "x", "y")), (op3, "x", "y")).rewrite(g) assert str(g) == "FunctionGraph(Op3(Op4(*1 -> Op1(x, y)), *1))" def test_eq(self): @@ -218,28 +222,30 @@ def test_eq(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op_y(x, y), z) g = FunctionGraph([x, y, z], [e]) - PatternOptimizer((op1, (op_z, "1", "2"), "3"), (op4, "3", "2")).optimize(g) + OpKeyPatternNodeRewriter((op1, (op_z, "1", "2"), "3"), (op4, "3", "2")).rewrite( + g + ) str_g = str(g) assert str_g == "FunctionGraph(Op4(z, y))" -def OpSubOptimizer(op1, op2): - return OpKeyOptimizer(OpSub(op1, op2)) +def KeyedSubstitutionNodeRewriter(op1, op2): + return OpKeyGraphRewriter(SubstitutionNodeRewriter(op1, op2)) -class TestOpSubOptimizer: +class TestSubstitutionNodeRewriter: def test_straightforward(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op1(op1(op1(op1(x))))) g = FunctionGraph([x, y, z], [e]) - OpSubOptimizer(op1, op2).optimize(g) + KeyedSubstitutionNodeRewriter(op1, op2).rewrite(g) assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))" def test_straightforward_2(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x), op3(y), op4(z)) g = FunctionGraph([x, y, z], [e]) - OpSubOptimizer(op3, op4).optimize(g) + KeyedSubstitutionNodeRewriter(op3, op4).rewrite(g) assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))" @@ -261,7 +267,7 @@ def test_straightforward(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op2(x, y), op2(x, y), op2(x, z)) g = FunctionGraph([x, y, z], [e], clone=False) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) out_var = g.outputs[0] var_1, var_2, var_3 = out_var.owner.inputs assert var_1 is var_2 @@ -273,7 +279,7 @@ def test_constant_merging(self): z = Constant(MyType(), 2, name="z") e = op1(op2(x, y), op2(x, y), op2(x, z)) g = FunctionGraph([x, y, z], [e], clone=False) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) out_var = g.outputs[0] var_1, var_2, var_3 = out_var.owner.inputs assert var_1 is var_2 @@ -283,7 +289,7 @@ def test_deep_merge(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e = op1(op3(op2(x, y), z), op4(op3(op2(x, y), z))) g = FunctionGraph([x, y, z], [e], clone=False) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) out_var = g.outputs[0] var_1, var_2 = out_var.owner.inputs assert var_2.owner.inputs[0] is var_1 @@ -293,14 +299,14 @@ def test_no_merge(self): e = op1(op3(op2(x, y)), op3(op2(y, x))) g = FunctionGraph([x, y, z], [e]) g.attach_feature(AssertNoChanges()) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) def test_merge_outputs(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") e1 = op3(op2(x, y)) e2 = op3(op2(x, y)) g = FunctionGraph([x, y, z], [e1, e2], clone=False) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) assert g.outputs[0] is g.outputs[1] def test_identical_constant_args(self): @@ -309,7 +315,7 @@ def test_identical_constant_args(self): z = Constant(MyType(), 2, name="z") e1 = op1(y, z) g = FunctionGraph([x, y, z], [e1], clone=False) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) assert g.outputs[0].owner.op == op1 input_1 = g.outputs[0].owner.inputs[0] @@ -322,7 +328,7 @@ def test_one_assert_merge(self): x2 = matrix("x2") e = dot(x1, x2) + dot(assert_op(x1, (x1 > x2).all()), x2) g = FunctionGraph([x1, x2], [e], clone=False) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) assert g.outputs[0].owner.op == add add_inputs = g.outputs[0].owner.inputs @@ -342,7 +348,7 @@ def test_both_assert_merge_identical(self): assert_op(x1, (x1 > x2).all()), x2 ) g = FunctionGraph([x1, x2], [e], clone=False) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) assert g.outputs[0].owner.op == add add_inputs = g.outputs[0].owner.inputs @@ -365,7 +371,7 @@ def test_both_assert_merge_1(self): assert_op(x1, (x1 > x2).all()), x2 ) g = FunctionGraph([x1, x2, x3], [e], clone=False) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) assert g.outputs[0].owner.op == add add_inputs = g.outputs[0].owner.inputs @@ -387,7 +393,7 @@ def test_both_assert_merge_2(self): x1, assert_op(x2, (x2 > x3).all()) ) g = FunctionGraph([x1, x2, x3], [e], clone=False) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) assert g.outputs[0].owner.op == add add_inputs = g.outputs[0].owner.inputs @@ -411,7 +417,7 @@ def test_both_assert_merge_2_reverse(self): assert_op(x1, (x1 > x3).all()), x2 ) g = FunctionGraph([x1, x2, x3], [e], clone=False) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) assert g.outputs[0].owner.op == add add_inputs = g.outputs[0].owner.inputs @@ -432,7 +438,7 @@ def test_merge_noinput(self): z = NoInputOp(param=1)() fg = FunctionGraph([], [x, y, z], clone=False) - MergeOptimizer().optimize(fg) + MergeOptimizer().rewrite(fg) assert fg.outputs[0] is fg.outputs[1] assert fg.outputs[0] is not fg.outputs[2] @@ -446,15 +452,15 @@ def test_1(self): e = op3(op4(x, y)) g = FunctionGraph([x, y, z], [e]) # print g - opt = EquilibriumOptimizer( + rewriter = EquilibriumGraphRewriter( [ - PatternSub((op1, "x", "y"), (op2, "x", "y")), - PatternSub((op4, "x", "y"), (op1, "x", "y")), - PatternSub((op3, (op2, "x", "y")), (op4, "x", "y")), + PatternNodeRewriter((op1, "x", "y"), (op2, "x", "y")), + PatternNodeRewriter((op4, "x", "y"), (op1, "x", "y")), + PatternNodeRewriter((op3, (op2, "x", "y")), (op4, "x", "y")), ], max_use_ratio=10, ) - opt.optimize(g) + rewriter.rewrite(g) # print g assert str(g) == "FunctionGraph(Op2(x, y))" @@ -463,17 +469,17 @@ def test_2(self): e = op1(op1(op3(x, y))) g = FunctionGraph([x, y, z], [e]) # print g - opt = EquilibriumOptimizer( + rewriter = EquilibriumGraphRewriter( [ - PatternSub((op1, (op2, "x", "y")), (op4, "x", "y")), - PatternSub((op3, "x", "y"), (op4, "x", "y")), - PatternSub((op4, "x", "y"), (op5, "x", "y")), - PatternSub((op5, "x", "y"), (op6, "x", "y")), - PatternSub((op6, "x", "y"), (op2, "x", "y")), + PatternNodeRewriter((op1, (op2, "x", "y")), (op4, "x", "y")), + PatternNodeRewriter((op3, "x", "y"), (op4, "x", "y")), + PatternNodeRewriter((op4, "x", "y"), (op5, "x", "y")), + PatternNodeRewriter((op5, "x", "y"), (op6, "x", "y")), + PatternNodeRewriter((op6, "x", "y"), (op2, "x", "y")), ], max_use_ratio=10, ) - opt.optimize(g) + rewriter.rewrite(g) assert str(g) == "FunctionGraph(Op2(x, y))" @config.change_flags(on_opt_error="ignore") @@ -483,20 +489,20 @@ def test_low_use_ratio(self): g = FunctionGraph([x, y, z], [e]) # print 'before', g # display pesky warnings along with stdout - # also silence logger for 'aesara.graph.opt' - _logger = logging.getLogger("aesara.graph.opt") + # also silence logger for 'aesara.graph.rewriting.basic' + _logger = logging.getLogger("aesara.graph.rewriting.basic") oldlevel = _logger.level _logger.setLevel(logging.CRITICAL) try: - opt = EquilibriumOptimizer( + rewriter = EquilibriumGraphRewriter( [ - PatternSub((op1, "x", "y"), (op2, "x", "y")), - PatternSub((op4, "x", "y"), (op1, "x", "y")), - PatternSub((op3, (op2, "x", "y")), (op4, "x", "y")), + PatternNodeRewriter((op1, "x", "y"), (op2, "x", "y")), + PatternNodeRewriter((op4, "x", "y"), (op1, "x", "y")), + PatternNodeRewriter((op3, (op2, "x", "y")), (op4, "x", "y")), ], max_use_ratio=1.0 / len(g.apply_nodes), - ) # each opt can only be applied once - opt.optimize(g) + ) + rewriter.rewrite(g) finally: _logger.setLevel(oldlevel) # print 'after', g @@ -547,7 +553,7 @@ def test_pre_constant_merge(): assert res == [adv] -def test_pre_greedy_local_optimizer(): +def test_pre_greedy_node_rewriter(): empty_fgraph = FunctionGraph([], []) @@ -564,7 +570,7 @@ def test_pre_greedy_local_optimizer(): # This should fold `o1`, because it has only `Constant` arguments, and # replace it with the `Constant` result - cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], o2) + cst = pre_greedy_node_rewriter(empty_fgraph, [constant_folding], o2) assert cst.owner.inputs[0].owner is None assert cst.owner.inputs[1] is c2 @@ -577,14 +583,14 @@ def test_pre_greedy_local_optimizer(): fg = FunctionGraph([], [o1], clone=False) o2 = op1(o1, c2, x, o3, o1) - cst = pre_greedy_local_optimizer(fg, [constant_folding], o2) + cst = pre_greedy_node_rewriter(fg, [constant_folding], o2) assert cst.owner.inputs[0] is o1 assert cst.owner.inputs[4] is cst.owner.inputs[0] # What exactly is this supposed to test? ms = MakeSlice()(1) - cst = pre_greedy_local_optimizer(empty_fgraph, [constant_folding], ms) + cst = pre_greedy_node_rewriter(empty_fgraph, [constant_folding], ms) assert isinstance(cst, SliceConstant) @@ -595,14 +601,14 @@ def test_pre_greedy_local_optimizer(): @pytest.mark.parametrize("tracks", [True, False]) @pytest.mark.parametrize("out_pattern", [(op2, "x"), "x", 1.0]) def test_patternsub_values_eq_approx(out_pattern, tracks): - # PatternSub would fail when `values_eq_approx` and `get_nodes` were specified + # PatternNodeRewriter would fail when `values_eq_approx` and `get_nodes` were specified x = MyVariable("x") e = op1(x) fg = FunctionGraph([x], [e], clone=False) - opt = EquilibriumOptimizer( + rewriter = EquilibriumGraphRewriter( [ - PatternSub( + PatternNodeRewriter( (op1, "x"), out_pattern, tracks=[op1] if tracks else (), @@ -612,7 +618,7 @@ def test_patternsub_values_eq_approx(out_pattern, tracks): ], max_use_ratio=1, ) - opt.optimize(fg) + rewriter.rewrite(fg) output = fg.outputs[0] if isinstance(out_pattern, tuple): assert output.owner.op == op2 @@ -628,43 +634,43 @@ def test_patternsub_values_eq_approx(out_pattern, tracks): @pytest.mark.parametrize("out_pattern", [(op1, "x"), "x"]) def test_patternsub_invalid_dtype(out_pattern): - # PatternSub would wrongly return output of different dtype as the original node + # PatternNodeRewriter would wrongly return output of different dtype as the original node x = MyVariable("x") e = op_cast_type2(x) fg = FunctionGraph([x], [e]) - opt = EquilibriumOptimizer( + rewriter = EquilibriumGraphRewriter( [ - PatternSub( + PatternNodeRewriter( (op_cast_type2, "x"), out_pattern, ) ], max_use_ratio=1, ) - opt.optimize(fg) + rewriter.rewrite(fg) assert e.type.is_super(fg.outputs[0].type) def test_patternsub_different_output_lengths(): - # Test that PatternSub won't replace nodes with different numbers of outputs - ps = PatternSub( + # Test that PatternNodeRewriter won't replace nodes with different numbers of outputs + ps = PatternNodeRewriter( (op1, "x"), ("x"), name="ps", ) - opt = in2out(ps) + rewriter = in2out(ps) x = MyVariable("x") e1, e2 = op_multiple_outputs(x) o = op1(e1) fgraph = FunctionGraph(inputs=[x], outputs=[o]) - opt.optimize(fgraph) + rewriter.rewrite(fgraph) assert fgraph.outputs[0].owner.op == op1 -class TestLocalOptGroup: +class TestSequentialNodeRewriter: def test_optimizer_verbose(self, capsys): x = MyVariable("x") @@ -673,59 +679,59 @@ def test_optimizer_verbose(self, capsys): fgraph = FunctionGraph([x, y], [o1], clone=False) - @local_optimizer(None) - def local_opt_1(fgraph, node): + @node_rewriter(None) + def local_rewrite_1(fgraph, node): if node.inputs[0] == x: res = op2(y, *node.inputs[1:]) return [res] - @local_optimizer(None) - def local_opt_2(fgraph, node): + @node_rewriter(None) + def local_rewrite_2(fgraph, node): if node.inputs[0] == y: res = op2(x, *node.inputs[1:]) return [res] - opt_group = LocalOptGroup(local_opt_1, local_opt_2) + seq_rewriter = SequentialNodeRewriter(local_rewrite_1, local_rewrite_2) with config.change_flags(optimizer_verbose=True): - (new_res,) = opt_group.transform(fgraph, o1.owner) - _ = opt_group.transform(fgraph, new_res.owner) + (new_res,) = seq_rewriter.transform(fgraph, o1.owner) + _ = seq_rewriter.transform(fgraph, new_res.owner) capres = capsys.readouterr() assert capres.err == "" assert ( - "optimizer: rewrite local_opt_1 replaces node Op1(x, y) with [Op2.0]" + "rewriting: rewrite local_rewrite_1 replaces node Op1(x, y) with [Op2.0]" in capres.out ) assert ( - "optimizer: rewrite local_opt_2 replaces node Op2(y, y) with [Op2.0]" + "rewriting: rewrite local_rewrite_2 replaces node Op2(y, y) with [Op2.0]" in capres.out ) -def test_local_optimizer_str(): - @local_optimizer([op1, MyOp]) - def local_opt_1(fgraph, node): +def test_node_rewriter_str(): + @node_rewriter([op1, MyOp]) + def local_rewriter_1(fgraph, node): pass - assert str(local_opt_1) == "local_opt_1" - res = repr(local_opt_1) - assert res.startswith("FromFunctionLocalOptimizer(") + assert str(local_rewriter_1) == "local_rewriter_1" + res = repr(local_rewriter_1) + assert res.startswith("FromFunctionNodeRewriter(") assert "Op1" in res - assert "local_opt_1" in res + assert "local_rewriter_1" in res -def test_local_optimizer(): +def test_node_rewriter(): with pytest.raises(ValueError): - @local_optimizer([]) + @node_rewriter([]) def local_bad_1(fgraph, node): return node.outputs with pytest.raises(TypeError): - @local_optimizer([None]) + @node_rewriter([None]) def local_bad_2(fgraph, node): return node.outputs @@ -748,61 +754,67 @@ class MyNewOp2(MyOp): hits = [0] - @local_optimizer([op1, MyNewOp]) - def local_opt_1(fgraph, node, hits=hits): + @node_rewriter([op1, MyNewOp]) + def local_rewriter_1(fgraph, node, hits=hits): hits[0] += 1 return node.outputs # This is allowed by the `op1` in `tracks` - local_opt_1.transform(fgraph, fgraph.outputs[0].owner) + local_rewriter_1.transform(fgraph, fgraph.outputs[0].owner) assert hits[0] == 1 # This is allowed by the `MyOp` in `tracks` - local_opt_1.transform(fgraph, fgraph.outputs[1].owner) + local_rewriter_1.transform(fgraph, fgraph.outputs[1].owner) assert hits[0] == 2 # This is not allowed by `tracks` - local_opt_1.transform(fgraph, fgraph.outputs[2].owner) + local_rewriter_1.transform(fgraph, fgraph.outputs[2].owner) assert hits[0] == 2 -def test_TrackingLocalOptimizer(): - @local_optimizer(None) - def local_opt_1(fgraph, node): +def test_OpToRewriterTracker(): + @node_rewriter(None) + def local_rewriter_1(fgraph, node): pass - @local_optimizer([op1]) - def local_opt_2(fgraph, node): + @node_rewriter([op1]) + def local_rewriter_2(fgraph, node): pass - @local_optimizer([Op]) - def local_opt_3(fgraph, node): + @node_rewriter([Op]) + def local_rewriter_3(fgraph, node): pass - @local_optimizer([MyOp]) - def local_opt_4(fgraph, node): + @node_rewriter([MyOp]) + def local_rewriter_4(fgraph, node): pass - @local_optimizer([MyOp]) - def local_opt_5(fgraph, node): + @node_rewriter([MyOp]) + def local_rewriter_5(fgraph, node): pass - tracker = LocalOptTracker() - tracker.add_tracker(local_opt_1) - tracker.add_tracker(local_opt_2) - tracker.add_tracker(local_opt_3) - tracker.add_tracker(local_opt_4) - tracker.add_tracker(local_opt_5) + tracker = OpToRewriterTracker() + tracker.add_tracker(local_rewriter_1) + tracker.add_tracker(local_rewriter_2) + tracker.add_tracker(local_rewriter_3) + tracker.add_tracker(local_rewriter_4) + tracker.add_tracker(local_rewriter_5) - assert tracker.tracked_instances == {op1: [local_opt_2]} + assert tracker.tracked_instances == {op1: [local_rewriter_2]} assert tracker.tracked_types == { - Op: [local_opt_3], - MyOp: [local_opt_4, local_opt_5], + Op: [local_rewriter_3], + MyOp: [local_rewriter_4, local_rewriter_5], } - assert tracker.untracked_opts == [local_opt_1] + assert tracker.untracked_rewrites == [local_rewriter_1] res = tracker.get_trackers(op1) - assert res == [local_opt_4, local_opt_5, local_opt_3, local_opt_2, local_opt_1] + assert res == [ + local_rewriter_4, + local_rewriter_5, + local_rewriter_3, + local_rewriter_2, + local_rewriter_1, + ] class MyNewOp(Op): def perform(self, *args): @@ -811,12 +823,26 @@ def perform(self, *args): new_op = MyNewOp() res = tracker.get_trackers(new_op) - assert res == [local_opt_3, local_opt_1] + assert res == [local_rewriter_3, local_rewriter_1] assert list(tracker.get_rewriters()) == [ - local_opt_3, - local_opt_4, - local_opt_5, - local_opt_2, - local_opt_1, + local_rewriter_3, + local_rewriter_4, + local_rewriter_5, + local_rewriter_2, + local_rewriter_1, ] + + +def test_deprecations(): + """Make sure we can import deprecated classes from current and deprecated modules.""" + with pytest.deprecated_call(): + from aesara.graph.rewriting.basic import GlobalOptimizer + + with pytest.deprecated_call(): + from aesara.graph.opt import GlobalOptimizer, LocalOptimizer # noqa: F401 F811 + + del sys.modules["aesara.graph.opt"] + + with pytest.deprecated_call(): + from aesara.graph.opt import GraphRewriter # noqa: F401 diff --git a/tests/graph/rewriting/test_db.py b/tests/graph/rewriting/test_db.py new file mode 100644 index 0000000000..51239b99d6 --- /dev/null +++ b/tests/graph/rewriting/test_db.py @@ -0,0 +1,102 @@ +import sys + +import pytest + +from aesara.graph.rewriting.basic import GraphRewriter, SequentialGraphRewriter +from aesara.graph.rewriting.db import ( + EquilibriumDB, + LocalGroupDB, + ProxyDB, + RewriteDatabase, + SequenceDB, +) + + +class TestRewriter(GraphRewriter): + name = "blah" + + def apply(self, fgraph): + pass + + +class TestDB: + def test_register(self): + db = RewriteDatabase() + db.register("a", TestRewriter()) + + db.register("b", TestRewriter()) + + db.register("c", TestRewriter(), "z", "asdf") + + assert "a" in db + assert "b" in db + assert "c" in db + + with pytest.raises(ValueError, match=r"The tag.*"): + db.register("c", TestRewriter()) # name taken + + with pytest.raises(ValueError, match=r"The tag.*"): + db.register("z", TestRewriter()) # name collides with tag + + with pytest.raises(ValueError, match=r"The tag.*"): + db.register("u", TestRewriter(), "b") # name new but tag collides with name + + with pytest.raises(TypeError, match=r".* is not a valid.*"): + db.register("d", 1) + + def test_EquilibriumDB(self): + eq_db = EquilibriumDB() + + with pytest.raises(ValueError, match=r"`final_rewriter` and.*"): + eq_db.register("d", TestRewriter(), final_rewriter=True, cleanup=True) + + def test_SequenceDB(self): + seq_db = SequenceDB(failure_callback=None) + + res = seq_db.query("+a") + + assert isinstance(res, SequentialGraphRewriter) + assert res.data == [] + + seq_db.register("b", TestRewriter(), position=1) + + from io import StringIO + + out_file = StringIO() + seq_db.print_summary(stream=out_file) + + res = out_file.getvalue() + + assert str(id(seq_db)) in res + assert "names {'b'}" in res + + with pytest.raises(TypeError, match=r"`position` must be.*"): + seq_db.register("c", TestRewriter(), position=object()) + + def test_LocalGroupDB(self): + lg_db = LocalGroupDB() + + lg_db.register("a", TestRewriter(), 1) + + assert "a" in lg_db.__position__ + + with pytest.raises(TypeError, match=r"`position` must be.*"): + lg_db.register("b", TestRewriter(), position=object()) + + def test_ProxyDB(self): + with pytest.raises(TypeError, match=r"`db` must be.*"): + ProxyDB(object()) + + +def test_deprecations(): + """Make sure we can import deprecated classes from current and deprecated modules.""" + with pytest.deprecated_call(): + from aesara.graph.rewriting.db import OptimizationDatabase # noqa: F401 F811 + + with pytest.deprecated_call(): + from aesara.graph.optdb import OptimizationDatabase # noqa: F401 F811 + + del sys.modules["aesara.graph.optdb"] + + with pytest.deprecated_call(): + from aesara.graph.optdb import RewriteDatabase # noqa: F401 diff --git a/tests/graph/test_kanren.py b/tests/graph/rewriting/test_kanren.py similarity index 89% rename from tests/graph/test_kanren.py rename to tests/graph/rewriting/test_kanren.py index e911b9e936..75d8ec037f 100644 --- a/tests/graph/test_kanren.py +++ b/tests/graph/rewriting/test_kanren.py @@ -11,11 +11,11 @@ import aesara.tensor as at from aesara.graph.basic import Apply from aesara.graph.fg import FunctionGraph -from aesara.graph.kanren import KanrenRelationSub from aesara.graph.op import Op -from aesara.graph.opt import EquilibriumOptimizer -from aesara.graph.opt_utils import optimize_graph -from aesara.graph.unify import eval_if_etuple +from aesara.graph.rewriting.basic import EquilibriumGraphRewriter +from aesara.graph.rewriting.kanren import KanrenRelationSub +from aesara.graph.rewriting.unify import eval_if_etuple +from aesara.graph.rewriting.utils import rewrite_graph from aesara.tensor.math import Dot, _dot from tests.graph.utils import MyType, MyVariable @@ -151,11 +151,11 @@ def distributes(in_lv, out_lv): ), ) - distribute_opt = EquilibriumOptimizer( + distribute_opt = EquilibriumGraphRewriter( [KanrenRelationSub(distributes)], max_use_ratio=10 ) - fgraph_opt = optimize_graph(fgraph, custom_opt=distribute_opt) + fgraph_opt = rewrite_graph(fgraph, custom_rewrite=distribute_opt) (expr_opt,) = fgraph_opt.outputs assert expr_opt.owner.op == at.add @@ -165,3 +165,9 @@ def distributes(in_lv, out_lv): assert expr_opt.owner.inputs[1].owner.op == at.add assert isinstance(expr_opt.owner.inputs[1].owner.inputs[0].owner.op, Dot) assert isinstance(expr_opt.owner.inputs[1].owner.inputs[1].owner.op, Dot) + + +def test_deprecations(): + """Make sure we can import deprecated classes from current and deprecated modules.""" + with pytest.deprecated_call(): + from aesara.graph.kanren import KanrenRelationSub # noqa: F401 F811 diff --git a/tests/graph/test_unify.py b/tests/graph/rewriting/test_unify.py similarity index 96% rename from tests/graph/test_unify.py rename to tests/graph/rewriting/test_unify.py index 5ca712e937..6ce1284794 100644 --- a/tests/graph/test_unify.py +++ b/tests/graph/rewriting/test_unify.py @@ -11,7 +11,7 @@ import aesara.tensor as at from aesara.graph.basic import Apply, Constant, equal_computations from aesara.graph.op import Op -from aesara.graph.unify import ConstrainedVar, convert_strs_to_vars +from aesara.graph.rewriting.unify import ConstrainedVar, convert_strs_to_vars from aesara.tensor.type import TensorType from tests.graph.utils import MyType @@ -350,3 +350,9 @@ def constraint(x): res = convert_strs_to_vars((val,)) assert isinstance(res[0], Constant) assert np.array_equal(res[0].data, val) + + +def test_deprecations(): + """Make sure we can import deprecated classes from current and deprecated modules.""" + with pytest.deprecated_call(): + from aesara.graph.unify import eval_if_etuple # noqa: F401 F811 diff --git a/tests/graph/test_opt_utils.py b/tests/graph/rewriting/test_utils.py similarity index 80% rename from tests/graph/test_opt_utils.py rename to tests/graph/rewriting/test_utils.py index 09ad983006..08aaea250e 100644 --- a/tests/graph/test_opt_utils.py +++ b/tests/graph/rewriting/test_utils.py @@ -1,6 +1,10 @@ +import sys + +import pytest + from aesara.graph.fg import FunctionGraph -from aesara.graph.opt import optimizer -from aesara.graph.opt_utils import is_same_graph, optimize_graph +from aesara.graph.rewriting.basic import graph_rewriter +from aesara.graph.rewriting.utils import is_same_graph, rewrite_graph from aesara.tensor.math import neg from aesara.tensor.type import vectors @@ -139,20 +143,34 @@ def test_merge_only(self): ) -def test_optimize_graph(): +def test_rewrite_graph(): x, y = vectors("xy") - @optimizer - def custom_opt(fgraph): + @graph_rewriter + def custom_rewrite(fgraph): fgraph.replace(x, y, import_missing=True) - x_opt = optimize_graph(x, custom_opt=custom_opt) + x_rewritten = rewrite_graph(x, custom_rewrite=custom_rewrite) - assert x_opt is y + assert x_rewritten is y - x_opt = optimize_graph( - FunctionGraph(outputs=[x], clone=False), custom_opt=custom_opt + x_rewritten = rewrite_graph( + FunctionGraph(outputs=[x], clone=False), custom_rewrite=custom_rewrite ) - assert x_opt.outputs[0] is y + assert x_rewritten.outputs[0] is y + + +def test_deprecations(): + """Make sure we can import deprecated classes from current and deprecated modules.""" + with pytest.deprecated_call(): + from aesara.graph.rewriting.utils import optimize_graph # noqa: F401 F811 + + with pytest.deprecated_call(): + from aesara.graph.opt_utils import optimize_graph # noqa: F401 F811 + + del sys.modules["aesara.graph.opt_utils"] + + with pytest.deprecated_call(): + from aesara.graph.opt_utils import rewrite_graph # noqa: F401 diff --git a/tests/graph/test_destroyhandler.py b/tests/graph/test_destroyhandler.py index c2f42f14c5..3470284e66 100644 --- a/tests/graph/test_destroyhandler.py +++ b/tests/graph/test_destroyhandler.py @@ -8,24 +8,28 @@ from aesara.graph.features import ReplaceValidate from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op -from aesara.graph.opt import ( - NavigatorOptimizer, - OpKeyOptimizer, - OpSub, - PatternSub, - TopoOptimizer, +from aesara.graph.rewriting.basic import ( + NodeProcessingGraphRewriter, + OpKeyGraphRewriter, + PatternNodeRewriter, + SubstitutionNodeRewriter, + WalkingGraphRewriter, ) from aesara.graph.type import Type from aesara.graph.utils import InconsistencyError from tests.unittest_tools import assertFailure_fast -def PatternOptimizer(p1, p2, ign=True): - return OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign) +def OpKeyPatternNodeRewriter(p1, p2, ign=True): + return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) -def OpSubOptimizer(op1, op2, fail=NavigatorOptimizer.warn_ignore, ign=True): - return TopoOptimizer(OpSub(op1, op2), ignore_newtrees=ign, failure_callback=fail) +def TopoSubstitutionNodeRewriter( + op1, op2, fail=NodeProcessingGraphRewriter.warn_ignore, ign=True +): + return WalkingGraphRewriter( + SubstitutionNodeRewriter(op1, op2), ignore_newtrees=ign, failure_callback=fail + ) def as_variable(x): @@ -127,27 +131,20 @@ def create_fgraph(inputs, outputs, validate=True): class FailureWatch: - # when passed to OpSubOptimizer or PatternOptimizer, counts the - # number of failures def __init__(self): self.failures = 0 - def __call__(self, exc, nav, pairs, lopt, node): + def __call__(self, exc, nav, pairs, lrewrite, node): assert isinstance(exc, InconsistencyError) self.failures += 1 -################# -# Test protocol # -################# - - def test_misc(): x, y, z = inputs() e = transpose_view(transpose_view(transpose_view(transpose_view(x)))) g = create_fgraph([x, y, z], [e]) assert g.consistent() - PatternOptimizer((transpose_view, (transpose_view, "x")), "x").optimize(g) + OpKeyPatternNodeRewriter((transpose_view, (transpose_view, "x")), "x").rewrite(g) assert str(g) == "FunctionGraph(x)" new_e = add(x, y) g.replace_validate(x, new_e) @@ -157,11 +154,6 @@ def test_misc(): assert not g.consistent() -###################### -# Test protocol skip # -###################### - - @assertFailure_fast def test_aliased_inputs_replacement(): x, y, z = inputs() @@ -231,11 +223,6 @@ def test_destroyers_loop(): assert g.consistent() -######## -# Misc # -######## - - def test_aliased_inputs(): x, y, z = inputs() e = add_in_place(x, x) @@ -326,7 +313,7 @@ def test_long_destroyers_loop(): e = dot(dot(add_in_place(x, y), add_in_place(y, z)), add(z, x)) g = create_fgraph([x, y, z], [e]) assert g.consistent() - OpSubOptimizer(add, add_in_place).optimize(g) + TopoSubstitutionNodeRewriter(add, add_in_place).rewrite(g) assert g.consistent() # we don't want to see that! assert ( @@ -362,7 +349,7 @@ def test_multi_destroyers_through_views(): g = create_fgraph([x, y, z], [e]) assert g.consistent() fail = FailureWatch() - OpSubOptimizer(add, add_in_place, fail).optimize(g) + TopoSubstitutionNodeRewriter(add, add_in_place, fail).rewrite(g) assert g.consistent() assert fail.failures == 1 # should have succeeded once and failed once @@ -384,7 +371,7 @@ def test_usage_loop(): g = create_fgraph([x, y, z], [dot(add_in_place(x, z), x)], False) assert not g.consistent() # replace add_in_place with add - OpSubOptimizer(add_in_place, add).optimize(g) + TopoSubstitutionNodeRewriter(add_in_place, add).rewrite(g) assert g.consistent() @@ -405,7 +392,7 @@ def test_usage_loop_insert_views(): g = create_fgraph([x, y, z], [e]) assert g.consistent() fail = FailureWatch() - OpSubOptimizer(sigmoid, transpose_view, fail).optimize(g) + TopoSubstitutionNodeRewriter(sigmoid, transpose_view, fail).rewrite(g) assert g.consistent() # it must keep one sigmoid in the long sigmoid chain assert fail.failures == 1 @@ -450,24 +437,26 @@ def test_multiple_inplace(): # try to work in-place on x/0 and y/1 (this should fail) fail = FailureWatch() - OpSubOptimizer(multiple, multiple_in_place_0_1, fail).optimize(g) + TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0_1, fail).rewrite(g) assert g.consistent() assert fail.failures == 1 # try to work in-place on x/0 (this should fail) fail = FailureWatch() - OpSubOptimizer(multiple, multiple_in_place_0, fail).optimize(g) + TopoSubstitutionNodeRewriter(multiple, multiple_in_place_0, fail).rewrite(g) assert g.consistent() assert fail.failures == 1 # try to work in-place on y/1 (this should succeed) fail = FailureWatch() - OpSubOptimizer(multiple, multiple_in_place_1, fail).optimize(g) + TopoSubstitutionNodeRewriter(multiple, multiple_in_place_1, fail).rewrite(g) assert g.consistent() assert fail.failures == 0 # try to work in-place on x/0 and y/1 (this should still fail) fail = FailureWatch() - OpSubOptimizer(multiple_in_place_1, multiple_in_place_0_1, fail).optimize(g) + TopoSubstitutionNodeRewriter( + multiple_in_place_1, multiple_in_place_0_1, fail + ).rewrite(g) assert g.consistent() assert fail.failures == 1 diff --git a/tests/graph/test_features.py b/tests/graph/test_features.py index a3aeb43779..bd5044fb29 100644 --- a/tests/graph/test_features.py +++ b/tests/graph/test_features.py @@ -103,7 +103,7 @@ def test_verbose(self, capsys): capres = capsys.readouterr() assert capres.err == "" assert ( - "optimizer: rewrite test-reason replaces Op1.0 of Op1(var2, var1) with var1 of None" + "rewriting: rewrite test-reason replaces Op1.0 of Op1(var2, var1) with var1 of None" in capres.out ) @@ -119,4 +119,4 @@ def validate(self, *args): ) capres = capsys.readouterr() - assert "optimizer: validate failed on node Op1.0" in capres.out + assert "rewriting: validate failed on node Op1.0" in capres.out diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index 7d668021ae..8e495ff44c 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -279,7 +279,7 @@ def test_replace_verbose(self, capsys): capres = capsys.readouterr() assert capres.err == "" assert ( - "optimizer: rewrite test-reason replaces Op1.0 of Op1(var2, var1) with var1 of None" + "rewriting: rewrite test-reason replaces Op1.0 of Op1(var2, var1) with var1 of None" in capres.out ) diff --git a/tests/graph/test_optdb.py b/tests/graph/test_optdb.py deleted file mode 100644 index b3ded02c16..0000000000 --- a/tests/graph/test_optdb.py +++ /dev/null @@ -1,86 +0,0 @@ -import pytest - -from aesara.graph import opt -from aesara.graph.optdb import ( - EquilibriumDB, - LocalGroupDB, - OptimizationDatabase, - ProxyDB, - SequenceDB, -) - - -class TestOpt(opt.GlobalOptimizer): - name = "blah" - - def apply(self, fgraph): - pass - - -class TestDB: - def test_register(self): - db = OptimizationDatabase() - db.register("a", TestOpt()) - - db.register("b", TestOpt()) - - db.register("c", TestOpt(), "z", "asdf") - - assert "a" in db - assert "b" in db - assert "c" in db - - with pytest.raises(ValueError, match=r"The tag.*"): - db.register("c", TestOpt()) # name taken - - with pytest.raises(ValueError, match=r"The tag.*"): - db.register("z", TestOpt()) # name collides with tag - - with pytest.raises(ValueError, match=r"The tag.*"): - db.register("u", TestOpt(), "b") # name new but tag collides with name - - with pytest.raises(TypeError, match=r".* is not a valid.*"): - db.register("d", 1) - - def test_EquilibriumDB(self): - eq_db = EquilibriumDB() - - with pytest.raises(ValueError, match=r"`final_opt` and.*"): - eq_db.register("d", TestOpt(), final_opt=True, cleanup=True) - - def test_SequenceDB(self): - seq_db = SequenceDB(failure_callback=None) - - res = seq_db.query("+a") - - assert isinstance(res, opt.SeqOptimizer) - assert res.data == [] - - seq_db.register("b", TestOpt(), position=1) - - from io import StringIO - - out_file = StringIO() - seq_db.print_summary(stream=out_file) - - res = out_file.getvalue() - - assert str(id(seq_db)) in res - assert "names {'b'}" in res - - with pytest.raises(TypeError, match=r"`position` must be.*"): - seq_db.register("c", TestOpt(), position=object()) - - def test_LocalGroupDB(self): - lg_db = LocalGroupDB() - - lg_db.register("a", TestOpt(), 1) - - assert "a" in lg_db.__position__ - - with pytest.raises(TypeError, match=r"`position` must be.*"): - lg_db.register("b", TestOpt(), position=object()) - - def test_ProxyDB(self): - with pytest.raises(TypeError, match=r"`db` must be.*"): - ProxyDB(object()) diff --git a/tests/link/c/test_basic.py b/tests/link/c/test_basic.py index 4f2e28af60..645972fc19 100644 --- a/tests/link/c/test_basic.py +++ b/tests/link/c/test_basic.py @@ -455,3 +455,43 @@ def test_shared_input_output(): gv0 = gv(0) assert np.all(fv0 == 5), fv0 assert np.all(gv0 == 5), gv0 + + +def test_cmodule_key_empty_props(): + """Make sure `CLinker.cmodule_key_` is correct when `COp.__props__` is empty.""" + + class MyAdd(COp): + __props__ = () + + def make_node(self, *inputs): + inputs = list(map(as_variable, inputs)) + outputs = [tdouble()] + return Apply(self, inputs, outputs) + + def __str__(self): + return self.name + + def perform(self, node, inputs, out_): + (out,) = out_ + out[0] = sum(*inputs) + + def c_code_cache_version(self): + return (1,) + + def c_code(self, node, name, inp, out, sub): + x, y = inp + (z,) = out + return f"{z} = {x} + {y};" + + x = tdouble("x") + y = tdouble("y") + + z = MyAdd()(x, y) + + fg = FunctionGraph(outputs=[z]) + + linker = CLinker() + linker.accept(fg) + key = linker.cmodule_key() + # None of the C version values should be empty + assert all(kv for kv in key[0]) diff --git a/tests/link/c/test_cmodule.py b/tests/link/c/test_cmodule.py index b1cf944c7a..6fef007537 100644 --- a/tests/link/c/test_cmodule.py +++ b/tests/link/c/test_cmodule.py @@ -18,9 +18,13 @@ from aesara.compile.function import function from aesara.compile.ops import DeepCopyOp from aesara.configdefaults import config -from aesara.link.c.cmodule import GCC_compiler, default_blas_ldflags +from aesara.graph.basic import Apply +from aesara.graph.fg import FunctionGraph +from aesara.link.c.basic import CLinker +from aesara.link.c.cmodule import GCC_compiler, ModuleCache, default_blas_ldflags from aesara.link.c.exceptions import CompileError -from aesara.tensor.type import dvectors +from aesara.link.c.op import COp +from aesara.tensor.type import dvectors, vector class MyOp(DeepCopyOp): @@ -43,20 +47,47 @@ def c_code(self, node, name, inames, onames, sub): return super(DeepCopyOp, self).c_code(node, name, inames, onames, sub) +class MyAdd(COp): + __props__ = () + + def make_node(self, *inputs): + outputs = [vector()] + return Apply(self, inputs, outputs) + + def perform(self, node, inputs, out_): + (out,) = out_ + out[0] = inputs[0][0] + 1 + + def c_code(self, node, name, inp, out, sub): + (x,) = inp + (z,) = out + return f"{z} = {x} + 1;" + + +class MyAddVersioned(MyAdd): + def c_code_cache_version(self): + return (1,) + + def test_compiler_error(): with pytest.raises(CompileError), tempfile.TemporaryDirectory() as dir_name: GCC_compiler.compile_str("module_name", "blah", location=dir_name) def test_inter_process_cache(): - # When an op with c_code, but no version. If we have 2 apply node - # in the graph with different inputs variable(so they don't get - # merged) but the inputs variable have the same type, do we reuse - # the same module? Even if they would generate different c_code? - # Currently this test show that we generate the c_code only once. - # - # This is to know if the c_code can add information specific to the - # node.inputs[*].owner like the name of the variable. + """ + TODO FIXME: This explanation is very poorly written. + + When a `COp` with `COp.c_code`, but no version. If we have two `Apply` + nodes in a graph with distinct inputs variable, but the input variables + have the same `Type`, do we reuse the same module? Even if they would + generate different `COp.c_code`? Currently this test show that we generate + the `COp.c_code` only once. + + This is to know if the `COp.c_code` can add information specific to the + ``node.inputs[*].owner`` like the name of the variable. + + """ x, y = dvectors("xy") f = function([x, y], [MyOp()(x), MyOp()(y)]) @@ -76,12 +107,58 @@ def test_inter_process_cache(): assert MyOp.nb_called == 1 +@pytest.mark.filterwarnings("error") +def test_cache_versioning(): + """Make sure `ModuleCache._add_to_cache` is working.""" + + my_add = MyAdd() + with pytest.warns(match=".*specifies no C code cache version.*"): + assert my_add.c_code_cache_version() == () + + my_add_ver = MyAddVersioned() + assert my_add_ver.c_code_cache_version() == (1,) + + assert len(MyOp.__props__) == 0 + assert len(MyAddVersioned.__props__) == 0 + + x = vector("x") + + z = my_add(x) + z_v = my_add_ver(x) + + with tempfile.TemporaryDirectory() as dir_name: + cache = ModuleCache(dir_name) + + lnk = CLinker().accept(FunctionGraph(outputs=[z])) + with pytest.warns(match=".*specifies no C code cache version.*"): + key = lnk.cmodule_key() + assert key[0] == () + + with pytest.warns(match=".*c_code_cache_version.*"): + cache.module_from_key(key, lnk) + + lnk_v = CLinker().accept(FunctionGraph(outputs=[z_v])) + key_v = lnk_v.cmodule_key() + assert len(key_v[0]) > 0 + + assert key_v not in cache.entry_from_key + + stats_before = cache.stats[2] + cache.module_from_key(key_v, lnk_v) + assert stats_before < cache.stats[2] + + def test_flag_detection(): - # Check that the code detecting blas flags does not raise any exception. - # It used to happen on python 3 because of improper string handling, - # but was not detected because that path is not usually taken, - # so we test it here directly. - GCC_compiler.try_flags(["-lblas"]) + """ + TODO FIXME: This is a very poor test. + + Check that the code detecting blas flags does not raise any exception. + It used to happen on Python 3 because of improper string handling, + but was not detected because that path is not usually taken, + so we test it here directly. + """ + res = GCC_compiler.try_flags(["-lblas"]) + assert isinstance(res, bool) @patch("aesara.link.c.cmodule.try_blas_flag", return_value=None) diff --git a/tests/link/c/test_op.py b/tests/link/c/test_op.py index 198c610f5b..ee448a8314 100644 --- a/tests/link/c/test_op.py +++ b/tests/link/c/test_op.py @@ -1,3 +1,9 @@ +import os +import subprocess +import sys +import tempfile +from pathlib import Path + import numpy as np import pytest @@ -9,6 +15,52 @@ from aesara.link.c.op import COp +test_dir = Path(__file__).parent.absolute() + +externalcop_test_code = f""" +from aesara import tensor as at +from aesara.graph.basic import Apply +from aesara.link.c.params_type import ParamsType +from aesara.link.c.op import ExternalCOp +from aesara.scalar import ScalarType +from aesara.link.c.type import Generic +from aesara.tensor.type import TensorType + +tensor_type_0d = TensorType("float64", tuple()) +scalar_type = ScalarType("float64") +generic_type = Generic() + + +class QuadraticCOpFunc(ExternalCOp): + __props__ = ("a", "b", "c") + params_type = ParamsType(a=tensor_type_0d, b=scalar_type, c=generic_type) + + def __init__(self, a, b, c): + super().__init__( + "{test_dir}/c_code/test_quadratic_function.c", "APPLY_SPECIFIC(compute_quadratic)" + ) + self.a = a + self.b = b + self.c = c + + def make_node(self, x): + x = at.as_tensor_variable(x) + return Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, output_storage, coefficients): + x = inputs[0] + y = output_storage[0] + y[0] = coefficients.a * (x**2) + coefficients.b * x + coefficients.c + + +if __name__ == "__main__": + qcop = QuadraticCOpFunc(1, 2, 3) + + print(qcop.c_code_cache_version()) + print("__success__") +""" + + class StructOp(COp): __props__ = () @@ -141,3 +193,43 @@ def perform(self, *args, **kwargs): else: with pytest.raises((NotImplementedError, MethodNotDefined)): thunk() + + +def get_hash(modname, seed=None): + """From https://hg.python.org/cpython/file/5e8fa1b13516/Lib/test/test_hash.py#l145""" + env = os.environ.copy() + if seed is not None: + env["PYTHONHASHSEED"] = str(seed) + else: + env.pop("PYTHONHASHSEED", None) + cmd_line = [sys.executable, modname] + p = subprocess.Popen( + cmd_line, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + ) + out, err = p.communicate() + return out, err + + +def test_ExternalCOp_c_code_cache_version(): + """Make sure the C cache versions produced by `ExternalCOp` don't depend on `hash` seeding.""" + + with tempfile.NamedTemporaryFile(dir=".", suffix=".py") as tmp: + tmp.write(externalcop_test_code.encode()) + tmp.seek(0) + # modname = os.path.splitext(tmp.name)[0] + modname = tmp.name + out_1, err = get_hash(modname, seed=428) + assert err is None + out_2, err = get_hash(modname, seed=3849) + assert err is None + + hash_1, msg, _ = out_1.decode().split("\n") + assert msg == "__success__" + hash_2, msg, _ = out_2.decode().split("\n") + assert msg == "__success__" + + assert hash_1 == hash_2 diff --git a/tests/link/c/test_type.py b/tests/link/c/test_type.py index 081afa86a9..aedf00c9a7 100644 --- a/tests/link/c/test_type.py +++ b/tests/link/c/test_type.py @@ -11,66 +11,68 @@ from aesara.tensor.type import TensorType, continuous_dtypes -@pytest.mark.skipif( - not aesara.config.cxx, reason="G++ not available, so we need to skip this test." -) -def test_cdata(): - class ProdOp(COp): - __props__ = () - - def make_node(self, i): - return Apply(self, [i], [CDataType("void *", "py_decref")()]) - - def c_support_code(self, **kwargs): - return """ - void py_decref(void *p) { - Py_XDECREF((PyObject *)p); - } - """ - - def c_code(self, node, name, inps, outs, sub): - return """ - Py_XDECREF(%(out)s); - %(out)s = (void *)%(inp)s; - Py_INCREF(%(inp)s); - """ % dict( - out=outs[0], inp=inps[0] - ) +class ProdOp(COp): + __props__ = () + + def make_node(self, i): + return Apply(self, [i], [CDataType("void *", "py_decref")()]) - def c_code_cache_version(self): - return (0,) + def c_support_code(self, **kwargs): + return """ +void py_decref(void *p) { +Py_XDECREF((PyObject *)p); +} +""" - def perform(self, *args, **kwargs): - raise NotImplementedError() + def c_code(self, node, name, inps, outs, sub): + return """ +Py_XDECREF(%(out)s); +%(out)s = (void *)%(inp)s; +Py_INCREF(%(inp)s); +""" % dict( + out=outs[0], inp=inps[0] + ) - class GetOp(COp): - __props__ = () + def c_code_cache_version(self): + return (0,) - def make_node(self, c): - return Apply(self, [c], [TensorType("float32", (False,))()]) + def perform(self, *args, **kwargs): + raise NotImplementedError() - def c_support_code(self, **kwargs): - return """ - void py_decref(void *p) { - Py_XDECREF((PyObject *)p); - } - """ - def c_code(self, node, name, inps, outs, sub): - return """ - Py_XDECREF(%(out)s); - %(out)s = (PyArrayObject *)%(inp)s; - Py_INCREF(%(out)s); - """ % dict( - out=outs[0], inp=inps[0] - ) +class GetOp(COp): + __props__ = () - def c_code_cache_version(self): - return (0,) + def make_node(self, c): + return Apply(self, [c], [TensorType("float32", (False,))()]) - def perform(self, *args, **kwargs): - raise NotImplementedError() + def c_support_code(self, **kwargs): + return """ +void py_decref(void *p) { +Py_XDECREF((PyObject *)p); +} +""" + def c_code(self, node, name, inps, outs, sub): + return """ +Py_XDECREF(%(out)s); +%(out)s = (PyArrayObject *)%(inp)s; +Py_INCREF(%(out)s); +""" % dict( + out=outs[0], inp=inps[0] + ) + + def c_code_cache_version(self): + return (0,) + + def perform(self, *args, **kwargs): + raise NotImplementedError() + + +@pytest.mark.skipif( + not aesara.config.cxx, reason="G++ not available, so we need to skip this test." +) +def test_cdata(): i = TensorType("float32", (False,))() c = ProdOp()(i) i2 = GetOp()(c) diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py new file mode 100644 index 0000000000..56dfd18d35 --- /dev/null +++ b/tests/link/jax/test_basic.py @@ -0,0 +1,226 @@ +from functools import partial +from typing import Callable, Iterable, Optional + +import numpy as np +import pytest + +from aesara.compile.function import function +from aesara.compile.mode import Mode +from aesara.compile.sharedvalue import SharedVariable, shared +from aesara.configdefaults import config +from aesara.graph.basic import Apply +from aesara.graph.fg import FunctionGraph +from aesara.graph.op import Op, get_test_value +from aesara.graph.rewriting.db import RewriteDatabaseQuery +from aesara.ifelse import ifelse +from aesara.link.jax import JAXLinker +from aesara.raise_op import assert_op +from aesara.tensor.type import dscalar, scalar, vector + + +@pytest.fixture(scope="module", autouse=True) +def set_aesara_flags(): + with config.change_flags(cxx="", compute_test_value="ignore"): + yield + + +jax = pytest.importorskip("jax") + + +opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) +jax_mode = Mode(JAXLinker(), opts) +py_mode = Mode("py", opts) + + +def compare_jax_and_py( + fgraph: FunctionGraph, + test_inputs: Iterable, + assert_fn: Optional[Callable] = None, + must_be_device_array: bool = True, +): + """Function to compare python graph output and jax compiled output for testing equality + + In the tests below computational graphs are defined in Aesara. These graphs are then passed to + this function which then compiles the graphs in both jax and python, runs the calculation + in both and checks if the results are the same + + Parameters + ---------- + fgraph: FunctionGraph + Aesara function Graph object + test_inputs: iter + Numerical inputs for testing the function graph + assert_fn: func, opt + Assert function used to check for equality between python and jax. If not + provided uses np.testing.assert_allclose + must_be_device_array: Bool + Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes + if this device array is found it indicates if the result was computed by jax + + Returns + ------- + jax_res + + """ + if assert_fn is None: + assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) + + fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] + aesara_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode) + jax_res = aesara_jax_fn(*test_inputs) + + if must_be_device_array: + if isinstance(jax_res, list): + assert all( + isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res + ) + else: + assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) + + aesara_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) + py_res = aesara_py_fn(*test_inputs) + + if len(fgraph.outputs) > 1: + for j, p in zip(jax_res, py_res): + assert_fn(j, p) + else: + assert_fn(jax_res, py_res) + + return jax_res + + +def test_jax_FunctionGraph_names(): + import inspect + + from aesara.link.jax.dispatch import jax_funcify + + x = scalar("1x") + y = scalar("_") + z = scalar() + q = scalar("def") + + out_fg = FunctionGraph([x, y, z, q], [x, y, z, q], clone=False) + out_jx = jax_funcify(out_fg) + sig = inspect.signature(out_jx) + assert (x.auto_name, "_", z.auto_name, q.auto_name) == tuple(sig.parameters.keys()) + assert (1, 2, 3, 4) == out_jx(1, 2, 3, 4) + + +def test_jax_FunctionGraph_once(): + """Make sure that an output is only computed once when it's referenced multiple times.""" + from aesara.link.jax.dispatch import jax_funcify + + x = vector("x") + y = vector("y") + + class TestOp(Op): + def __init__(self): + self.called = 0 + + def make_node(self, *args): + return Apply(self, list(args), [x.type() for x in args]) + + def perform(self, inputs, outputs): + for i, inp in enumerate(inputs): + outputs[i][0] = inp[0] + + @jax_funcify.register(TestOp) + def jax_funcify_TestOp(op, **kwargs): + def func(*args, op=op): + op.called += 1 + return list(args) + + return func + + op1 = TestOp() + op2 = TestOp() + + q, r = op1(x, y) + outs = op2(q + r, q + r) + + out_fg = FunctionGraph([x, y], outs, clone=False) + assert len(out_fg.outputs) == 2 + + out_jx = jax_funcify(out_fg) + + x_val = np.r_[1, 2].astype(config.floatX) + y_val = np.r_[2, 3].astype(config.floatX) + + res = out_jx(x_val, y_val) + assert len(res) == 2 + assert op1.called == 1 + assert op2.called == 1 + + res = out_jx(x_val, y_val) + assert len(res) == 2 + assert op1.called == 2 + assert op2.called == 2 + + +def test_shared(): + a = shared(np.array([1, 2, 3], dtype=config.floatX)) + + aesara_jax_fn = function([], a, mode="JAX") + jax_res = aesara_jax_fn() + + assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) + np.testing.assert_allclose(jax_res, a.get_value()) + + aesara_jax_fn = function([], a * 2, mode="JAX") + jax_res = aesara_jax_fn() + + assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) + np.testing.assert_allclose(jax_res, a.get_value() * 2) + + # Changed the shared value and make sure that the JAX-compiled + # function also changes. + new_a_value = np.array([3, 4, 5], dtype=config.floatX) + a.set_value(new_a_value) + + jax_res = aesara_jax_fn() + assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) + np.testing.assert_allclose(jax_res, new_a_value * 2) + + +def test_shared_updates(): + a = shared(0) + + aesara_jax_fn = function([], a, updates={a: a + 1}, mode="JAX") + res1, res2 = aesara_jax_fn(), aesara_jax_fn() + assert res1 == 0 + assert res2 == 1 + assert a.get_value() == 2 + + a.set_value(5) + res1, res2 = aesara_jax_fn(), aesara_jax_fn() + assert res1 == 5 + assert res2 == 6 + assert a.get_value() == 7 + + +def test_jax_ifelse(): + + true_vals = np.r_[1, 2, 3] + false_vals = np.r_[-1, -2, -3] + + x = ifelse(np.array(True), true_vals, false_vals) + x_fg = FunctionGraph([], [x]) + + compare_jax_and_py(x_fg, []) + + a = dscalar("a") + a.tag.test_value = np.array(0.2, dtype=config.floatX) + x = ifelse(a < 0.5, true_vals, false_vals) + x_fg = FunctionGraph([a], [x]) # I.e. False + + compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs]) + + +def test_jax_checkandraise(): + p = scalar() + p.tag.test_value = 0 + + res = assert_op(p, p < 1.0) + + with pytest.warns(UserWarning): + function((p,), res, mode=jax_mode) diff --git a/tests/link/jax/test_elemwise.py b/tests/link/jax/test_elemwise.py new file mode 100644 index 0000000000..eb9809ba07 --- /dev/null +++ b/tests/link/jax/test_elemwise.py @@ -0,0 +1,101 @@ +import numpy as np +import pytest + +from aesara.configdefaults import config +from aesara.graph.fg import FunctionGraph +from aesara.graph.op import get_test_value +from aesara.tensor import elemwise as at_elemwise +from aesara.tensor import nnet as at_nnet +from aesara.tensor.math import all as at_all +from aesara.tensor.math import prod +from aesara.tensor.math import sum as at_sum +from aesara.tensor.nnet.basic import SoftmaxGrad +from aesara.tensor.type import matrix, tensor, vector +from tests.link.jax.test_basic import compare_jax_and_py + + +def test_jax_Dimshuffle(): + a_at = matrix("a") + + x = a_at.T + x_fg = FunctionGraph([a_at], [x]) + compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + + x = a_at.dimshuffle([0, 1, "x"]) + x_fg = FunctionGraph([a_at], [x]) + compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) + + a_at = tensor(dtype=config.floatX, shape=[False, True]) + x = a_at.dimshuffle((0,)) + x_fg = FunctionGraph([a_at], [x]) + compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) + + a_at = tensor(dtype=config.floatX, shape=[False, True]) + x = at_elemwise.DimShuffle([False, True], (0,))(a_at) + x_fg = FunctionGraph([a_at], [x]) + compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) + + +def test_jax_CAReduce(): + a_at = vector("a") + a_at.tag.test_value = np.r_[1, 2, 3].astype(config.floatX) + + x = at_sum(a_at, axis=None) + x_fg = FunctionGraph([a_at], [x]) + + compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)]) + + a_at = matrix("a") + a_at.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) + + x = at_sum(a_at, axis=0) + x_fg = FunctionGraph([a_at], [x]) + + compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) + + x = at_sum(a_at, axis=1) + x_fg = FunctionGraph([a_at], [x]) + + compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) + + a_at = matrix("a") + a_at.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) + + x = prod(a_at, axis=0) + x_fg = FunctionGraph([a_at], [x]) + + compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) + + x = at_all(a_at) + x_fg = FunctionGraph([a_at], [x]) + + compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_softmax(axis): + x = matrix("x") + x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) + out = at_nnet.softmax(x, axis=axis) + fgraph = FunctionGraph([x], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_logsoftmax(axis): + x = matrix("x") + x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) + out = at_nnet.logsoftmax(x, axis=axis) + fgraph = FunctionGraph([x], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +@pytest.mark.parametrize("axis", [None, 0, 1]) +def test_softmax_grad(axis): + dy = matrix("dy") + dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) + sm = matrix("sm") + sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) + out = SoftmaxGrad(axis=axis)(dy, sm) + fgraph = FunctionGraph([dy, sm], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py new file mode 100644 index 0000000000..8c9b70ef37 --- /dev/null +++ b/tests/link/jax/test_extra_ops.py @@ -0,0 +1,125 @@ +import numpy as np +import pytest +from packaging.version import parse as version_parse + +import aesara.tensor.basic as at +from aesara.configdefaults import config +from aesara.graph.fg import FunctionGraph +from aesara.graph.op import get_test_value +from aesara.tensor import extra_ops as at_extra_ops +from aesara.tensor.type import matrix, vector +from tests.link.jax.test_basic import compare_jax_and_py + + +jax = pytest.importorskip("jax") + + +def set_test_value(x, v): + x.tag.test_value = v + return x + + +def test_extra_ops(): + a = matrix("a") + a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) + + out = at_extra_ops.cumsum(a, axis=0) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = at_extra_ops.cumprod(a, axis=1) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = at_extra_ops.diff(a, n=2, axis=1) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = at_extra_ops.repeat(a, (3, 3), axis=1) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + c = at.as_tensor(5) + + out = at_extra_ops.fill_diagonal(a, c) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + with pytest.raises(NotImplementedError): + out = at_extra_ops.fill_diagonal_offset(a, c, c) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + with pytest.raises(NotImplementedError): + out = at_extra_ops.Unique(axis=1)(a) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + indices = np.arange(np.product((3, 4))) + out = at_extra_ops.unravel_index(indices, (3, 4), order="C") + fgraph = FunctionGraph([], out) + compare_jax_and_py( + fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False + ) + + +@pytest.mark.parametrize( + "x, shape", + [ + ( + set_test_value( + vector("x"), np.random.random(size=(2,)).astype(config.floatX) + ), + [at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)], + ), + ( + set_test_value( + vector("x"), np.random.random(size=(2,)).astype(config.floatX) + ), + [at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)], + ), + ], +) +def test_BroadcastTo(x, shape): + out = at_extra_ops.broadcast_to(x, shape) + fgraph = FunctionGraph(outputs=[out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +@pytest.mark.xfail( + version_parse(jax.__version__) >= version_parse("0.2.12"), + reason="Omnistaging cannot be disabled", +) +def test_extra_ops_omni(): + a = matrix("a") + a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) + + # This function also cannot take symbolic input. + c = at.as_tensor(5) + out = at_extra_ops.bartlett(c) + fgraph = FunctionGraph([], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) + out = at_extra_ops.ravel_multi_index(multi_index, (3, 4)) + fgraph = FunctionGraph([], [out]) + compare_jax_and_py( + fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False + ) + + # The inputs are "concrete", yet it still has problems? + out = at_extra_ops.Unique()( + at.as_tensor(np.arange(6, dtype=config.floatX).reshape((3, 2))) + ) + fgraph = FunctionGraph([], [out]) + compare_jax_and_py(fgraph, []) + + +@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs") +def test_unique_nonconcrete(): + a = matrix("a") + a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) + + out = at_extra_ops.Unique()(a) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py new file mode 100644 index 0000000000..8f29872980 --- /dev/null +++ b/tests/link/jax/test_nlinalg.py @@ -0,0 +1,132 @@ +import numpy as np +import pytest +from packaging.version import parse as version_parse + +from aesara.compile.function import function +from aesara.compile.mode import Mode +from aesara.configdefaults import config +from aesara.graph.fg import FunctionGraph +from aesara.graph.op import get_test_value +from aesara.graph.rewriting.db import RewriteDatabaseQuery +from aesara.link.jax import JAXLinker +from aesara.tensor import blas as at_blas +from aesara.tensor import nlinalg as at_nlinalg +from aesara.tensor.math import MaxAndArgmax +from aesara.tensor.math import max as at_max +from aesara.tensor.math import maximum +from aesara.tensor.type import dvector, matrix, scalar, tensor3, vector +from tests.link.jax.test_basic import compare_jax_and_py + + +jax = pytest.importorskip("jax") + + +def test_jax_BatchedDot(): + # tensor3 . tensor3 + a = tensor3("a") + a.tag.test_value = ( + np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) + ) + b = tensor3("b") + b.tag.test_value = ( + np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) + ) + out = at_blas.BatchedDot()(a, b) + fgraph = FunctionGraph([a, b], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + # A dimension mismatch should raise a TypeError for compatibility + inputs = [get_test_value(a)[:-1], get_test_value(b)] + opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) + jax_mode = Mode(JAXLinker(), opts) + aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) + with pytest.raises(TypeError): + aesara_jax_fn(*inputs) + + # matrix . matrix + a = matrix("a") + a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape((5, 3)) + b = matrix("b") + b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape((5, 3)) + out = at_blas.BatchedDot()(a, b) + fgraph = FunctionGraph([a, b], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +def test_jax_basic_multiout(): + rng = np.random.default_rng(213234) + + M = rng.normal(size=(3, 3)) + X = M.dot(M.T) + + x = matrix("x") + + outs = at_nlinalg.eig(x) + out_fg = FunctionGraph([x], outs) + + def assert_fn(x, y): + np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3) + + compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + + outs = at_nlinalg.eigh(x) + out_fg = FunctionGraph([x], outs) + compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + + outs = at_nlinalg.qr(x, mode="full") + out_fg = FunctionGraph([x], outs) + compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + + outs = at_nlinalg.qr(x, mode="reduced") + out_fg = FunctionGraph([x], outs) + compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + + outs = at_nlinalg.svd(x) + out_fg = FunctionGraph([x], outs) + compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) + + +@pytest.mark.xfail( + version_parse(jax.__version__) >= version_parse("0.2.12"), + reason="Omnistaging cannot be disabled", +) +def test_jax_basic_multiout_omni(): + # Test that a single output of a multi-output `Op` can be used as input to + # another `Op` + x = dvector() + mx, amx = MaxAndArgmax([0])(x) + out = mx * amx + out_fg = FunctionGraph([x], [out]) + compare_jax_and_py(out_fg, [np.r_[1, 2]]) + + +@pytest.mark.xfail( + version_parse(jax.__version__) >= version_parse("0.2.12"), + reason="Omnistaging cannot be disabled", +) +def test_tensor_basics(): + y = vector("y") + y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + x = vector("x") + x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) + A = matrix("A") + A.tag.test_value = np.empty((2, 2), dtype=config.floatX) + alpha = scalar("alpha") + alpha.tag.test_value = np.array(3.0, dtype=config.floatX) + beta = scalar("beta") + beta.tag.test_value = np.array(5.0, dtype=config.floatX) + + # This should be converted into a `Gemv` `Op` when the non-JAX compatible + # optimizations are turned on; however, when using JAX mode, it should + # leave the expression alone. + out = y.dot(alpha * A).dot(x) + beta * y + fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = maximum(y, x) + fgraph = FunctionGraph([y, x], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = at_max(y) + fgraph = FunctionGraph([y], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py new file mode 100644 index 0000000000..d23c6a096a --- /dev/null +++ b/tests/link/jax/test_random.py @@ -0,0 +1,86 @@ +import numpy as np +import pytest +from packaging.version import parse as version_parse + +import aesara.tensor as at +from aesara.compile.function import function +from aesara.compile.sharedvalue import shared +from aesara.configdefaults import config +from aesara.graph.fg import FunctionGraph +from aesara.tensor.random.basic import RandomVariable +from aesara.tensor.random.utils import RandomStream +from tests.link.jax.test_basic import compare_jax_and_py, jax_mode + + +jax = pytest.importorskip("jax") + + +@pytest.mark.xfail( + version_parse(jax.__version__) >= version_parse("0.2.26"), + reason="JAX samplers require concrete/static shape values?", +) +@pytest.mark.parametrize( + "at_dist, dist_params, rng, size", + [ + ( + at.random.normal, + (), + shared(np.random.RandomState(123)), + 10000, + ), + ( + at.random.normal, + (), + shared(np.random.default_rng(123)), + 10000, + ), + ], +) +def test_random_stats(at_dist, dist_params, rng, size): + # The RNG states are not 1:1, so the best we can do is check some summary + # statistics of the samples + out = at.random.normal(*dist_params, rng=rng, size=size) + fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) + + def assert_fn(x, y): + (x,) = x + (y,) = y + assert x.dtype.kind == y.dtype.kind + + d = 2 if config.floatX == "float64" else 1 + np.testing.assert_array_almost_equal(np.abs(x.mean()), np.abs(y.mean()), d) + + compare_jax_and_py(fgraph, [], assert_fn=assert_fn) + + +def test_random_unimplemented(): + class NonExistentRV(RandomVariable): + name = "non-existent" + ndim_supp = 0 + ndims_params = [] + dtype = "floatX" + + def __call__(self, size=None, **kwargs): + return super().__call__(size=size, **kwargs) + + def rng_fn(cls, rng, size): + return 0 + + nonexistentrv = NonExistentRV() + rng = shared(np.random.RandomState(123)) + out = nonexistentrv(rng=rng) + fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) + + with pytest.raises(NotImplementedError): + compare_jax_and_py(fgraph, []) + + +def test_RandomStream(): + srng = RandomStream(seed=123) + out = srng.normal() - srng.normal() + + fn = function([], out, mode=jax_mode) + jax_res_1 = fn() + jax_res_2 = fn() + + assert np.array_equal(jax_res_1, jax_res_2) diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py new file mode 100644 index 0000000000..d2b90ef12e --- /dev/null +++ b/tests/link/jax/test_scalar.py @@ -0,0 +1,190 @@ +import numpy as np +import pytest + +import aesara.scalar.basic as aes +import aesara.tensor as at +from aesara.configdefaults import config +from aesara.graph.fg import FunctionGraph +from aesara.graph.op import get_test_value +from aesara.scalar.basic import Composite +from aesara.tensor import nnet as at_nnet +from aesara.tensor.elemwise import Elemwise +from aesara.tensor.math import all as at_all +from aesara.tensor.math import ( + cosh, + erf, + erfc, + erfinv, + log, + log1mexp, + psi, + sigmoid, + softplus, +) +from aesara.tensor.type import matrix, scalar, vector +from tests.link.jax.test_basic import compare_jax_and_py + + +jax = pytest.importorskip("jax") + + +def test_second(): + a0 = scalar("a0") + b = scalar("b") + + out = aes.second(a0, b) + fgraph = FunctionGraph([a0, b], [out]) + compare_jax_and_py(fgraph, [10.0, 5.0]) + + a1 = vector("a1") + out = at.second(a1, b) + fgraph = FunctionGraph([a1, b], [out]) + compare_jax_and_py(fgraph, [np.zeros([5], dtype=config.floatX), 5.0]) + + +def test_identity(): + a = scalar("a") + a.tag.test_value = 10 + + out = aes.identity(a) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +@pytest.mark.parametrize( + "x, y, x_val, y_val", + [ + (scalar("x"), scalar("y"), np.array(10), np.array(20)), + (scalar("x"), vector("y"), np.array(10), np.arange(10, 20)), + ( + matrix("x"), + vector("y"), + np.arange(10 * 20).reshape((20, 10)), + np.arange(10, 20), + ), + ], +) +def test_jax_Composite(x, y, x_val, y_val): + x_s = aes.float64("x") + y_s = aes.float64("y") + + comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2 + aes.exp(x_s - y_s)])) + + out = comp_op(x, y) + + out_fg = FunctionGraph([x, y], [out]) + + test_input_vals = [ + x_val.astype(config.floatX), + y_val.astype(config.floatX), + ] + _ = compare_jax_and_py(out_fg, test_input_vals) + + +def test_erf(): + x = scalar("x") + out = erf(x) + fg = FunctionGraph([x], [out]) + + compare_jax_and_py(fg, [1.0]) + + +def test_erfc(): + x = scalar("x") + out = erfc(x) + fg = FunctionGraph([x], [out]) + + compare_jax_and_py(fg, [1.0]) + + +def test_erfinv(): + x = scalar("x") + out = erfinv(x) + fg = FunctionGraph([x], [out]) + + compare_jax_and_py(fg, [1.0]) + + +def test_psi(): + x = scalar("x") + out = psi(x) + fg = FunctionGraph([x], [out]) + compare_jax_and_py(fg, [3.0]) + + +def test_log1mexp(): + x = vector("x") + out = log1mexp(x) + fg = FunctionGraph([x], [out]) + + compare_jax_and_py(fg, [[-1.0, -0.75, -0.5, -0.25]]) + + +def test_nnet(): + x = vector("x") + x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + + out = sigmoid(x) + fgraph = FunctionGraph([x], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = at_nnet.ultra_fast_sigmoid(x) + fgraph = FunctionGraph([x], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + out = softplus(x) + fgraph = FunctionGraph([x], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +def test_jax_variadic_Scalar(): + mu = vector("mu", dtype=config.floatX) + mu.tag.test_value = np.r_[0.1, 1.1].astype(config.floatX) + tau = vector("tau", dtype=config.floatX) + tau.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + + res = -tau * mu + + fgraph = FunctionGraph([mu, tau], [res]) + + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + res = -tau * (tau - mu) ** 2 + + fgraph = FunctionGraph([mu, tau], [res]) + + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +def test_jax_multioutput(): + x = vector("x") + x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) + y = vector("y") + y.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) + + w = cosh(x**2 + y / 3.0) + v = cosh(x / 3.0 + y**2) + + fgraph = FunctionGraph([x, y], [w, v]) + + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +def test_jax_logp(): + mu = vector("mu") + mu.tag.test_value = np.r_[0.0, 0.0].astype(config.floatX) + tau = vector("tau") + tau.tag.test_value = np.r_[1.0, 1.0].astype(config.floatX) + sigma = vector("sigma") + sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(config.floatX) + value = vector("value") + value.tag.test_value = np.r_[0.1, -10].astype(config.floatX) + + logp = (-tau * (value - mu) ** 2 + log(tau / np.pi / 2.0)) / 2.0 + conditions = [sigma > 0] + alltrue = at_all([at_all(1 * val) for val in conditions]) + normal_logp = at.switch(alltrue, logp, -np.inf) + + fgraph = FunctionGraph([mu, tau, sigma, value], [normal_logp]) + + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py new file mode 100644 index 0000000000..158f8bd14d --- /dev/null +++ b/tests/link/jax/test_scan.py @@ -0,0 +1,146 @@ +import numpy as np +import pytest +from packaging.version import parse as version_parse + +import aesara.tensor as at +from aesara.configdefaults import config +from aesara.graph.fg import FunctionGraph +from aesara.scan.basic import scan +from aesara.tensor.math import gammaln, log +from aesara.tensor.type import ivector, lscalar, scalar +from tests.link.jax.test_basic import compare_jax_and_py + + +jax = pytest.importorskip("jax") + + +@pytest.mark.xfail( + version_parse(jax.__version__) >= version_parse("0.2.12"), + reason="Omnistaging cannot be disabled", +) +def test_jax_scan_multiple_output(): + """Test a scan implementation of a SEIR model. + + SEIR model definition: + S[t+1] = S[t] - B[t] + E[t+1] = E[t] +B[t] - C[t] + I[t+1] = I[t+1] + C[t] - D[t] + + B[t] ~ Binom(S[t], beta) + C[t] ~ Binom(E[t], gamma) + D[t] ~ Binom(I[t], delta) + """ + + def binomln(n, k): + return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1) + + def binom_log_prob(n, p, value): + return binomln(n, value) + value * log(p) + (n - value) * log(1 - p) + + # sequences + at_C = ivector("C_t") + at_D = ivector("D_t") + # outputs_info (initial conditions) + st0 = lscalar("s_t0") + et0 = lscalar("e_t0") + it0 = lscalar("i_t0") + logp_c = scalar("logp_c") + logp_d = scalar("logp_d") + # non_sequences + beta = scalar("beta") + gamma = scalar("gamma") + delta = scalar("delta") + + # TODO: Use random streams when their JAX conversions are implemented. + # trng = aesara.tensor.random.RandomStream(1234) + + def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): + # bt0 = trng.binomial(n=st0, p=beta) + bt0 = st0 * beta + bt0 = bt0.astype(st0.dtype) + + logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype) + logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype) + + st1 = st0 - bt0 + et1 = et0 + bt0 - ct0 + it1 = it0 + ct0 - dt0 + return st1, et1, it1, logp_c1, logp_d1 + + (st, et, it, logp_c_all, logp_d_all), _ = scan( + fn=seir_one_step, + sequences=[at_C, at_D], + outputs_info=[st0, et0, it0, logp_c, logp_d], + non_sequences=[beta, gamma, delta], + ) + st.name = "S_t" + et.name = "E_t" + it.name = "I_t" + logp_c_all.name = "C_t_logp" + logp_d_all.name = "D_t_logp" + + out_fg = FunctionGraph( + [at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta], + [st, et, it, logp_c_all, logp_d_all], + ) + + s0, e0, i0 = 100, 50, 25 + logp_c0 = np.array(0.0, dtype=config.floatX) + logp_d0 = np.array(0.0, dtype=config.floatX) + beta_val, gamma_val, delta_val = [ + np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753] + ] + C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32) + D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32) + + test_input_vals = [ + C, + D, + s0, + e0, + i0, + logp_c0, + logp_d0, + beta_val, + gamma_val, + delta_val, + ] + compare_jax_and_py(out_fg, test_input_vals) + + +@pytest.mark.xfail( + version_parse(jax.__version__) >= version_parse("0.2.12"), + reason="Omnistaging cannot be disabled", +) +def test_jax_scan_tap_output(): + + a_at = scalar("a") + + def input_step_fn(y_tm1, y_tm3, a): + y_tm1.name = "y_tm1" + y_tm3.name = "y_tm3" + res = (y_tm1 + y_tm3) * a + res.name = "y_t" + return res + + y_scan_at, _ = scan( + fn=input_step_fn, + outputs_info=[ + { + "initial": at.as_tensor_variable( + np.r_[-1.0, 1.3, 0.0].astype(config.floatX) + ), + "taps": [-1, -3], + }, + ], + non_sequences=[a_at], + n_steps=10, + name="y_scan", + ) + y_scan_at.name = "y" + y_scan_at.owner.inputs[0].name = "y_all" + + out_fg = FunctionGraph([a_at], [y_scan_at]) + + test_input_vals = [np.array(10.0).astype(config.floatX)] + compare_jax_and_py(out_fg, test_input_vals) diff --git a/tests/link/jax/test_shape.py b/tests/link/jax/test_shape.py new file mode 100644 index 0000000000..6b1bd442fa --- /dev/null +++ b/tests/link/jax/test_shape.py @@ -0,0 +1,90 @@ +import jax +import numpy as np +import pytest +from packaging.version import parse as version_parse + +import aesara.tensor as at +from aesara.compile.ops import DeepCopyOp, ViewOp +from aesara.configdefaults import config +from aesara.graph.fg import FunctionGraph +from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, Unbroadcast, reshape +from aesara.tensor.type import iscalar, vector +from tests.link.jax.test_basic import compare_jax_and_py + + +def test_jax_shape_ops(): + x_np = np.zeros((20, 3)) + x = Shape()(at.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_jax_and_py(x_fg, [], must_be_device_array=False) + + x = Shape_i(1)(at.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_jax_and_py(x_fg, [], must_be_device_array=False) + + +@pytest.mark.xfail( + version_parse(jax.__version__) >= version_parse("0.2.12"), + reason="Omnistaging cannot be disabled", +) +def test_jax_specify_shape(): + x_np = np.zeros((20, 3)) + x = SpecifyShape()(at.as_tensor_variable(x_np), (20, 3)) + x_fg = FunctionGraph([], [x]) + + compare_jax_and_py(x_fg, []) + + with config.change_flags(compute_test_value="off"): + + x = SpecifyShape()(at.as_tensor_variable(x_np), *(2, 3)) + x_fg = FunctionGraph([], [x]) + + with pytest.raises(AssertionError): + compare_jax_and_py(x_fg, []) + + +def test_jax_Reshape(): + a = vector("a") + x = reshape(a, (2, 2)) + x_fg = FunctionGraph([a], [x]) + compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + # Test breaking "omnistaging" changes in JAX. + # See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68 + x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2)) + x_fg = FunctionGraph([a], [x]) + with pytest.raises( + TypeError, + match="Shapes must be 1D sequences of concrete values of integer type", + ): + compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) + + b = iscalar("b") + x = reshape(a, (b, b)) + x_fg = FunctionGraph([a, b], [x]) + with pytest.raises( + TypeError, + match="Shapes must be 1D sequences of concrete values of integer type", + ): + compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]) + + +def test_jax_compile_ops(): + + x = DeepCopyOp()(at.as_tensor_variable(1.1)) + x_fg = FunctionGraph([], [x]) + + compare_jax_and_py(x_fg, []) + + x_np = np.zeros((20, 1, 1)) + x = Unbroadcast(0, 2)(at.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_jax_and_py(x_fg, []) + + x = ViewOp()(at.as_tensor_variable(x_np)) + x_fg = FunctionGraph([], [x]) + + compare_jax_and_py(x_fg, []) diff --git a/tests/link/jax/test_slinalg.py b/tests/link/jax/test_slinalg.py new file mode 100644 index 0000000000..ce63129145 --- /dev/null +++ b/tests/link/jax/test_slinalg.py @@ -0,0 +1,131 @@ +import numpy as np +import pytest + +import aesara.tensor as at +from aesara.configdefaults import config +from aesara.graph.fg import FunctionGraph +from aesara.tensor import nlinalg as at_nlinalg +from aesara.tensor import slinalg as at_slinalg +from aesara.tensor import subtensor as at_subtensor +from aesara.tensor.math import clip, cosh +from aesara.tensor.type import matrix, vector +from tests.link.jax.test_basic import compare_jax_and_py + + +def test_jax_basic(): + rng = np.random.default_rng(28494) + + x = matrix("x") + y = matrix("y") + b = vector("b") + + # `ScalarOp` + z = cosh(x**2 + y / 3.0) + + # `[Inc]Subtensor` + out = at_subtensor.set_subtensor(z[0], -10.0) + out = at_subtensor.inc_subtensor(out[0, 1], 2.0) + out = out[:5, :3] + + out_fg = FunctionGraph([x, y], [out]) + + test_input_vals = [ + np.tile(np.arange(10), (10, 1)).astype(config.floatX), + np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), + ] + (jax_res,) = compare_jax_and_py(out_fg, test_input_vals) + + # Confirm that the `Subtensor` slice operations are correct + assert jax_res.shape == (5, 3) + + # Confirm that the `IncSubtensor` operations are correct + assert jax_res[0, 0] == -10.0 + assert jax_res[0, 1] == -8.0 + + out = clip(x, y, 5) + out_fg = FunctionGraph([x, y], [out]) + compare_jax_and_py(out_fg, test_input_vals) + + out = at.diagonal(x, 0) + out_fg = FunctionGraph([x], [out]) + compare_jax_and_py( + out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] + ) + + out = at_slinalg.cholesky(x) + out_fg = FunctionGraph([x], [out]) + compare_jax_and_py( + out_fg, + [ + (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( + config.floatX + ) + ], + ) + + # not sure why this isn't working yet with lower=False + out = at_slinalg.Cholesky(lower=False)(x) + out_fg = FunctionGraph([x], [out]) + compare_jax_and_py( + out_fg, + [ + (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( + config.floatX + ) + ], + ) + + out = at_slinalg.solve(x, b) + out_fg = FunctionGraph([x, b], [out]) + compare_jax_and_py( + out_fg, + [ + np.eye(10).astype(config.floatX), + np.arange(10).astype(config.floatX), + ], + ) + + out = at.diag(b) + out_fg = FunctionGraph([b], [out]) + compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)]) + + out = at_nlinalg.det(x) + out_fg = FunctionGraph([x], [out]) + compare_jax_and_py( + out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] + ) + + out = at_nlinalg.matrix_inverse(x) + out_fg = FunctionGraph([x], [out]) + compare_jax_and_py( + out_fg, + [ + (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( + config.floatX + ) + ], + ) + + +@pytest.mark.parametrize("check_finite", [False, True]) +@pytest.mark.parametrize("lower", [False, True]) +@pytest.mark.parametrize("trans", [0, 1, 2]) +def test_jax_SolveTriangular(trans, lower, check_finite): + x = matrix("x") + b = vector("b") + + out = at_slinalg.solve_triangular( + x, + b, + trans=trans, + lower=lower, + check_finite=check_finite, + ) + out_fg = FunctionGraph([x, b], [out]) + compare_jax_and_py( + out_fg, + [ + np.eye(10).astype(config.floatX), + np.arange(10).astype(config.floatX), + ], + ) diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py new file mode 100644 index 0000000000..ef9738e0cd --- /dev/null +++ b/tests/link/jax/test_tensor_basic.py @@ -0,0 +1,110 @@ +import numpy as np +import pytest + +import aesara.tensor.basic as at +from aesara.configdefaults import config +from aesara.graph.fg import FunctionGraph +from aesara.graph.op import get_test_value +from aesara.tensor.type import matrix, scalar, vector +from tests.link.jax.test_basic import compare_jax_and_py + + +def test_jax_Alloc(): + x = at.alloc(0.0, 2, 3) + x_fg = FunctionGraph([], [x]) + + (jax_res,) = compare_jax_and_py(x_fg, []) + + assert jax_res.shape == (2, 3) + + x = at.alloc(1.1, 2, 3) + x_fg = FunctionGraph([], [x]) + + compare_jax_and_py(x_fg, []) + + x = at.AllocEmpty("float32")(2, 3) + x_fg = FunctionGraph([], [x]) + + def compare_shape_dtype(x, y): + (x,) = x + (y,) = y + return x.shape == y.shape and x.dtype == y.dtype + + compare_jax_and_py(x_fg, [], assert_fn=compare_shape_dtype) + + a = scalar("a") + x = at.alloc(a, 20) + x_fg = FunctionGraph([a], [x]) + + compare_jax_and_py(x_fg, [10.0]) + + a = vector("a") + x = at.alloc(a, 20, 10) + x_fg = FunctionGraph([a], [x]) + + compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)]) + + +def test_jax_MakeVector(): + x = at.make_vector(1, 2, 3) + x_fg = FunctionGraph([], [x]) + + compare_jax_and_py(x_fg, []) + + +@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs") +def test_arange_nonconcrete(): + + a = scalar("a") + a.tag.test_value = 10 + + out = at.arange(a) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + + +def test_jax_Join(): + a = matrix("a") + b = matrix("b") + + x = at.join(0, a, b) + x_fg = FunctionGraph([a, b], [x]) + compare_jax_and_py( + x_fg, + [ + np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), + np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), + ], + ) + compare_jax_and_py( + x_fg, + [ + np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), + np.c_[[4.0, 5.0]].astype(config.floatX), + ], + ) + + x = at.join(1, a, b) + x_fg = FunctionGraph([a, b], [x]) + compare_jax_and_py( + x_fg, + [ + np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), + np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), + ], + ) + compare_jax_and_py( + x_fg, + [ + np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX), + np.c_[[5.0, 6.0]].astype(config.floatX), + ], + ) + + +def test_jax_eye(): + """Tests jaxification of the Eye operator""" + out = at.eye(3) + out_fg = FunctionGraph([], [out]) + + compare_jax_and_py(out_fg, []) diff --git a/tests/link/numba/__init__.py b/tests/link/numba/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py new file mode 100644 index 0000000000..d3d6a1d870 --- /dev/null +++ b/tests/link/numba/test_basic.py @@ -0,0 +1,983 @@ +import contextlib +import inspect +from unittest import mock + +import numba +import numpy as np +import pytest + +import aesara.scalar as aes +import aesara.scalar.math as aesm +import aesara.tensor as at +import aesara.tensor.math as aem +from aesara import config, shared +from aesara.compile.function import function +from aesara.compile.mode import Mode +from aesara.compile.ops import ViewOp +from aesara.compile.sharedvalue import SharedVariable +from aesara.graph.basic import Apply, Constant +from aesara.graph.fg import FunctionGraph +from aesara.graph.op import Op, get_test_value +from aesara.graph.rewriting.db import RewriteDatabaseQuery +from aesara.graph.type import Type +from aesara.ifelse import ifelse +from aesara.link.numba.dispatch import basic as numba_basic +from aesara.link.numba.dispatch import numba_typify +from aesara.link.numba.linker import NumbaLinker +from aesara.raise_op import assert_op +from aesara.tensor import blas +from aesara.tensor import subtensor as at_subtensor +from aesara.tensor.elemwise import Elemwise +from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape + + +class MyType(Type): + def filter(self, data): + return data + + def __eq__(self, other): + return isinstance(other, MyType) + + def __hash__(self): + return hash(MyType) + + +class MyOp(Op): + def perform(self, *args): + pass + + +class MySingleOut(Op): + def make_node(self, a, b): + return Apply(self, [a, b], [a.type()]) + + def perform(self, node, inputs, outputs): + res = (inputs[0] + inputs[1]).astype(inputs[0][0].dtype) + outputs[0][0] = res + + +class MyMultiOut(Op): + nin = 2 + nout = 2 + + @staticmethod + def impl(a, b): + res1 = 2 * a + res2 = 2 * b + return [res1, res2] + + def make_node(self, a, b): + return Apply(self, [a, b], [a.type(), b.type()]) + + def perform(self, node, inputs, outputs): + res1, res2 = self.impl(inputs[0], inputs[1]) + outputs[0][0] = res1 + outputs[1][0] = res2 + + +my_multi_out = Elemwise(MyMultiOut()) +my_multi_out.ufunc = MyMultiOut.impl +my_multi_out.ufunc.nin = 2 +my_multi_out.ufunc.nout = 2 + +opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) +numba_mode = Mode(NumbaLinker(), opts) +py_mode = Mode("py", opts) + +rng = np.random.default_rng(42849) + + +def set_test_value(x, v): + x.tag.test_value = v + return x + + +def compare_shape_dtype(x, y): + (x,) = x + (y,) = y + return x.shape == y.shape and x.dtype == y.dtype + + +def eval_python_only(fn_inputs, fgraph, inputs): + """Evaluate the Numba implementation in pure Python for coverage purposes.""" + + def py_tuple_setitem(t, i, v): + ll = list(t) + ll[i] = v + return tuple(ll) + + def py_to_scalar(x): + if isinstance(x, np.ndarray): + return x.item() + else: + return x + + def njit_noop(*args, **kwargs): + if len(args) == 1 and callable(args[0]): + return args[0] + else: + return lambda x: x + + def vectorize_noop(*args, **kwargs): + def wrap(fn): + # `numba.vectorize` allows an `out` positional argument. We need + # to account for that + sig = inspect.signature(fn) + nparams = len(sig.parameters) + + def inner_vec(*args): + if len(args) > nparams: + # An `out` argument has been specified for an in-place + # operation + out = args[-1] + out[...] = np.vectorize(fn)(*args[:nparams]) + return out + else: + return np.vectorize(fn)(*args) + + return inner_vec + + if len(args) == 1 and callable(args[0]): + return wrap(args[0], **kwargs) + else: + return wrap + + mocks = [ + mock.patch("numba.njit", njit_noop), + mock.patch("numba.vectorize", vectorize_noop), + mock.patch("aesara.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem), + mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop), + mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop), + mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x), + mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar), + mock.patch( + "aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype", + lambda dtype: dtype, + ), + mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)), + ] + + with contextlib.ExitStack() as stack: + for ctx in mocks: + stack.enter_context(ctx) + + aesara_numba_fn = function( + fn_inputs, + fgraph.outputs, + mode=numba_mode, + accept_inplace=True, + ) + _ = aesara_numba_fn(*inputs) + + +def compare_numba_and_py(fgraph, inputs, assert_fn=None): + """Function to compare python graph output and Numba compiled output for testing equality + + In the tests below computational graphs are defined in Aesara. These graphs are then passed to + this function which then compiles the graphs in both Numba and python, runs the calculation + in both and checks if the results are the same + + Parameters + ---------- + fgraph: FunctionGraph + Aesara function Graph object + inputs: iter + Inputs for function graph + assert_fn: func, opt + Assert function used to check for equality between python and Numba. If not + provided uses np.testing.assert_allclose + + """ + if assert_fn is None: + + def assert_fn(x, y): + return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype( + x, y + ) + + fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] + + aesara_py_fn = function( + fn_inputs, fgraph.outputs, mode=py_mode, accept_inplace=True + ) + py_res = aesara_py_fn(*inputs) + + aesara_numba_fn = function( + fn_inputs, + fgraph.outputs, + mode=numba_mode, + accept_inplace=True, + ) + numba_res = aesara_numba_fn(*inputs) + + # Get some coverage + eval_python_only(fn_inputs, fgraph, inputs) + + if len(fgraph.outputs) > 1: + for j, p in zip(numba_res, py_res): + assert_fn(j, p) + else: + assert_fn(numba_res, py_res) + + return numba_res + + +@pytest.mark.parametrize( + "v, expected, force_scalar, not_implemented", + [ + (MyType(), None, False, True), + (aes.float32, numba.types.float32, False, False), + (at.fscalar, numba.types.Array(numba.types.float32, 0, "A"), False, False), + (at.fscalar, numba.types.float32, True, False), + (at.lvector, numba.types.int64[:], False, False), + (at.dmatrix, numba.types.float64[:, :], False, False), + (at.dmatrix, numba.types.float64, True, False), + ], +) +def test_get_numba_type(v, expected, force_scalar, not_implemented): + cm = ( + contextlib.suppress() + if not not_implemented + else pytest.raises(NotImplementedError) + ) + with cm: + res = numba_basic.get_numba_type(v, force_scalar=force_scalar) + assert res == expected + + +@pytest.mark.parametrize( + "v, expected, force_scalar", + [ + (Apply(MyOp(), [], []), numba.types.void(), False), + (Apply(MyOp(), [], []), numba.types.void(), True), + ( + Apply(MyOp(), [at.lvector()], []), + numba.types.void(numba.types.int64[:]), + False, + ), + (Apply(MyOp(), [at.lvector()], []), numba.types.void(numba.types.int64), True), + ( + Apply(MyOp(), [at.dmatrix(), aes.float32()], [at.dmatrix()]), + numba.types.float64[:, :](numba.types.float64[:, :], numba.types.float32), + False, + ), + ( + Apply(MyOp(), [at.dmatrix(), aes.float32()], [at.dmatrix()]), + numba.types.float64(numba.types.float64, numba.types.float32), + True, + ), + ( + Apply(MyOp(), [at.dmatrix(), aes.float32()], [at.dmatrix(), aes.int32()]), + numba.types.Tuple([numba.types.float64[:, :], numba.types.int32])( + numba.types.float64[:, :], numba.types.float32 + ), + False, + ), + ( + Apply(MyOp(), [at.dmatrix(), aes.float32()], [at.dmatrix(), aes.int32()]), + numba.types.Tuple([numba.types.float64, numba.types.int32])( + numba.types.float64, numba.types.float32 + ), + True, + ), + ], +) +def test_create_numba_signature(v, expected, force_scalar): + res = numba_basic.create_numba_signature(v, force_scalar=force_scalar) + assert res == expected + + +@pytest.mark.parametrize( + "input, wrapper_fn, check_fn", + [ + ( + np.random.RandomState(1), + numba_typify, + lambda x, y: np.all(x.get_state()[1] == y.get_state()[1]), + ) + ], +) +def test_box_unbox(input, wrapper_fn, check_fn): + input = wrapper_fn(input) + + pass_through = numba.njit(lambda x: x) + res = pass_through(input) + + assert isinstance(res, type(input)) + assert check_fn(res, input) + + +@pytest.mark.parametrize( + "x, indices", + [ + (at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1,)), + ( + at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + (slice(None)), + ), + (at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1, 2, 0)), + ( + at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + (slice(1, 2), 1, slice(None)), + ), + ], +) +def test_Subtensor(x, indices): + """Test NumPy's basic indexing.""" + out_at = x[indices] + assert isinstance(out_at.owner.op, at_subtensor.Subtensor) + out_fg = FunctionGraph([], [out_at]) + compare_numba_and_py(out_fg, []) + + +@pytest.mark.parametrize( + "x, indices", + [ + (at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)), + ], +) +def test_AdvancedSubtensor1(x, indices): + """Test NumPy's advanced indexing in one dimension.""" + out_at = at_subtensor.advanced_subtensor1(x, *indices) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor1) + out_fg = FunctionGraph([], [out_at]) + compare_numba_and_py(out_fg, []) + + +@pytest.mark.parametrize( + "x, indices", + [ + (at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])), + ( + at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + ([1, 2], slice(None), [3, 4]), + ), + ], +) +def test_AdvancedSubtensor(x, indices): + """Test NumPy's advanced indexing in more than one dimension.""" + out_at = x[indices] + assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_numba_and_py(out_fg, []) + + +@pytest.mark.parametrize( + "x, y, indices", + [ + ( + at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + at.as_tensor(np.array(10)), + (1,), + ), + ( + at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + at.as_tensor(rng.poisson(size=(4, 5))), + (slice(None)), + ), + ( + at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + at.as_tensor(np.array(10)), + (1, 2, 0), + ), + ( + at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + at.as_tensor(rng.poisson(size=(1, 5))), + (slice(1, 2), 1, slice(None)), + ), + ], +) +def test_IncSubtensor(x, y, indices): + out_at = at.set_subtensor(x[indices], y) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_numba_and_py(out_fg, []) + + out_at = at.inc_subtensor(x[indices], y) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_numba_and_py(out_fg, []) + + x_at = x.type() + out_at = at.set_subtensor(x_at[indices], y, inplace=True) + assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) + out_fg = FunctionGraph([x_at], [out_at]) + compare_numba_and_py(out_fg, [x.data]) + + +@pytest.mark.parametrize( + "x, y, indices", + [ + ( + at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + at.as_tensor(rng.poisson(size=(2, 4, 5))), + ([1, 2],), + ), + ( + at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + at.as_tensor(rng.poisson(size=(2, 4, 5))), + ([1, 1],), + ), + ], +) +def test_AdvancedIncSubtensor1(x, y, indices): + out_at = at_subtensor.advanced_set_subtensor1(x, y, *indices) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1) + out_fg = FunctionGraph([], [out_at]) + compare_numba_and_py(out_fg, []) + + out_at = at_subtensor.advanced_inc_subtensor1(x, y, *indices) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1) + out_fg = FunctionGraph([], [out_at]) + compare_numba_and_py(out_fg, []) + + x_at = x.type() + out_at = at_subtensor.AdvancedIncSubtensor1(inplace=True)(x_at, y, *indices) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1) + out_fg = FunctionGraph([x_at], [out_at]) + compare_numba_and_py(out_fg, [x.data]) + + +@pytest.mark.parametrize( + "x, y, indices", + [ + ( + at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + at.as_tensor(rng.poisson(size=(2, 5))), + ([1, 2], [2, 3]), + ), + ( + at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + at.as_tensor(rng.poisson(size=(2, 4))), + ([1, 2], slice(None), [3, 4]), + ), + pytest.param( + at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), + at.as_tensor(rng.poisson(size=(2, 5))), + ([1, 1], [2, 2]), + marks=pytest.mark.xfail( + reason="Duplicate index handling hasn't been implemented, yet." + ), + ), + ], +) +def test_AdvancedIncSubtensor(x, y, indices): + out_at = at.set_subtensor(x[indices], y) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_numba_and_py(out_fg, []) + + out_at = at.inc_subtensor(x[indices], y) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + compare_numba_and_py(out_fg, []) + + x_at = x.type() + out_at = at.set_subtensor(x_at[indices], y) + # Inplace isn't really implemented for `AdvancedIncSubtensor`, so we just + # hack it on here + out_at.owner.op.inplace = True + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([x_at], [out_at]) + compare_numba_and_py(out_fg, [x.data]) + + +@pytest.mark.parametrize( + "x, i", + [ + (np.zeros((20, 3)), 1), + ], +) +def test_Shape(x, i): + g = Shape()(at.as_tensor_variable(x)) + g_fg = FunctionGraph([], [g]) + + compare_numba_and_py(g_fg, []) + + g = Shape_i(i)(at.as_tensor_variable(x)) + g_fg = FunctionGraph([], [g]) + + compare_numba_and_py(g_fg, []) + + +@pytest.mark.parametrize( + "v, shape, ndim", + [ + (set_test_value(at.vector(), np.array([4], dtype=config.floatX)), (), 0), + (set_test_value(at.vector(), np.arange(4, dtype=config.floatX)), (2, 2), 2), + ( + set_test_value(at.vector(), np.arange(4, dtype=config.floatX)), + set_test_value(at.lvector(), np.array([2, 2], dtype="int64")), + 2, + ), + ], +) +def test_Reshape(v, shape, ndim): + g = Reshape(ndim)(v, shape) + g_fg = FunctionGraph(outputs=[g]) + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +def test_Reshape_scalar(): + v = at.vector() + v.tag.test_value = np.array([1.0], dtype=config.floatX) + g = Reshape(1)(v[0], (1,)) + g_fg = FunctionGraph(outputs=[g]) + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "v, shape, fails", + [ + ( + set_test_value(at.matrix(), np.array([[1.0]], dtype=config.floatX)), + (1, 1), + False, + ), + ( + set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), + (1, 1), + True, + ), + ( + set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), + (1, None), + False, + ), + ], +) +def test_SpecifyShape(v, shape, fails): + g = SpecifyShape()(v, *shape) + g_fg = FunctionGraph(outputs=[g]) + cm = contextlib.suppress() if not fails else pytest.raises(AssertionError) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "v", + [ + set_test_value(at.vector(), np.arange(4, dtype=config.floatX)), + ], +) +def test_ViewOp(v): + g = ViewOp()(v) + g_fg = FunctionGraph(outputs=[g]) + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "inputs, op, exc", + [ + ( + [ + set_test_value( + at.matrix(), rng.random(size=(2, 3)).astype(config.floatX) + ), + set_test_value(at.lmatrix(), rng.poisson(size=(2, 3))), + ], + MySingleOut, + UserWarning, + ), + ( + [ + set_test_value( + at.matrix(), rng.random(size=(2, 3)).astype(config.floatX) + ), + set_test_value(at.lmatrix(), rng.poisson(size=(2, 3))), + ], + MyMultiOut, + UserWarning, + ), + ], +) +def test_perform(inputs, op, exc): + + g = op()(*inputs) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +def test_perform_params(): + """This tests for `Op.perform` implementations that require the `params` arguments.""" + + x = at.vector() + x.tag.test_value = np.array([1.0, 2.0], dtype=config.floatX) + + out = assert_op(x, np.array(True)) + + if not isinstance(out, (list, tuple)): + out = [out] + + out_fg = FunctionGraph([x], out) + + with pytest.warns(UserWarning, match=".*object mode.*"): + compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) + + +def test_perform_type_convert(): + """This tests the use of `Type.filter` in `objmode`. + + The `Op.perform` takes a single input that it returns as-is, but it gets a + native scalar and it's supposed to return an `np.ndarray`. + """ + + x = at.vector() + x.tag.test_value = np.array([1.0, 2.0], dtype=config.floatX) + + out = assert_op(x.sum(), np.array(True)) + + if not isinstance(out, (list, tuple)): + out = [out] + + out_fg = FunctionGraph([x], out) + + with pytest.warns(UserWarning, match=".*object mode.*"): + compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) + + +@pytest.mark.parametrize( + "x, y, exc", + [ + ( + set_test_value(at.matrix(), rng.random(size=(3, 2)).astype(config.floatX)), + set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), + None, + ), + ( + set_test_value( + at.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64") + ), + set_test_value( + at.vector(dtype="float32"), rng.random(size=(2,)).astype("float32") + ), + None, + ), + ( + set_test_value(at.lmatrix(), rng.poisson(size=(3, 2))), + set_test_value(at.fvector(), rng.random(size=(2,)).astype("float32")), + None, + ), + ( + set_test_value(at.lvector(), rng.random(size=(2,)).astype(np.int64)), + set_test_value(at.lvector(), rng.random(size=(2,)).astype(np.int64)), + None, + ), + ], +) +def test_Dot(x, y, exc): + g = aem.Dot()(x, y) + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "x, exc", + [ + ( + set_test_value(aes.float64(), np.array(0.0, dtype="float64")), + None, + ), + ( + set_test_value(aes.float64(), np.array(-32.0, dtype="float64")), + None, + ), + ( + set_test_value(aes.float64(), np.array(-40.0, dtype="float64")), + None, + ), + ( + set_test_value(aes.float64(), np.array(32.0, dtype="float64")), + None, + ), + ( + set_test_value(aes.float64(), np.array(40.0, dtype="float64")), + None, + ), + ( + set_test_value(aes.int64(), np.array(32, dtype="int64")), + None, + ), + ], +) +def test_Softplus(x, exc): + g = aesm.Softplus(aes.upgrade_to_float)(x) + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "x, y, exc", + [ + ( + set_test_value( + at.dmatrix(), + rng.random(size=(3, 3)).astype("float64"), + ), + set_test_value( + at.dmatrix(), + rng.random(size=(3, 3)).astype("float64"), + ), + None, + ), + ( + set_test_value( + at.dmatrix(), + rng.random(size=(3, 3)).astype("float64"), + ), + set_test_value( + at.lmatrix(), + rng.poisson(size=(3, 3)).astype("int64"), + ), + None, + ), + ], +) +def test_BatchedDot(x, y, exc): + g = blas.BatchedDot()(x, y) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +def test_shared(): + a = shared(np.array([1, 2, 3], dtype=config.floatX)) + + aesara_numba_fn = function([], a, mode="NUMBA") + numba_res = aesara_numba_fn() + + np.testing.assert_allclose(numba_res, a.get_value()) + + aesara_numba_fn = function([], a * 2, mode="NUMBA") + numba_res = aesara_numba_fn() + + np.testing.assert_allclose(numba_res, a.get_value() * 2) + + # Changed the shared value and make sure that the Numba-compiled function + # also changes. + new_a_value = np.array([3, 4, 5], dtype=config.floatX) + a.set_value(new_a_value) + + numba_res = aesara_numba_fn() + np.testing.assert_allclose(numba_res, new_a_value * 2) + + +def test_shared_updates(): + a = shared(0) + + aesara_numba_fn = function([], a, updates={a: a + 1}, mode="NUMBA") + res1, res2 = aesara_numba_fn(), aesara_numba_fn() + assert res1 == 0 + assert res2 == 1 + assert a.get_value() == 2 + + a.set_value(5) + res1, res2 = aesara_numba_fn(), aesara_numba_fn() + assert res1 == 5 + assert res2 == 6 + assert a.get_value() == 7 + + +# We were seeing some weird results in CI where the following two almost +# sign-swapped results were being return from Numba and Python, respectively. +# The issue might be related to https://github.com/numba/numba/issues/4519. +# Regardless, I was not able to reproduce anything like it locally after +# extensive testing. +x = np.array( + [ + [-0.60407637, -0.71177603, -0.35842241], + [-0.07735968, 0.50000561, -0.86256007], + [-0.7931628, 0.49332471, 0.35710434], + ], + dtype=np.float64, +) + +y = np.array( + [ + [0.60407637, 0.71177603, -0.35842241], + [0.07735968, -0.50000561, -0.86256007], + [0.7931628, -0.49332471, 0.35710434], + ], + dtype=np.float64, +) + + +@pytest.mark.parametrize( + "inputs, cond_fn, true_vals, false_vals", + [ + ([], lambda: np.array(True), np.r_[1, 2, 3], np.r_[-1, -2, -3]), + ( + [set_test_value(at.dscalar(), np.array(0.2, dtype=np.float64))], + lambda x: x < 0.5, + np.r_[1, 2, 3], + np.r_[-1, -2, -3], + ), + ( + [ + set_test_value(at.dscalar(), np.array(0.3, dtype=np.float64)), + set_test_value(at.dscalar(), np.array(0.5, dtype=np.float64)), + ], + lambda x, y: x > y, + x, + y, + ), + ( + [ + set_test_value(at.dvector(), np.array([0.3, 0.1], dtype=np.float64)), + set_test_value(at.dvector(), np.array([0.5, 0.9], dtype=np.float64)), + ], + lambda x, y: at.all(x > y), + x, + y, + ), + ( + [ + set_test_value(at.dvector(), np.array([0.3, 0.1], dtype=np.float64)), + set_test_value(at.dvector(), np.array([0.5, 0.9], dtype=np.float64)), + ], + lambda x, y: at.all(x > y), + [x, 2 * x], + [y, 3 * y], + ), + ( + [ + set_test_value(at.dvector(), np.array([0.5, 0.9], dtype=np.float64)), + set_test_value(at.dvector(), np.array([0.3, 0.1], dtype=np.float64)), + ], + lambda x, y: at.all(x > y), + [x, 2 * x], + [y, 3 * y], + ), + ], +) +def test_IfElse(inputs, cond_fn, true_vals, false_vals): + + out = ifelse(cond_fn(*inputs), true_vals, false_vals) + + if not isinstance(out, list): + out = [out] + + out_fg = FunctionGraph(inputs, out) + + compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) + + +@pytest.mark.xfail(reason="https://github.com/numba/numba/issues/7409") +def test_config_options_parallel(): + x = at.dvector() + + with config.change_flags(numba__vectorize_target="parallel"): + aesara_numba_fn = function([x], x * 2, mode=numba_mode) + numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"] + assert numba_mul_fn.targetoptions["parallel"] is True + + +def test_config_options_fastmath(): + x = at.dvector() + + with config.change_flags(numba__fastmath=True): + aesara_numba_fn = function([x], x * 2, mode=numba_mode) + numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"] + assert numba_mul_fn.targetoptions["fastmath"] is True + + +def test_config_options_cached(): + x = at.dvector() + + with config.change_flags(numba__cache=True): + aesara_numba_fn = function([x], x * 2, mode=numba_mode) + numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"] + assert not isinstance( + numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache + ) + + with config.change_flags(numba__cache=False): + aesara_numba_fn = function([x], x * 2, mode=numba_mode) + numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"] + assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache) + + +def test_scalar_return_value_conversion(): + r"""Make sure that we convert \"native\" scalars to `ndarray`\s in the graph outputs.""" + x = at.scalar(name="x") + x_fn = function( + [x], + 2 * x, + mode=numba_mode, + ) + assert isinstance(x_fn(1.0), np.ndarray) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py new file mode 100644 index 0000000000..c7b30c8e6a --- /dev/null +++ b/tests/link/numba/test_elemwise.py @@ -0,0 +1,509 @@ +import contextlib + +import numpy as np +import pytest + +import aesara.tensor as at +import aesara.tensor.inplace as ati +import aesara.tensor.math as aem +import aesara.tensor.nnet.basic as nnetb +from aesara import config +from aesara.compile.ops import deep_copy_op +from aesara.compile.sharedvalue import SharedVariable +from aesara.graph.basic import Constant +from aesara.graph.fg import FunctionGraph +from aesara.tensor import elemwise as at_elemwise +from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum +from tests.link.numba.test_basic import ( + compare_numba_and_py, + my_multi_out, + set_test_value, +) + + +rng = np.random.default_rng(42849) + + +@pytest.mark.parametrize( + "inputs, input_vals, output_fn, exc", + [ + ( + [at.vector()], + [rng.uniform(size=100).astype(config.floatX)], + lambda x: at.gammaln(x), + None, + ), + ( + [at.vector()], + [rng.standard_normal(100).astype(config.floatX)], + lambda x: at.sigmoid(x), + None, + ), + ( + [at.vector()], + [rng.standard_normal(100).astype(config.floatX)], + lambda x: at.log1mexp(x), + None, + ), + ( + [at.vector()], + [rng.standard_normal(100).astype(config.floatX)], + lambda x: at.erf(x), + None, + ), + ( + [at.vector()], + [rng.standard_normal(100).astype(config.floatX)], + lambda x: at.erfc(x), + None, + ), + ( + [at.vector() for i in range(4)], + [rng.standard_normal(100).astype(config.floatX) for i in range(4)], + lambda x, y, x1, y1: (x + y) * (x1 + y1) * y, + None, + ), + ( + [at.matrix(), at.scalar()], + [rng.normal(size=(2, 2)).astype(config.floatX), 0.0], + lambda a, b: at.switch(a, b, a), + None, + ), + ( + [at.scalar(), at.scalar()], + [ + np.array(1.0, dtype=config.floatX), + np.array(1.0, dtype=config.floatX), + ], + lambda x, y: ati.add_inplace(deep_copy_op(x), deep_copy_op(y)), + None, + ), + ( + [at.vector(), at.vector()], + [ + rng.standard_normal(100).astype(config.floatX), + rng.standard_normal(100).astype(config.floatX), + ], + lambda x, y: ati.add_inplace(deep_copy_op(x), deep_copy_op(y)), + None, + ), + ( + [at.vector(), at.vector()], + [ + rng.standard_normal(100).astype(config.floatX), + rng.standard_normal(100).astype(config.floatX), + ], + lambda x, y: my_multi_out(x, y), + NotImplementedError, + ), + ], +) +def test_Elemwise(inputs, input_vals, output_fn, exc): + + outputs = output_fn(*inputs) + + out_fg = FunctionGraph( + outputs=[outputs] if not isinstance(outputs, list) else outputs + ) + + cm = contextlib.suppress() if exc is None else pytest.raises(exc) + with cm: + compare_numba_and_py(out_fg, input_vals) + + +@pytest.mark.parametrize( + "v, new_order", + [ + # `{'drop': [], 'shuffle': [], 'augment': [0, 1]}` + ( + set_test_value( + at.lscalar(name="a"), + np.array(1, dtype=np.int64), + ), + ("x", "x"), + ), + # I.e. `a_at.T` + # `{'drop': [], 'shuffle': [1, 0], 'augment': []}` + ( + set_test_value( + at.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + ), + (1, 0), + ), + # `{'drop': [], 'shuffle': [0, 1], 'augment': [2]}` + ( + set_test_value( + at.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) + ), + (1, 0, "x"), + ), + # `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}` + ( + set_test_value( + at.tensor(config.floatX, [False, True, False], name="a"), + np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=config.floatX), + ), + ("x", 2, "x", 0, "x"), + ), + # I.e. `a_at.dimshuffle((0,))` + # `{'drop': [1], 'shuffle': [0], 'augment': []}` + ( + set_test_value( + at.tensor(config.floatX, [False, True], name="a"), + np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX), + ), + (0,), + ), + ( + set_test_value( + at.tensor(config.floatX, [False, True], name="a"), + np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX), + ), + (0,), + ), + ( + set_test_value( + at.tensor(config.floatX, [True, True, True], name="a"), + np.array([[[1.0]]], dtype=config.floatX), + ), + (), + ), + ], +) +def test_Dimshuffle(v, new_order): + g = at_elemwise.DimShuffle(v.broadcastable, new_order)(v) + g_fg = FunctionGraph(outputs=[g]) + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "careduce_fn, axis, v", + [ + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Sum( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + 0, + set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x), + 0, + set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x), + 0, + set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x), + 0, + set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x), + 0, + set_test_value( + at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Sum( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + 0, + set_test_value( + at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Sum( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + (0, 1), + set_test_value( + at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Sum( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + (1, 0), + set_test_value( + at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Sum( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + None, + set_test_value( + at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Sum( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + 1, + set_test_value( + at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Prod( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + 0, + set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: ProdWithoutZeros( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + 0, + set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Prod( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + 0, + set_test_value( + at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Prod( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + 1, + set_test_value( + at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x), + None, + set_test_value( + at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x), + None, + set_test_value( + at.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x), + None, + set_test_value( + at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x), + None, + set_test_value( + at.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2)) + ), + ), + ], +) +def test_CAReduce(careduce_fn, axis, v): + g = careduce_fn(v, axis=axis) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +def test_scalar_Elemwise_Clip(): + a = at.scalar("a") + b = at.scalar("b") + + z = at.switch(1, a, b) + c = at.clip(z, 1, 3) + c_fg = FunctionGraph(outputs=[c]) + + compare_numba_and_py(c_fg, [1, 1]) + + +@pytest.mark.parametrize( + "dy, sm, axis, exc", + [ + ( + set_test_value( + at.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) + ), + set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + None, + None, + ), + ( + set_test_value( + at.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) + ), + set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + 0, + None, + ), + ( + set_test_value( + at.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) + ), + set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + 1, + None, + ), + ], +) +def test_SoftmaxGrad(dy, sm, axis, exc): + g = nnetb.SoftmaxGrad(axis=axis)(dy, sm) + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "x, axis, exc", + [ + ( + set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), + None, + None, + ), + ( + set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + None, + None, + ), + ( + set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + 0, + None, + ), + ], +) +def test_Softmax(x, axis, exc): + g = nnetb.Softmax(axis=axis)(x) + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "x, axis, exc", + [ + ( + set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), + None, + None, + ), + ( + set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + 0, + None, + ), + ( + set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), + 1, + None, + ), + ], +) +def test_LogSoftmax(x, axis, exc): + g = nnetb.LogSoftmax(axis=axis)(x) + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "x, axes, exc", + [ + ( + set_test_value(at.dscalar(), np.array(0.0, dtype="float64")), + [], + None, + ), + ( + set_test_value(at.dvector(), rng.random(size=(3,)).astype("float64")), + [0], + None, + ), + ( + set_test_value(at.dmatrix(), rng.random(size=(3, 2)).astype("float64")), + [0], + None, + ), + ( + set_test_value(at.dmatrix(), rng.random(size=(3, 2)).astype("float64")), + [0, 1], + None, + ), + ], +) +def test_MaxAndArgmax(x, axes, exc): + g = aem.MaxAndArgmax(axes)(x) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) diff --git a/tests/link/numba/test_extra_ops.py b/tests/link/numba/test_extra_ops.py new file mode 100644 index 0000000000..584b0831c3 --- /dev/null +++ b/tests/link/numba/test_extra_ops.py @@ -0,0 +1,480 @@ +import contextlib + +import numpy as np +import pytest + +import aesara.tensor as at +from aesara import config +from aesara.compile.sharedvalue import SharedVariable +from aesara.graph.basic import Constant +from aesara.graph.fg import FunctionGraph +from aesara.tensor import extra_ops +from tests.link.numba.test_basic import compare_numba_and_py, set_test_value + + +rng = np.random.default_rng(42849) + + +@pytest.mark.parametrize( + "val", + [ + set_test_value(at.lscalar(), np.array(6, dtype="int64")), + ], +) +def test_Bartlett(val): + g = extra_ops.bartlett(val) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "x, shape", + [ + ( + set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), + [set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]], + ), + ( + set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), + [at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)], + ), + ( + set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), + at.as_tensor([set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]]), + ), + ( + set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), + [at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)], + ), + ], +) +def test_BroadcastTo(x, shape): + g = extra_ops.BroadcastTo()(x, shape) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "val, axis, mode", + [ + ( + set_test_value( + at.matrix(), np.arange(3, dtype=config.floatX).reshape((3, 1)) + ), + 1, + "add", + ), + ( + set_test_value( + at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) + ), + 0, + "add", + ), + ( + set_test_value( + at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) + ), + 1, + "add", + ), + ( + set_test_value( + at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) + ), + 0, + "mul", + ), + ( + set_test_value( + at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) + ), + 1, + "mul", + ), + ], +) +def test_CumOp(val, axis, mode): + g = extra_ops.CumOp(axis=axis, mode=mode)(val) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "a, val", + [ + ( + set_test_value(at.lmatrix(), np.zeros((10, 2), dtype="int64")), + set_test_value(at.lscalar(), np.array(1, dtype="int64")), + ) + ], +) +def test_FillDiagonal(a, val): + g = extra_ops.FillDiagonal()(a, val) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "a, val, offset", + [ + ( + set_test_value(at.lmatrix(), np.zeros((10, 2), dtype="int64")), + set_test_value(at.lscalar(), np.array(1, dtype="int64")), + set_test_value(at.lscalar(), np.array(-1, dtype="int64")), + ), + ( + set_test_value(at.lmatrix(), np.zeros((10, 2), dtype="int64")), + set_test_value(at.lscalar(), np.array(1, dtype="int64")), + set_test_value(at.lscalar(), np.array(0, dtype="int64")), + ), + ( + set_test_value(at.lmatrix(), np.zeros((10, 3), dtype="int64")), + set_test_value(at.lscalar(), np.array(1, dtype="int64")), + set_test_value(at.lscalar(), np.array(1, dtype="int64")), + ), + ], +) +def test_FillDiagonalOffset(a, val, offset): + g = extra_ops.FillDiagonalOffset()(a, val, offset) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "arr, shape, mode, order, exc", + [ + ( + tuple(set_test_value(at.lscalar(), v) for v in np.array([0])), + set_test_value(at.lvector(), np.array([2])), + "raise", + "C", + None, + ), + ( + tuple(set_test_value(at.lscalar(), v) for v in np.array([0, 0, 3])), + set_test_value(at.lvector(), np.array([2, 3, 4])), + "raise", + "C", + None, + ), + ( + tuple( + set_test_value(at.lvector(), v) + for v in np.array([[0, 1], [2, 0], [1, 3]]) + ), + set_test_value(at.lvector(), np.array([2, 3, 4])), + "raise", + "C", + None, + ), + ( + tuple( + set_test_value(at.lvector(), v) + for v in np.array([[0, 1], [2, 0], [1, 3]]) + ), + set_test_value(at.lvector(), np.array([2, 3, 4])), + "raise", + "F", + NotImplementedError, + ), + ( + tuple( + set_test_value(at.lvector(), v) + for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) + ), + set_test_value(at.lvector(), np.array([2, 3, 4])), + "raise", + "C", + ValueError, + ), + ( + tuple( + set_test_value(at.lvector(), v) + for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) + ), + set_test_value(at.lvector(), np.array([2, 3, 4])), + "wrap", + "C", + None, + ), + ( + tuple( + set_test_value(at.lvector(), v) + for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) + ), + set_test_value(at.lvector(), np.array([2, 3, 4])), + "clip", + "C", + None, + ), + ], +) +def test_RavelMultiIndex(arr, shape, mode, order, exc): + g = extra_ops.RavelMultiIndex(mode, order)(*(arr + (shape,))) + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.raises(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "x, repeats, axis, exc", + [ + ( + set_test_value(at.lscalar(), np.array(1, dtype="int64")), + set_test_value(at.lscalar(), np.array(0, dtype="int64")), + None, + None, + ), + ( + set_test_value(at.lmatrix(), np.zeros((2, 2), dtype="int64")), + set_test_value(at.lscalar(), np.array(1, dtype="int64")), + None, + None, + ), + ( + set_test_value(at.lvector(), np.arange(2, dtype="int64")), + set_test_value(at.lvector(), np.array([1, 1], dtype="int64")), + None, + None, + ), + ( + set_test_value(at.lmatrix(), np.zeros((2, 2), dtype="int64")), + set_test_value(at.lscalar(), np.array(1, dtype="int64")), + 0, + UserWarning, + ), + ], +) +def test_Repeat(x, repeats, axis, exc): + g = extra_ops.Repeat(axis)(x, repeats) + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "x, axis, return_index, return_inverse, return_counts, exc", + [ + ( + set_test_value(at.lscalar(), np.array(1, dtype="int64")), + None, + False, + False, + False, + None, + ), + ( + set_test_value(at.lvector(), np.array([1, 1, 2], dtype="int64")), + None, + False, + False, + False, + None, + ), + ( + set_test_value(at.lmatrix(), np.array([[1, 1], [2, 2]], dtype="int64")), + None, + False, + False, + False, + None, + ), + ( + set_test_value( + at.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64") + ), + 0, + False, + False, + False, + UserWarning, + ), + ( + set_test_value( + at.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64") + ), + 0, + True, + True, + True, + UserWarning, + ), + ], +) +def test_Unique(x, axis, return_index, return_inverse, return_counts, exc): + g = extra_ops.Unique(return_index, return_inverse, return_counts, axis)(x) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "arr, shape, order, exc", + [ + ( + set_test_value(at.lvector(), np.array([9, 15, 1], dtype="int64")), + at.as_tensor([2, 3, 4]), + "C", + None, + ), + ( + set_test_value(at.lvector(), np.array([1, 0], dtype="int64")), + at.as_tensor([2]), + "C", + None, + ), + ( + set_test_value(at.lvector(), np.array([9, 15, 1], dtype="int64")), + at.as_tensor([2, 3, 4]), + "F", + NotImplementedError, + ), + ], +) +def test_UnravelIndex(arr, shape, order, exc): + g = extra_ops.UnravelIndex(order)(arr, shape) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.raises(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "a, v, side, sorter, exc", + [ + ( + set_test_value(at.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)), + set_test_value(at.matrix(), rng.random((3, 2)).astype(config.floatX)), + "left", + None, + None, + ), + pytest.param( + set_test_value( + at.vector(), + np.array([0.29769574, 0.71649186, 0.20475563]).astype(config.floatX), + ), + set_test_value( + at.matrix(), + np.array( + [ + [0.18847123, 0.39659508], + [0.56220006, 0.57428752], + [0.86720994, 0.44522637], + ] + ).astype(config.floatX), + ), + "left", + None, + None, + marks=pytest.mark.xfail( + reason="This won't work until https://github.com/numba/numba/pull/7005 is merged" + ), + ), + ( + set_test_value(at.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)), + set_test_value(at.matrix(), rng.random((3, 2)).astype(config.floatX)), + "right", + set_test_value(at.lvector(), np.array([0, 2, 1])), + UserWarning, + ), + ], +) +def test_Searchsorted(a, v, side, sorter, exc): + g = extra_ops.SearchsortedOp(side)(a, v, sorter) + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py new file mode 100644 index 0000000000..da427676b5 --- /dev/null +++ b/tests/link/numba/test_nlinalg.py @@ -0,0 +1,501 @@ +import contextlib + +import numpy as np +import pytest + +import aesara.tensor as at +from aesara.compile.sharedvalue import SharedVariable +from aesara.graph.basic import Constant +from aesara.graph.fg import FunctionGraph +from aesara.tensor import nlinalg, slinalg +from tests.link.numba.test_basic import compare_numba_and_py, set_test_value + + +rng = np.random.default_rng(42849) + + +@pytest.mark.parametrize( + "x, lower, exc", + [ + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + True, + None, + ), + ( + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))( + rng.integers(1, 10, size=(3, 3)).astype("int64") + ), + ), + True, + None, + ), + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + False, + UserWarning, + ), + ], +) +def test_Cholesky(x, lower, exc): + g = slinalg.Cholesky(lower)(x) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "A, x, lower, exc", + [ + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + set_test_value(at.dvector(), rng.random(size=(3,)).astype("float64")), + "gen", + None, + ), + ( + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))( + rng.integers(1, 10, size=(3, 3)).astype("int64") + ), + ), + set_test_value(at.dvector(), rng.random(size=(3,)).astype("float64")), + "gen", + None, + ), + ], +) +def test_Solve(A, x, lower, exc): + g = slinalg.Solve(lower)(A, x) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "A, x, lower, exc", + [ + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + set_test_value(at.dvector(), rng.random(size=(3,)).astype("float64")), + "sym", + UserWarning, + ), + ], +) +def test_SolveTriangular(A, x, lower, exc): + g = slinalg.SolveTriangular(lower)(A, x) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "x, exc", + [ + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + None, + ), + ( + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), + ), + None, + ), + ], +) +def test_Det(x, exc): + g = nlinalg.Det()(x) + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +# We were seeing some weird results in CI where the following two almost +# sign-swapped results were being return from Numba and Python, respectively. +# The issue might be related to https://github.com/numba/numba/issues/4519. +# Regardless, I was not able to reproduce anything like it locally after +# extensive testing. +x = np.array( + [ + [-0.60407637, -0.71177603, -0.35842241], + [-0.07735968, 0.50000561, -0.86256007], + [-0.7931628, 0.49332471, 0.35710434], + ], + dtype=np.float64, +) + +y = np.array( + [ + [0.60407637, 0.71177603, -0.35842241], + [0.07735968, -0.50000561, -0.86256007], + [0.7931628, -0.49332471, 0.35710434], + ], + dtype=np.float64, +) + + +@pytest.mark.parametrize( + "x, exc", + [ + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(x), + ), + None, + ), + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(y), + ), + None, + ), + ( + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))( + rng.integers(1, 10, size=(3, 3)).astype("int64") + ), + ), + None, + ), + ], +) +def test_Eig(x, exc): + g = nlinalg.Eig()(x) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "x, uplo, exc", + [ + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + "L", + None, + ), + ( + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))( + rng.integers(1, 10, size=(3, 3)).astype("int64") + ), + ), + "U", + UserWarning, + ), + ], +) +def test_Eigh(x, uplo, exc): + g = nlinalg.Eigh(uplo)(x) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "op, x, exc, op_args", + [ + ( + nlinalg.MatrixInverse, + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + None, + (), + ), + ( + nlinalg.MatrixInverse, + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))( + rng.integers(1, 10, size=(3, 3)).astype("int64") + ), + ), + None, + (), + ), + ( + nlinalg.Inv, + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + None, + (), + ), + ( + nlinalg.Inv, + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))( + rng.integers(1, 10, size=(3, 3)).astype("int64") + ), + ), + None, + (), + ), + ( + nlinalg.MatrixPinv, + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + None, + (True,), + ), + ( + nlinalg.MatrixPinv, + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))( + rng.integers(1, 10, size=(3, 3)).astype("int64") + ), + ), + None, + (False,), + ), + ], +) +def test_matrix_inverses(op, x, exc, op_args): + g = op(*op_args)(x) + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "x, mode, exc", + [ + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + "reduced", + None, + ), + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + "r", + None, + ), + ( + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))( + rng.integers(1, 10, size=(3, 3)).astype("int64") + ), + ), + "reduced", + None, + ), + ( + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))( + rng.integers(1, 10, size=(3, 3)).astype("int64") + ), + ), + "complete", + UserWarning, + ), + ], +) +def test_QRFull(x, mode, exc): + g = nlinalg.QRFull(mode)(x) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "x, full_matrices, compute_uv, exc", + [ + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + True, + True, + None, + ), + ( + set_test_value( + at.dmatrix(), + (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), + ), + False, + True, + None, + ), + ( + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))( + rng.integers(1, 10, size=(3, 3)).astype("int64") + ), + ), + True, + True, + None, + ), + ( + set_test_value( + at.lmatrix(), + (lambda x: x.T.dot(x))( + rng.integers(1, 10, size=(3, 3)).astype("int64") + ), + ), + True, + False, + UserWarning, + ), + ], +) +def test_SVD(x, full_matrices, compute_uv, exc): + g = nlinalg.SVD(full_matrices, compute_uv)(x) + + if isinstance(g, list): + g_fg = FunctionGraph(outputs=g) + else: + g_fg = FunctionGraph(outputs=[g]) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) diff --git a/tests/link/test_numba_performance.py b/tests/link/numba/test_performance.py similarity index 93% rename from tests/link/test_numba_performance.py rename to tests/link/numba/test_performance.py index 952ae4943f..2a7af04c73 100644 --- a/tests/link/test_numba_performance.py +++ b/tests/link/numba/test_performance.py @@ -7,12 +7,12 @@ from aesara import config from aesara.compile.function import function from aesara.compile.mode import Mode -from aesara.graph.optdb import OptimizationQuery +from aesara.graph.rewriting.db import RewriteDatabaseQuery from aesara.link.numba.linker import NumbaLinker from aesara.tensor.math import Max -opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) +opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) numba_mode = Mode(NumbaLinker(), opts) py_mode = Mode("py", opts) diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py new file mode 100644 index 0000000000..25d77a5a66 --- /dev/null +++ b/tests/link/numba/test_random.py @@ -0,0 +1,593 @@ +import contextlib + +import numpy as np +import pytest +import scipy.stats as stats + +import aesara.tensor as at +import aesara.tensor.random.basic as aer +from aesara import shared +from aesara.compile.function import function +from aesara.compile.sharedvalue import SharedVariable +from aesara.graph.basic import Constant +from aesara.graph.fg import FunctionGraph +from tests.link.numba.test_basic import ( + compare_numba_and_py, + eval_python_only, + numba_mode, + set_test_value, +) + + +rng = np.random.default_rng(42849) + + +@pytest.mark.parametrize( + "rv_op, dist_args, size", + [ + ( + aer.normal, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + at.as_tensor([3, 2]), + ), + ( + aer.uniform, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + at.as_tensor([3, 2]), + ), + ( + aer.triangular, + [ + set_test_value( + at.dscalar(), + np.array(-5.0, dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(5.0, dtype=np.float64), + ), + ], + at.as_tensor([3, 2]), + ), + ( + aer.lognormal, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + at.as_tensor([3, 2]), + ), + pytest.param( + aer.pareto, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + ], + at.as_tensor([3, 2]), + marks=pytest.mark.xfail(reason="Not implemented"), + ), + ( + aer.exponential, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + ], + at.as_tensor([3, 2]), + ), + ( + aer.weibull, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + ], + at.as_tensor([3, 2]), + ), + ( + aer.logistic, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + at.as_tensor([3, 2]), + ), + ( + aer.geometric, + [ + set_test_value( + at.dvector(), + np.array([0.3, 0.4], dtype=np.float64), + ), + ], + at.as_tensor([3, 2]), + ), + ( + aer.hypergeometric, + [ + set_test_value( + at.lscalar(), + np.array(7, dtype=np.int64), + ), + set_test_value( + at.lscalar(), + np.array(8, dtype=np.int64), + ), + set_test_value( + at.lscalar(), + np.array(15, dtype=np.int64), + ), + ], + at.as_tensor([3, 2]), + ), + ( + aer.wald, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + at.as_tensor([3, 2]), + ), + ( + aer.laplace, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + at.as_tensor([3, 2]), + ), + ( + aer.binomial, + [ + set_test_value( + at.lvector(), + np.array([1, 2], dtype=np.int64), + ), + set_test_value( + at.dscalar(), + np.array(0.9, dtype=np.float64), + ), + ], + at.as_tensor([3, 2]), + ), + ( + aer.normal, + [ + set_test_value( + at.lvector(), + np.array([1, 2], dtype=np.int64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + at.as_tensor(tuple(set_test_value(at.lscalar(), v) for v in [3, 2])), + ), + ( + aer.poisson, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + ], + None, + ), + ( + aer.halfnormal, + [ + set_test_value( + at.lvector(), + np.array([1, 2], dtype=np.int64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + None, + ), + ( + aer.bernoulli, + [ + set_test_value( + at.dvector(), + np.array([0.1, 0.9], dtype=np.float64), + ), + ], + None, + ), + ( + aer.randint, + [ + set_test_value( + at.lscalar(), + np.array(0, dtype=np.int64), + ), + set_test_value( + at.lscalar(), + np.array(5, dtype=np.int64), + ), + ], + at.as_tensor([3, 2]), + ), + pytest.param( + aer.multivariate_normal, + [ + set_test_value( + at.dmatrix(), + np.array([[1, 2], [3, 4]], dtype=np.float64), + ), + set_test_value( + at.tensor("float64", [True, False, False]), + np.eye(2)[None, ...], + ), + ], + at.as_tensor(tuple(set_test_value(at.lscalar(), v) for v in [4, 3, 2])), + marks=pytest.mark.xfail(reason="Not implemented"), + ), + ], + ids=str, +) +def test_aligned_RandomVariable(rv_op, dist_args, size): + """Tests for Numba samplers that are one-to-one with Aesara's/NumPy's samplers.""" + rng = shared(np.random.RandomState(29402)) + g = rv_op(*dist_args, size=size, rng=rng) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "rv_op, dist_args, base_size, cdf_name, params_conv", + [ + ( + aer.beta, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "beta", + lambda *args: args, + ), + ( + aer.gamma, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "gamma", + lambda a, b: (a, 0.0, b), + ), + ( + aer.cauchy, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "cauchy", + lambda *args: args, + ), + ( + aer.chisquare, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ) + ], + (2,), + "chi2", + lambda *args: args, + ), + ( + aer.gumbel, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "gumbel_r", + lambda *args: args, + ), + ( + aer.negative_binomial, + [ + set_test_value( + at.lvector(), + np.array([100, 200], dtype=np.int64), + ), + set_test_value( + at.dscalar(), + np.array(0.09, dtype=np.float64), + ), + ], + (2,), + "nbinom", + lambda *args: args, + ), + pytest.param( + aer.vonmises, + [ + set_test_value( + at.dvector(), + np.array([-0.5, 0.5], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "vonmises_line", + lambda mu, kappa: (kappa, mu), + marks=pytest.mark.xfail( + reason=( + "Numba's parameterization of `vonmises` does not match NumPy's." + "See https://github.com/numba/numba/issues/7886" + ) + ), + ), + ], +) +def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv): + """Tests for Numba samplers that are not one-to-one with Aesara's/NumPy's samplers.""" + rng = shared(np.random.RandomState(29402)) + g = rv_op(*dist_args, size=(2000,) + base_size, rng=rng) + g_fn = function(dist_args, g, mode=numba_mode) + samples = g_fn( + *[ + i.tag.test_value + for i in g_fn.maker.fgraph.inputs + if not isinstance(i, (SharedVariable, Constant)) + ] + ) + + bcast_dist_args = np.broadcast_arrays(*[i.tag.test_value for i in dist_args]) + + for idx in np.ndindex(*base_size): + cdf_params = params_conv(*tuple(arg[idx] for arg in bcast_dist_args)) + test_res = stats.cramervonmises( + samples[(Ellipsis,) + idx], cdf_name, args=cdf_params + ) + assert test_res.pvalue > 0.1 + + +@pytest.mark.parametrize( + "dist_args, size, cm", + [ + pytest.param( + [ + set_test_value( + at.dvector(), + np.array([100000, 1, 1], dtype=np.float64), + ), + ], + None, + contextlib.suppress(), + ), + pytest.param( + [ + set_test_value( + at.dmatrix(), + np.array( + [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], + dtype=np.float64, + ), + ), + ], + (10, 3), + contextlib.suppress(), + ), + pytest.param( + [ + set_test_value( + at.dmatrix(), + np.array( + [[100000, 1, 1]], + dtype=np.float64, + ), + ), + ], + (5, 4, 3), + contextlib.suppress(), + ), + pytest.param( + [ + set_test_value( + at.dmatrix(), + np.array( + [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], + dtype=np.float64, + ), + ), + ], + (10, 4), + pytest.raises( + ValueError, match="objects cannot be broadcast to a single shape" + ), + ), + ], +) +def test_CategoricalRV(dist_args, size, cm): + rng = shared(np.random.RandomState(29402)) + g = aer.categorical(*dist_args, size=size, rng=rng) + g_fg = FunctionGraph(outputs=[g]) + + with cm: + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "a, size, cm", + [ + pytest.param( + set_test_value( + at.dvector(), + np.array([100000, 1, 1], dtype=np.float64), + ), + None, + contextlib.suppress(), + ), + pytest.param( + set_test_value( + at.dmatrix(), + np.array( + [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], + dtype=np.float64, + ), + ), + (10, 3), + contextlib.suppress(), + ), + pytest.param( + set_test_value( + at.dmatrix(), + np.array( + [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], + dtype=np.float64, + ), + ), + (10, 4), + pytest.raises(ValueError, match="Parameters shape.*"), + ), + ], +) +def test_DirichletRV(a, size, cm): + rng = shared(np.random.RandomState(29402)) + g = aer.dirichlet(a, size=size, rng=rng) + g_fn = function([a], g, mode=numba_mode) + + with cm: + a_val = a.tag.test_value + + # For coverage purposes only... + eval_python_only([a], FunctionGraph(outputs=[g], clone=False), [a_val]) + + all_samples = [] + for i in range(1000): + samples = g_fn(a_val) + all_samples.append(samples) + + exp_res = a_val / a_val.sum(-1) + res = np.mean(all_samples, axis=tuple(range(0, a_val.ndim - 1))) + assert np.allclose(res, exp_res, atol=1e-4) + + +def test_RandomState_updates(): + rng = shared(np.random.RandomState(1)) + rng_new = shared(np.random.RandomState(2)) + + x = at.random.normal(size=10, rng=rng) + res = function([], x, updates={rng: rng_new}, mode=numba_mode)() + + ref = np.random.RandomState(2).normal(size=10) + assert np.allclose(res, ref) + + +def test_random_Generator(): + rng = shared(np.random.default_rng(29402)) + g = aer.normal(rng=rng) + g_fg = FunctionGraph(outputs=[g]) + + with pytest.raises(TypeError): + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py new file mode 100644 index 0000000000..ac107d46f4 --- /dev/null +++ b/tests/link/numba/test_scalar.py @@ -0,0 +1,142 @@ +import numpy as np +import pytest + +import aesara.scalar as aes +import aesara.scalar.basic as aesb +import aesara.tensor as at +from aesara import config +from aesara.compile.sharedvalue import SharedVariable +from aesara.graph.basic import Constant +from aesara.graph.fg import FunctionGraph +from aesara.scalar.basic import Composite +from aesara.tensor.elemwise import Elemwise +from tests.link.numba.test_basic import compare_numba_and_py, set_test_value + + +rng = np.random.default_rng(42849) + + +@pytest.mark.parametrize( + "x, y", + [ + ( + set_test_value(at.lvector(), np.arange(4, dtype="int64")), + set_test_value(at.dvector(), np.arange(4, dtype="float64")), + ), + ( + set_test_value(at.dmatrix(), np.arange(4, dtype="float64").reshape((2, 2))), + set_test_value(at.lscalar(), np.array(4, dtype="int64")), + ), + ], +) +def test_Second(x, y): + # We use the `Elemwise`-wrapped version of `Second` + g = at.second(x, y) + g_fg = FunctionGraph(outputs=[g]) + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "v, min, max", + [ + (set_test_value(at.scalar(), np.array(10, dtype=config.floatX)), 3.0, 7.0), + (set_test_value(at.scalar(), np.array(1, dtype=config.floatX)), 3.0, 7.0), + (set_test_value(at.scalar(), np.array(10, dtype=config.floatX)), 7.0, 3.0), + ], +) +def test_Clip(v, min, max): + g = aes.clip(v, min, max) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "inputs, input_values, scalar_fn", + [ + ( + [at.scalar("x"), at.scalar("y"), at.scalar("z")], + [ + np.array(10, dtype=config.floatX), + np.array(20, dtype=config.floatX), + np.array(30, dtype=config.floatX), + ], + lambda x, y, z: aes.add(x, y, z), + ), + ( + [at.scalar("x"), at.scalar("y"), at.scalar("z")], + [ + np.array(10, dtype=config.floatX), + np.array(20, dtype=config.floatX), + np.array(30, dtype=config.floatX), + ], + lambda x, y, z: aes.mul(x, y, z), + ), + ( + [at.scalar("x"), at.scalar("y")], + [ + np.array(10, dtype=config.floatX), + np.array(20, dtype=config.floatX), + ], + lambda x, y: x + y * 2 + aes.exp(x - y), + ), + ], +) +def test_Composite(inputs, input_values, scalar_fn): + composite_inputs = [aes.float64(i.name) for i in inputs] + comp_op = Elemwise(Composite(composite_inputs, [scalar_fn(*composite_inputs)])) + out_fg = FunctionGraph(inputs, [comp_op(*inputs)]) + compare_numba_and_py(out_fg, input_values) + + +@pytest.mark.parametrize( + "v, dtype", + [ + (set_test_value(at.fscalar(), np.array(1.0, dtype="float32")), aesb.float64), + (set_test_value(at.dscalar(), np.array(1.0, dtype="float64")), aesb.float32), + ], +) +def test_Cast(v, dtype): + g = aesb.Cast(dtype)(v) + g_fg = FunctionGraph(outputs=[g]) + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "v, dtype", + [ + (set_test_value(at.iscalar(), np.array(10, dtype="int32")), aesb.float64), + ], +) +def test_reciprocal(v, dtype): + g = aesb.reciprocal(v) + g_fg = FunctionGraph(outputs=[g]) + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py new file mode 100644 index 0000000000..6ef2257caa --- /dev/null +++ b/tests/link/numba/test_scan.py @@ -0,0 +1,200 @@ +import numpy as np + +import aesara.tensor as at +from aesara import config +from aesara.graph.fg import FunctionGraph +from aesara.scan.basic import scan +from aesara.scan.utils import until +from tests.link.numba.test_basic import compare_numba_and_py + + +rng = np.random.default_rng(42849) + + +def test_scan_multiple_output(): + """Test a scan implementation of a SEIR model. + + SEIR model definition: + S[t+1] = S[t] - B[t] + E[t+1] = E[t] +B[t] - C[t] + I[t+1] = I[t+1] + C[t] - D[t] + + B[t] ~ Binom(S[t], beta) + C[t] ~ Binom(E[t], gamma) + D[t] ~ Binom(I[t], delta) + """ + + def binomln(n, k): + return at.exp(n + 1) - at.exp(k + 1) - at.exp(n - k + 1) + + def binom_log_prob(n, p, value): + return binomln(n, value) + value * at.exp(p) + (n - value) * at.exp(1 - p) + + # sequences + at_C = at.ivector("C_t") + at_D = at.ivector("D_t") + # outputs_info (initial conditions) + st0 = at.lscalar("s_t0") + et0 = at.lscalar("e_t0") + it0 = at.lscalar("i_t0") + logp_c = at.scalar("logp_c") + logp_d = at.scalar("logp_d") + # non_sequences + beta = at.scalar("beta") + gamma = at.scalar("gamma") + delta = at.scalar("delta") + + def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): + bt0 = st0 * beta + bt0 = bt0.astype(st0.dtype) + + logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype) + logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype) + + st1 = st0 - bt0 + et1 = et0 + bt0 - ct0 + it1 = it0 + ct0 - dt0 + return st1, et1, it1, logp_c1, logp_d1 + + (st, et, it, logp_c_all, logp_d_all), _ = scan( + fn=seir_one_step, + sequences=[at_C, at_D], + outputs_info=[st0, et0, it0, logp_c, logp_d], + non_sequences=[beta, gamma, delta], + ) + st.name = "S_t" + et.name = "E_t" + it.name = "I_t" + logp_c_all.name = "C_t_logp" + logp_d_all.name = "D_t_logp" + + out_fg = FunctionGraph( + [at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta], + [st, et, it, logp_c_all, logp_d_all], + ) + + s0, e0, i0 = 100, 50, 25 + logp_c0 = np.array(0.0, dtype=config.floatX) + logp_d0 = np.array(0.0, dtype=config.floatX) + beta_val, gamma_val, delta_val = [ + np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753] + ] + C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32) + D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32) + + test_input_vals = [ + C, + D, + s0, + e0, + i0, + logp_c0, + logp_d0, + beta_val, + gamma_val, + delta_val, + ] + compare_numba_and_py(out_fg, test_input_vals) + + +@config.change_flags(compute_test_value="raise") +def test_scan_tap_output(): + + a_at = at.scalar("a") + a_at.tag.test_value = 10.0 + + b_at = at.arange(11).astype(config.floatX) + b_at.name = "b" + + c_at = at.arange(20, 31, dtype=config.floatX) + c_at.name = "c" + + def input_step_fn(b, b2, c, x_tm1, y_tm1, y_tm3, a): + x_tm1.name = "x_tm1" + y_tm1.name = "y_tm1" + y_tm3.name = "y_tm3" + y_t = (y_tm1 + y_tm3) * a + b + b2 + z_t = y_t * c + x_t = x_tm1 + 1 + x_t.name = "x_t" + y_t.name = "y_t" + return x_t, y_t, at.fill((10,), z_t) + + scan_res, _ = scan( + fn=input_step_fn, + sequences=[ + { + "input": b_at, + "taps": [-1, -2], + }, + { + "input": c_at, + "taps": [-2], + }, + ], + outputs_info=[ + { + "initial": at.as_tensor_variable(0.0, dtype=config.floatX), + "taps": [-1], + }, + { + "initial": at.as_tensor_variable( + np.r_[-1.0, 1.3, 0.0].astype(config.floatX) + ), + "taps": [-1, -3], + }, + None, + ], + non_sequences=[a_at], + n_steps=5, + name="yz_scan", + strict=True, + ) + + out_fg = FunctionGraph([a_at, b_at, c_at], scan_res) + + test_input_vals = [ + np.array(10.0).astype(config.floatX), + np.arange(11, dtype=config.floatX), + np.arange(20, 31, dtype=config.floatX), + ] + compare_numba_and_py(out_fg, test_input_vals) + + +def test_scan_while(): + def power_of_2(previous_power, max_value): + return previous_power * 2, until(previous_power * 2 > max_value) + + max_value = at.scalar() + values, _ = scan( + power_of_2, + outputs_info=at.constant(1.0), + non_sequences=max_value, + n_steps=1024, + ) + + out_fg = FunctionGraph([max_value], [values]) + + test_input_vals = [ + np.array(45).astype(config.floatX), + ] + compare_numba_and_py(out_fg, test_input_vals) + + +def test_scan_multiple_none_output(): + A = at.dvector("A") + + def power_step(prior_result, x): + return prior_result * x, prior_result * x * x, prior_result * x * x * x + + result, _ = scan( + power_step, + non_sequences=[A], + outputs_info=[at.ones_like(A), None, None], + n_steps=3, + ) + + out_fg = FunctionGraph([A], result) + test_input_vals = (np.array([1.0, 2.0]),) + + compare_numba_and_py(out_fg, test_input_vals) diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py new file mode 100644 index 0000000000..81664e756d --- /dev/null +++ b/tests/link/numba/test_tensor_basic.py @@ -0,0 +1,394 @@ +import numpy as np +import pytest + +import aesara.scalar as aes +import aesara.tensor as at +import aesara.tensor.basic as atb +from aesara import config +from aesara.compile.sharedvalue import SharedVariable +from aesara.graph.basic import Constant +from aesara.graph.fg import FunctionGraph +from aesara.tensor.shape import Unbroadcast +from tests.link.numba.test_basic import ( + compare_numba_and_py, + compare_shape_dtype, + set_test_value, +) + + +rng = np.random.default_rng(42849) + + +@pytest.mark.parametrize( + "v, shape", + [ + (0.0, (2, 3)), + (1.1, (2, 3)), + (set_test_value(at.scalar("a"), np.array(10.0, dtype=config.floatX)), (20,)), + (set_test_value(at.vector("a"), np.ones(10, dtype=config.floatX)), (20, 10)), + ], +) +def test_Alloc(v, shape): + g = at.alloc(v, *shape) + g_fg = FunctionGraph(outputs=[g]) + + (numba_res,) = compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + assert numba_res.shape == shape + + +def test_AllocEmpty(): + + x = at.empty((2, 3), dtype="float32") + x_fg = FunctionGraph([], [x]) + + # We cannot compare the values in the arrays, only the shapes and dtypes + compare_numba_and_py(x_fg, [], assert_fn=compare_shape_dtype) + + +@pytest.mark.parametrize( + "v, offset", + [ + (set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), 0), + (set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), 1), + (set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), -1), + ], +) +def test_AllocDiag(v, offset): + g = atb.AllocDiag(offset=offset)(v) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "v", [set_test_value(aes.float64(), np.array(1.0, dtype="float64"))] +) +def test_TensorFromScalar(v): + g = atb.TensorFromScalar()(v) + g_fg = FunctionGraph(outputs=[g]) + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "v", + [ + set_test_value(at.scalar(), np.array(1.0, dtype=config.floatX)), + ], +) +def test_ScalarFromTensor(v): + g = atb.ScalarFromTensor()(v) + g_fg = FunctionGraph(outputs=[g]) + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +def test_Unbroadcast(): + v = set_test_value(at.row(), np.array([[1.0, 2.0]], dtype=config.floatX)) + g = Unbroadcast(0)(v) + g_fg = FunctionGraph(outputs=[g]) + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "vals, dtype", + [ + ( + ( + set_test_value(at.scalar(), np.array(1, dtype=config.floatX)), + set_test_value(at.scalar(), np.array(2, dtype=config.floatX)), + set_test_value(at.scalar(), np.array(3, dtype=config.floatX)), + ), + config.floatX, + ), + ( + ( + set_test_value(at.dscalar(), np.array(1, dtype=np.float64)), + set_test_value(at.lscalar(), np.array(3, dtype=np.int32)), + ), + "float64", + ), + ( + (set_test_value(at.iscalar(), np.array(1, dtype=np.int32)),), + "float64", + ), + ( + (set_test_value(at.scalar(dtype=bool), True),), + bool, + ), + ], +) +def test_MakeVector(vals, dtype): + g = atb.MakeVector(dtype)(*vals) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "start, stop, step, dtype", + [ + ( + set_test_value(at.lscalar(), np.array(1)), + set_test_value(at.lscalar(), np.array(10)), + set_test_value(at.lscalar(), np.array(3)), + config.floatX, + ), + ], +) +def test_ARange(start, stop, step, dtype): + g = atb.ARange(dtype)(start, stop, step) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "vals, axis", + [ + ( + ( + set_test_value( + at.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) + ), + set_test_value( + at.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) + ), + ), + 0, + ), + ( + ( + set_test_value( + at.matrix(), rng.normal(size=(2, 1)).astype(config.floatX) + ), + set_test_value( + at.matrix(), rng.normal(size=(3, 1)).astype(config.floatX) + ), + ), + 0, + ), + ( + ( + set_test_value( + at.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) + ), + set_test_value( + at.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) + ), + ), + 1, + ), + ( + ( + set_test_value( + at.matrix(), rng.normal(size=(2, 2)).astype(config.floatX) + ), + set_test_value( + at.matrix(), rng.normal(size=(2, 1)).astype(config.floatX) + ), + ), + 1, + ), + ], +) +def test_Join(vals, axis): + g = at.join(axis, *vals) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +def test_Join_view(): + vals = ( + set_test_value(at.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)), + set_test_value(at.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)), + ) + g = atb.Join(view=1)(1, *vals) + g_fg = FunctionGraph(outputs=[g]) + + with pytest.raises(NotImplementedError): + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "n_splits, axis, values, sizes", + [ + ( + 0, + 0, + set_test_value(at.vector(), rng.normal(size=20).astype(config.floatX)), + set_test_value(at.vector(dtype="int64"), []), + ), + ( + 5, + 0, + set_test_value(at.vector(), rng.normal(size=5).astype(config.floatX)), + set_test_value( + at.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5) + ), + ), + ( + 5, + 0, + set_test_value(at.vector(), rng.normal(size=10).astype(config.floatX)), + set_test_value( + at.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5) + ), + ), + ( + 5, + -1, + set_test_value(at.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), + set_test_value( + at.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5) + ), + ), + ( + 5, + -2, + set_test_value(at.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), + set_test_value( + at.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5) + ), + ), + ], +) +def test_Split(n_splits, axis, values, sizes): + g = at.split(values, sizes, n_splits, axis=axis) + assert len(g) == n_splits + if n_splits == 0: + return + g_fg = FunctionGraph(outputs=[g] if n_splits == 1 else g) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "val, offset", + [ + ( + set_test_value( + at.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10)) + ), + 0, + ), + ( + set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), + 0, + ), + ], +) +def test_ExtractDiag(val, offset): + g = at.diag(val, offset) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) + + +@pytest.mark.parametrize( + "n, m, k, dtype", + [ + (set_test_value(at.lscalar(), np.array(1, dtype=np.int64)), None, 0, None), + ( + set_test_value(at.lscalar(), np.array(1, dtype=np.int64)), + set_test_value(at.lscalar(), np.array(2, dtype=np.int64)), + 0, + "float32", + ), + ( + set_test_value(at.lscalar(), np.array(1, dtype=np.int64)), + set_test_value(at.lscalar(), np.array(2, dtype=np.int64)), + 1, + "int64", + ), + ], +) +def test_Eye(n, m, k, dtype): + g = at.eye(n, m, k, dtype=dtype) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py( + g_fg, + [ + i.tag.test_value + for i in g_fg.inputs + if not isinstance(i, (SharedVariable, Constant)) + ], + ) diff --git a/tests/link/test_jax.py b/tests/link/test_jax.py deleted file mode 100644 index b456b4a790..0000000000 --- a/tests/link/test_jax.py +++ /dev/null @@ -1,1394 +0,0 @@ -from functools import partial -from typing import Optional - -import numpy as np -import pytest -from jax._src.errors import NonConcreteBooleanIndexError -from packaging.version import parse as version_parse - -import aesara.scalar.basic as aes -from aesara.compile.function import function -from aesara.compile.mode import Mode -from aesara.compile.ops import DeepCopyOp, ViewOp -from aesara.compile.sharedvalue import SharedVariable, shared -from aesara.configdefaults import config -from aesara.graph.basic import Apply -from aesara.graph.fg import FunctionGraph -from aesara.graph.op import Op, get_test_value -from aesara.graph.optdb import OptimizationQuery -from aesara.ifelse import ifelse -from aesara.link.jax import JAXLinker -from aesara.raise_op import assert_op -from aesara.scalar.basic import Composite -from aesara.scan.basic import scan -from aesara.tensor import basic as at -from aesara.tensor import blas as at_blas -from aesara.tensor import elemwise as at_elemwise -from aesara.tensor import extra_ops as at_extra_ops -from aesara.tensor import nlinalg as at_nlinalg -from aesara.tensor import nnet as at_nnet -from aesara.tensor import slinalg as at_slinalg -from aesara.tensor import subtensor as at_subtensor -from aesara.tensor.elemwise import Elemwise -from aesara.tensor.math import MaxAndArgmax -from aesara.tensor.math import all as at_all -from aesara.tensor.math import clip, cosh, erf, erfc, erfinv, gammaln, log, log1mexp -from aesara.tensor.math import max as at_max -from aesara.tensor.math import maximum, prod, psi, sigmoid, softplus -from aesara.tensor.math import sum as at_sum -from aesara.tensor.nnet.basic import SoftmaxGrad -from aesara.tensor.random.basic import RandomVariable, normal -from aesara.tensor.random.utils import RandomStream -from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, Unbroadcast, reshape -from aesara.tensor.type import ( - dscalar, - dvector, - iscalar, - ivector, - lscalar, - matrix, - scalar, - tensor, - tensor3, - vector, -) - - -jax = pytest.importorskip("jax") - -opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) -jax_mode = Mode(JAXLinker(), opts) -py_mode = Mode("py", opts) - - -@pytest.fixture(scope="module", autouse=True) -def set_aesara_flags(): - with config.change_flags(cxx="", compute_test_value="ignore"): - yield - - -def compare_jax_and_py( - fgraph: FunctionGraph, - test_inputs: iter, - assert_fn: Optional[callable] = None, - must_be_device_array: bool = True, -): - """Function to compare python graph output and jax compiled output for testing equality - - In the tests below computational graphs are defined in Aesara. These graphs are then passed to - this function which then compiles the graphs in both jax and python, runs the calculation - in both and checks if the results are the same - - Parameters - ---------- - fgraph: FunctionGraph - Aesara function Graph object - test_inputs: iter - Numerical inputs for testing the function graph - assert_fn: func, opt - Assert function used to check for equality between python and jax. If not - provided uses np.testing.assert_allclose - must_be_device_array: Bool - Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes - if this device array is found it indicates if the result was computed by jax - - Returns - ------- - jax_res - - """ - if assert_fn is None: - assert_fn = partial(np.testing.assert_allclose, rtol=1e-4) - - fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] - aesara_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode) - jax_res = aesara_jax_fn(*test_inputs) - - if must_be_device_array: - if isinstance(jax_res, list): - assert all( - isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res - ) - else: - assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) - - aesara_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode) - py_res = aesara_py_fn(*test_inputs) - - if len(fgraph.outputs) > 1: - for j, p in zip(jax_res, py_res): - assert_fn(j, p) - else: - assert_fn(jax_res, py_res) - - return jax_res - - -def test_jax_Alloc(): - x = at.alloc(0.0, 2, 3) - x_fg = FunctionGraph([], [x]) - - (jax_res,) = compare_jax_and_py(x_fg, []) - - assert jax_res.shape == (2, 3) - - x = at.alloc(1.1, 2, 3) - x_fg = FunctionGraph([], [x]) - - compare_jax_and_py(x_fg, []) - - x = at.AllocEmpty("float32")(2, 3) - x_fg = FunctionGraph([], [x]) - - def compare_shape_dtype(x, y): - (x,) = x - (y,) = y - return x.shape == y.shape and x.dtype == y.dtype - - compare_jax_and_py(x_fg, [], assert_fn=compare_shape_dtype) - - a = scalar("a") - x = at.alloc(a, 20) - x_fg = FunctionGraph([a], [x]) - - compare_jax_and_py(x_fg, [10.0]) - - a = vector("a") - x = at.alloc(a, 20, 10) - x_fg = FunctionGraph([a], [x]) - - compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)]) - - -def test_jax_shape_ops(): - x_np = np.zeros((20, 3)) - x = Shape()(at.as_tensor_variable(x_np)) - x_fg = FunctionGraph([], [x]) - - compare_jax_and_py(x_fg, [], must_be_device_array=False) - - x = Shape_i(1)(at.as_tensor_variable(x_np)) - x_fg = FunctionGraph([], [x]) - - compare_jax_and_py(x_fg, [], must_be_device_array=False) - - -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", -) -def test_jax_specify_shape(): - x_np = np.zeros((20, 3)) - x = SpecifyShape()(at.as_tensor_variable(x_np), (20, 3)) - x_fg = FunctionGraph([], [x]) - - compare_jax_and_py(x_fg, []) - - with config.change_flags(compute_test_value="off"): - - x = SpecifyShape()(at.as_tensor_variable(x_np), *(2, 3)) - x_fg = FunctionGraph([], [x]) - - with pytest.raises(AssertionError): - compare_jax_and_py(x_fg, []) - - -def test_jax_compile_ops(): - - x = DeepCopyOp()(at.as_tensor_variable(1.1)) - x_fg = FunctionGraph([], [x]) - - compare_jax_and_py(x_fg, []) - - x_np = np.zeros((20, 1, 1)) - x = Unbroadcast(0, 2)(at.as_tensor_variable(x_np)) - x_fg = FunctionGraph([], [x]) - - compare_jax_and_py(x_fg, []) - - x = ViewOp()(at.as_tensor_variable(x_np)) - x_fg = FunctionGraph([], [x]) - - compare_jax_and_py(x_fg, []) - - -def test_jax_basic(): - rng = np.random.default_rng(28494) - - x = matrix("x") - y = matrix("y") - b = vector("b") - - # `ScalarOp` - z = cosh(x**2 + y / 3.0) - - # `[Inc]Subtensor` - out = at_subtensor.set_subtensor(z[0], -10.0) - out = at_subtensor.inc_subtensor(out[0, 1], 2.0) - out = out[:5, :3] - - out_fg = FunctionGraph([x, y], [out]) - - test_input_vals = [ - np.tile(np.arange(10), (10, 1)).astype(config.floatX), - np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX), - ] - (jax_res,) = compare_jax_and_py(out_fg, test_input_vals) - - # Confirm that the `Subtensor` slice operations are correct - assert jax_res.shape == (5, 3) - - # Confirm that the `IncSubtensor` operations are correct - assert jax_res[0, 0] == -10.0 - assert jax_res[0, 1] == -8.0 - - out = clip(x, y, 5) - out_fg = FunctionGraph([x, y], [out]) - compare_jax_and_py(out_fg, test_input_vals) - - out = at.diagonal(x, 0) - out_fg = FunctionGraph([x], [out]) - compare_jax_and_py( - out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] - ) - - out = at_slinalg.cholesky(x) - out_fg = FunctionGraph([x], [out]) - compare_jax_and_py( - out_fg, - [ - (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( - config.floatX - ) - ], - ) - - # not sure why this isn't working yet with lower=False - out = at_slinalg.Cholesky(lower=False)(x) - out_fg = FunctionGraph([x], [out]) - compare_jax_and_py( - out_fg, - [ - (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( - config.floatX - ) - ], - ) - - out = at_slinalg.solve(x, b) - out_fg = FunctionGraph([x, b], [out]) - compare_jax_and_py( - out_fg, - [ - np.eye(10).astype(config.floatX), - np.arange(10).astype(config.floatX), - ], - ) - - out = at.diag(b) - out_fg = FunctionGraph([b], [out]) - compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)]) - - out = at_nlinalg.det(x) - out_fg = FunctionGraph([x], [out]) - compare_jax_and_py( - out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)] - ) - - out = at_nlinalg.matrix_inverse(x) - out_fg = FunctionGraph([x], [out]) - compare_jax_and_py( - out_fg, - [ - (np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype( - config.floatX - ) - ], - ) - - -@pytest.mark.parametrize("check_finite", [False, True]) -@pytest.mark.parametrize("lower", [False, True]) -@pytest.mark.parametrize("trans", [0, 1, 2]) -def test_jax_SolveTriangular(trans, lower, check_finite): - x = matrix("x") - b = vector("b") - - out = at_slinalg.solve_triangular( - x, - b, - trans=trans, - lower=lower, - check_finite=check_finite, - ) - out_fg = FunctionGraph([x, b], [out]) - compare_jax_and_py( - out_fg, - [ - np.eye(10).astype(config.floatX), - np.arange(10).astype(config.floatX), - ], - ) - - -@pytest.mark.parametrize( - "x, y, x_val, y_val", - [ - (scalar("x"), scalar("y"), np.array(10), np.array(20)), - (scalar("x"), vector("y"), np.array(10), np.arange(10, 20)), - ( - matrix("x"), - vector("y"), - np.arange(10 * 20).reshape((20, 10)), - np.arange(10, 20), - ), - ], -) -def test_jax_Composite(x, y, x_val, y_val): - x_s = aes.float64("x") - y_s = aes.float64("y") - - comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2 + aes.exp(x_s - y_s)])) - - out = comp_op(x, y) - - out_fg = FunctionGraph([x, y], [out]) - - test_input_vals = [ - x_val.astype(config.floatX), - y_val.astype(config.floatX), - ] - _ = compare_jax_and_py(out_fg, test_input_vals) - - -def test_jax_FunctionGraph_names(): - import inspect - - from aesara.link.jax.dispatch import jax_funcify - - x = scalar("1x") - y = scalar("_") - z = scalar() - q = scalar("def") - - out_fg = FunctionGraph([x, y, z, q], [x, y, z, q], clone=False) - out_jx = jax_funcify(out_fg) - sig = inspect.signature(out_jx) - assert (x.auto_name, "_", z.auto_name, q.auto_name) == tuple(sig.parameters.keys()) - assert (1, 2, 3, 4) == out_jx(1, 2, 3, 4) - - -def test_jax_FunctionGraph_once(): - """Make sure that an output is only computed once when it's referenced multiple times.""" - from aesara.link.jax.dispatch import jax_funcify - - x = vector("x") - y = vector("y") - - class TestOp(Op): - def __init__(self): - self.called = 0 - - def make_node(self, *args): - return Apply(self, list(args), [x.type() for x in args]) - - def perform(self, inputs, outputs): - for i, inp in enumerate(inputs): - outputs[i][0] = inp[0] - - @jax_funcify.register(TestOp) - def jax_funcify_TestOp(op, **kwargs): - def func(*args, op=op): - op.called += 1 - return list(args) - - return func - - op1 = TestOp() - op2 = TestOp() - - q, r = op1(x, y) - outs = op2(q + r, q + r) - - out_fg = FunctionGraph([x, y], outs, clone=False) - assert len(out_fg.outputs) == 2 - - out_jx = jax_funcify(out_fg) - - x_val = np.r_[1, 2].astype(config.floatX) - y_val = np.r_[2, 3].astype(config.floatX) - - res = out_jx(x_val, y_val) - assert len(res) == 2 - assert op1.called == 1 - assert op2.called == 1 - - res = out_jx(x_val, y_val) - assert len(res) == 2 - assert op1.called == 2 - assert op2.called == 2 - - -def test_jax_eye(): - """Tests jaxification of the Eye operator""" - out = at.eye(3) - out_fg = FunctionGraph([], [out]) - - compare_jax_and_py(out_fg, []) - - -def test_jax_basic_multiout(): - rng = np.random.default_rng(213234) - - M = rng.normal(size=(3, 3)) - X = M.dot(M.T) - - x = matrix("x") - - outs = at_nlinalg.eig(x) - out_fg = FunctionGraph([x], outs) - - def assert_fn(x, y): - np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3) - - compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) - - outs = at_nlinalg.eigh(x) - out_fg = FunctionGraph([x], outs) - compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) - - outs = at_nlinalg.qr(x, mode="full") - out_fg = FunctionGraph([x], outs) - compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) - - outs = at_nlinalg.qr(x, mode="reduced") - out_fg = FunctionGraph([x], outs) - compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) - - outs = at_nlinalg.svd(x) - out_fg = FunctionGraph([x], outs) - compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) - - -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", -) -def test_jax_basic_multiout_omni(): - # Test that a single output of a multi-output `Op` can be used as input to - # another `Op` - x = dvector() - mx, amx = MaxAndArgmax([0])(x) - out = mx * amx - out_fg = FunctionGraph([x], [out]) - compare_jax_and_py(out_fg, [np.r_[1, 2]]) - - -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", -) -def test_jax_scan_multiple_output(): - """Test a scan implementation of a SEIR model. - - SEIR model definition: - S[t+1] = S[t] - B[t] - E[t+1] = E[t] +B[t] - C[t] - I[t+1] = I[t+1] + C[t] - D[t] - - B[t] ~ Binom(S[t], beta) - C[t] ~ Binom(E[t], gamma) - D[t] ~ Binom(I[t], delta) - """ - - def binomln(n, k): - return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1) - - def binom_log_prob(n, p, value): - return binomln(n, value) + value * log(p) + (n - value) * log(1 - p) - - # sequences - at_C = ivector("C_t") - at_D = ivector("D_t") - # outputs_info (initial conditions) - st0 = lscalar("s_t0") - et0 = lscalar("e_t0") - it0 = lscalar("i_t0") - logp_c = scalar("logp_c") - logp_d = scalar("logp_d") - # non_sequences - beta = scalar("beta") - gamma = scalar("gamma") - delta = scalar("delta") - - # TODO: Use random streams when their JAX conversions are implemented. - # trng = aesara.tensor.random.RandomStream(1234) - - def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): - # bt0 = trng.binomial(n=st0, p=beta) - bt0 = st0 * beta - bt0 = bt0.astype(st0.dtype) - - logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype) - logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype) - - st1 = st0 - bt0 - et1 = et0 + bt0 - ct0 - it1 = it0 + ct0 - dt0 - return st1, et1, it1, logp_c1, logp_d1 - - (st, et, it, logp_c_all, logp_d_all), _ = scan( - fn=seir_one_step, - sequences=[at_C, at_D], - outputs_info=[st0, et0, it0, logp_c, logp_d], - non_sequences=[beta, gamma, delta], - ) - st.name = "S_t" - et.name = "E_t" - it.name = "I_t" - logp_c_all.name = "C_t_logp" - logp_d_all.name = "D_t_logp" - - out_fg = FunctionGraph( - [at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta], - [st, et, it, logp_c_all, logp_d_all], - ) - - s0, e0, i0 = 100, 50, 25 - logp_c0 = np.array(0.0, dtype=config.floatX) - logp_d0 = np.array(0.0, dtype=config.floatX) - beta_val, gamma_val, delta_val = [ - np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753] - ] - C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32) - D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32) - - test_input_vals = [ - C, - D, - s0, - e0, - i0, - logp_c0, - logp_d0, - beta_val, - gamma_val, - delta_val, - ] - compare_jax_and_py(out_fg, test_input_vals) - - -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", -) -def test_jax_scan_tap_output(): - - a_at = scalar("a") - - def input_step_fn(y_tm1, y_tm3, a): - y_tm1.name = "y_tm1" - y_tm3.name = "y_tm3" - res = (y_tm1 + y_tm3) * a - res.name = "y_t" - return res - - y_scan_at, _ = scan( - fn=input_step_fn, - outputs_info=[ - { - "initial": at.as_tensor_variable( - np.r_[-1.0, 1.3, 0.0].astype(config.floatX) - ), - "taps": [-1, -3], - }, - ], - non_sequences=[a_at], - n_steps=10, - name="y_scan", - ) - y_scan_at.name = "y" - y_scan_at.owner.inputs[0].name = "y_all" - - out_fg = FunctionGraph([a_at], [y_scan_at]) - - test_input_vals = [np.array(10.0).astype(config.floatX)] - compare_jax_and_py(out_fg, test_input_vals) - - -def test_jax_Subtensors(): - # Basic indices - x_at = at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) - out_at = x_at[1, 2, 0] - assert isinstance(out_at.owner.op, at_subtensor.Subtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - out_at = x_at[1:2, 1, :] - assert isinstance(out_at.owner.op, at_subtensor.Subtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - # Advanced indexing - out_at = at_subtensor.advanced_subtensor1(x_at, [1, 2]) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor1) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - out_at = x_at[[1, 2], [2, 3]] - assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - # Advanced and basic indexing - out_at = x_at[[1, 2], :] - assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - out_at = x_at[[1, 2], :, [3, 4]] - assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", -) -def test_jax_Subtensors_omni(): - x_at = at.arange(3 * 4 * 5).reshape((3, 4, 5)) - - # Boolean indices - out_at = x_at[x_at < 0] - assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - -def test_jax_IncSubtensor(): - rng = np.random.default_rng(213234) - - x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) - x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)) - - # "Set" basic indices - st_at = at.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) - out_at = at_subtensor.set_subtensor(x_at[1, 2, 3], st_at) - assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) - out_at = at_subtensor.set_subtensor(x_at[:2, 0, 0], st_at) - assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - out_at = at_subtensor.set_subtensor(x_at[0, 1:3, 0], st_at) - assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - # "Set" advanced indices - st_at = at.as_tensor_variable( - rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX) - ) - out_at = at_subtensor.set_subtensor(x_at[np.r_[0, 2]], st_at) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) - out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, 0], st_at) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - # "Set" boolean indices - mask_at = at.constant(x_np > 0) - out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - # "Increment" basic indices - st_at = at.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) - out_at = at_subtensor.inc_subtensor(x_at[1, 2, 3], st_at) - assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) - out_at = at_subtensor.inc_subtensor(x_at[:2, 0, 0], st_at) - assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - out_at = at_subtensor.set_subtensor(x_at[0, 1:3, 0], st_at) - assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - # "Increment" advanced indices - st_at = at.as_tensor_variable( - rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX) - ) - out_at = at_subtensor.inc_subtensor(x_at[np.r_[0, 2]], st_at) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) - out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, 0], st_at) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - # "Increment" boolean indices - mask_at = at.constant(x_np > 0) - out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - - -def test_jax_IncSubtensors_unsupported(): - rng = np.random.default_rng(213234) - x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) - x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)) - - mask_at = at.as_tensor(x_np) > 0 - out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_at]) - with pytest.raises( - NonConcreteBooleanIndexError, match="Array boolean indices must be concrete" - ): - compare_jax_and_py(out_fg, []) - - mask_at = at.as_tensor_variable(x_np) > 0 - out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_at]) - with pytest.raises( - NonConcreteBooleanIndexError, match="Array boolean indices must be concrete" - ): - compare_jax_and_py(out_fg, []) - - st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) - out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_at]) - with pytest.raises(IndexError, match="Array slice indices must have static"): - compare_jax_and_py(out_fg, []) - - st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) - out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_at]) - with pytest.raises(IndexError, match="Array slice indices must have static"): - compare_jax_and_py(out_fg, []) - - -def test_jax_ifelse(): - - true_vals = np.r_[1, 2, 3] - false_vals = np.r_[-1, -2, -3] - - x = ifelse(np.array(True), true_vals, false_vals) - x_fg = FunctionGraph([], [x]) - - compare_jax_and_py(x_fg, []) - - a = dscalar("a") - a.tag.test_value = np.array(0.2, dtype=config.floatX) - x = ifelse(a < 0.5, true_vals, false_vals) - x_fg = FunctionGraph([a], [x]) # I.e. False - - compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs]) - - -def test_jax_checkandraise(): - p = scalar() - p.tag.test_value = 0 - - res = assert_op(p, p < 1.0) - res_fg = FunctionGraph([p], [res]) - - with pytest.raises(NotImplementedError): - compare_jax_and_py(res_fg, [1.0]) - - -def test_jax_CAReduce(): - a_at = vector("a") - a_at.tag.test_value = np.r_[1, 2, 3].astype(config.floatX) - - x = at_sum(a_at, axis=None) - x_fg = FunctionGraph([a_at], [x]) - - compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)]) - - a_at = matrix("a") - a_at.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) - - x = at_sum(a_at, axis=0) - x_fg = FunctionGraph([a_at], [x]) - - compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) - - x = at_sum(a_at, axis=1) - x_fg = FunctionGraph([a_at], [x]) - - compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) - - a_at = matrix("a") - a_at.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX) - - x = prod(a_at, axis=0) - x_fg = FunctionGraph([a_at], [x]) - - compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) - - x = at_all(a_at) - x_fg = FunctionGraph([a_at], [x]) - - compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)]) - - -def test_jax_MakeVector(): - x = at.make_vector(1, 2, 3) - x_fg = FunctionGraph([], [x]) - - compare_jax_and_py(x_fg, []) - - -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", -) -def test_jax_Reshape(): - a = vector("a") - x = reshape(a, (2, 2)) - x_fg = FunctionGraph([a], [x]) - compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) - - # Test breaking "omnistaging" changes in JAX. - # See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68 - x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2)) - x_fg = FunctionGraph([a], [x]) - compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) - - -@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs") -def test_jax_Reshape_nonconcrete(): - a = vector("a") - b = iscalar("b") - x = reshape(a, (b, b)) - x_fg = FunctionGraph([a, b], [x]) - compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]) - - -def test_jax_Dimshuffle(): - a_at = matrix("a") - - x = a_at.T - x_fg = FunctionGraph([a_at], [x]) - compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) - - x = a_at.dimshuffle([0, 1, "x"]) - x_fg = FunctionGraph([a_at], [x]) - compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)]) - - a_at = tensor(dtype=config.floatX, shape=[False, True]) - x = a_at.dimshuffle((0,)) - x_fg = FunctionGraph([a_at], [x]) - compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) - - a_at = tensor(dtype=config.floatX, shape=[False, True]) - x = at_elemwise.DimShuffle([False, True], (0,))(a_at) - x_fg = FunctionGraph([a_at], [x]) - compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)]) - - -def test_jax_Join(): - a = matrix("a") - b = matrix("b") - - x = at.join(0, a, b) - x_fg = FunctionGraph([a, b], [x]) - compare_jax_and_py( - x_fg, - [ - np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), - np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), - ], - ) - compare_jax_and_py( - x_fg, - [ - np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), - np.c_[[4.0, 5.0]].astype(config.floatX), - ], - ) - - x = at.join(1, a, b) - x_fg = FunctionGraph([a, b], [x]) - compare_jax_and_py( - x_fg, - [ - np.c_[[1.0, 2.0, 3.0]].astype(config.floatX), - np.c_[[4.0, 5.0, 6.0]].astype(config.floatX), - ], - ) - compare_jax_and_py( - x_fg, - [ - np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX), - np.c_[[5.0, 6.0]].astype(config.floatX), - ], - ) - - -def test_jax_variadic_Scalar(): - mu = vector("mu", dtype=config.floatX) - mu.tag.test_value = np.r_[0.1, 1.1].astype(config.floatX) - tau = vector("tau", dtype=config.floatX) - tau.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) - - res = -tau * mu - - fgraph = FunctionGraph([mu, tau], [res]) - - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - res = -tau * (tau - mu) ** 2 - - fgraph = FunctionGraph([mu, tau], [res]) - - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - -def test_jax_logp(): - - mu = vector("mu") - mu.tag.test_value = np.r_[0.0, 0.0].astype(config.floatX) - tau = vector("tau") - tau.tag.test_value = np.r_[1.0, 1.0].astype(config.floatX) - sigma = vector("sigma") - sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(config.floatX) - value = vector("value") - value.tag.test_value = np.r_[0.1, -10].astype(config.floatX) - - logp = (-tau * (value - mu) ** 2 + log(tau / np.pi / 2.0)) / 2.0 - conditions = [sigma > 0] - alltrue = at_all([at_all(1 * val) for val in conditions]) - normal_logp = at.switch(alltrue, logp, -np.inf) - - fgraph = FunctionGraph([mu, tau, sigma, value], [normal_logp]) - - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - -def test_jax_multioutput(): - x = vector("x") - x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) - y = vector("y") - y.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) - - w = cosh(x**2 + y / 3.0) - v = cosh(x / 3.0 + y**2) - - fgraph = FunctionGraph([x, y], [w, v]) - - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - -def test_nnet(): - x = vector("x") - x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) - - out = sigmoid(x) - fgraph = FunctionGraph([x], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - out = at_nnet.ultra_fast_sigmoid(x) - fgraph = FunctionGraph([x], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - out = softplus(x) - fgraph = FunctionGraph([x], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - -@pytest.mark.parametrize("axis", [None, 0, 1]) -def test_softmax(axis): - x = matrix("x") - x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) - out = at_nnet.softmax(x, axis=axis) - fgraph = FunctionGraph([x], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - -@pytest.mark.parametrize("axis", [None, 0, 1]) -def test_logsoftmax(axis): - x = matrix("x") - x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) - out = at_nnet.logsoftmax(x, axis=axis) - fgraph = FunctionGraph([x], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - -@pytest.mark.parametrize("axis", [None, 0, 1]) -def test_softmax_grad(axis): - dy = matrix("dy") - dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) - sm = matrix("sm") - sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) - out = SoftmaxGrad(axis=axis)(dy, sm) - fgraph = FunctionGraph([dy, sm], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", -) -def test_tensor_basics(): - y = vector("y") - y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) - x = vector("x") - x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX) - A = matrix("A") - A.tag.test_value = np.empty((2, 2), dtype=config.floatX) - alpha = scalar("alpha") - alpha.tag.test_value = np.array(3.0, dtype=config.floatX) - beta = scalar("beta") - beta.tag.test_value = np.array(5.0, dtype=config.floatX) - - # This should be converted into a `Gemv` `Op` when the non-JAX compatible - # optimizations are turned on; however, when using JAX mode, it should - # leave the expression alone. - out = y.dot(alpha * A).dot(x) + beta * y - fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - out = maximum(y, x) - fgraph = FunctionGraph([y, x], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - out = at_max(y) - fgraph = FunctionGraph([y], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - -@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs") -def test_arange_nonconcrete(): - - a = scalar("a") - a.tag.test_value = 10 - - out = at.arange(a) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - -@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs") -def test_unique_nonconcrete(): - a = matrix("a") - a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) - - out = at_extra_ops.Unique()(a) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - -def test_identity(): - a = scalar("a") - a.tag.test_value = 10 - - out = aes.identity(a) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - -def test_second(): - a0 = scalar("a0") - b = scalar("b") - - out = aes.second(a0, b) - fgraph = FunctionGraph([a0, b], [out]) - compare_jax_and_py(fgraph, [10.0, 5.0]) - - a1 = vector("a1") - out = at.second(a1, b) - fgraph = FunctionGraph([a1, b], [out]) - compare_jax_and_py(fgraph, [np.zeros([5], dtype=config.floatX), 5.0]) - - -def test_jax_BatchedDot(): - # tensor3 . tensor3 - a = tensor3("a") - a.tag.test_value = ( - np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) - ) - b = tensor3("b") - b.tag.test_value = ( - np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) - ) - out = at_blas.BatchedDot()(a, b) - fgraph = FunctionGraph([a, b], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - # A dimension mismatch should raise a TypeError for compatibility - inputs = [get_test_value(a)[:-1], get_test_value(b)] - opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) - jax_mode = Mode(JAXLinker(), opts) - aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode) - with pytest.raises(TypeError): - aesara_jax_fn(*inputs) - - # matrix . matrix - a = matrix("a") - a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape((5, 3)) - b = matrix("b") - b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape((5, 3)) - out = at_blas.BatchedDot()(a, b) - fgraph = FunctionGraph([a, b], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - -def test_shared(): - a = shared(np.array([1, 2, 3], dtype=config.floatX)) - - aesara_jax_fn = function([], a, mode="JAX") - jax_res = aesara_jax_fn() - - assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) - np.testing.assert_allclose(jax_res, a.get_value()) - - aesara_jax_fn = function([], a * 2, mode="JAX") - jax_res = aesara_jax_fn() - - assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) - np.testing.assert_allclose(jax_res, a.get_value() * 2) - - # Changed the shared value and make sure that the JAX-compiled - # function also changes. - new_a_value = np.array([3, 4, 5], dtype=config.floatX) - a.set_value(new_a_value) - - jax_res = aesara_jax_fn() - assert isinstance(jax_res, jax.interpreters.xla.DeviceArray) - np.testing.assert_allclose(jax_res, new_a_value * 2) - - -def test_extra_ops(): - a = matrix("a") - a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) - - out = at_extra_ops.cumsum(a, axis=0) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - out = at_extra_ops.cumprod(a, axis=1) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - out = at_extra_ops.diff(a, n=2, axis=1) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - out = at_extra_ops.repeat(a, (3, 3), axis=1) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - c = at.as_tensor(5) - - out = at_extra_ops.fill_diagonal(a, c) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - with pytest.raises(NotImplementedError): - out = at_extra_ops.fill_diagonal_offset(a, c, c) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - with pytest.raises(NotImplementedError): - out = at_extra_ops.Unique(axis=1)(a) - fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - indices = np.arange(np.product((3, 4))) - out = at_extra_ops.unravel_index(indices, (3, 4), order="C") - fgraph = FunctionGraph([], out) - compare_jax_and_py( - fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False - ) - - -def set_test_value(x, v): - x.tag.test_value = v - return x - - -@pytest.mark.parametrize( - "x, shape", - [ - ( - set_test_value( - vector("x"), np.random.random(size=(2,)).astype(config.floatX) - ), - [at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)], - ), - ( - set_test_value( - vector("x"), np.random.random(size=(2,)).astype(config.floatX) - ), - [at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)], - ), - ], -) -def test_BroadcastTo(x, shape): - out = at_extra_ops.broadcast_to(x, shape) - fgraph = FunctionGraph(outputs=[out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", -) -def test_extra_ops_omni(): - a = matrix("a") - a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) - - # This function also cannot take symbolic input. - c = at.as_tensor(5) - out = at_extra_ops.bartlett(c) - fgraph = FunctionGraph([], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4)) - out = at_extra_ops.ravel_multi_index(multi_index, (3, 4)) - fgraph = FunctionGraph([], [out]) - compare_jax_and_py( - fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False - ) - - # The inputs are "concrete", yet it still has problems? - out = at_extra_ops.Unique()( - at.as_tensor(np.arange(6, dtype=config.floatX).reshape((3, 2))) - ) - fgraph = FunctionGraph([], [out]) - compare_jax_and_py(fgraph, []) - - -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.26"), - reason="JAX samplers require concrete/static shape values?", -) -@pytest.mark.parametrize( - "at_dist, dist_params, rng, size", - [ - ( - normal, - (), - shared(np.random.RandomState(123)), - 10000, - ), - ( - normal, - (), - shared(np.random.default_rng(123)), - 10000, - ), - ], -) -def test_random_stats(at_dist, dist_params, rng, size): - # The RNG states are not 1:1, so the best we can do is check some summary - # statistics of the samples - out = normal(*dist_params, rng=rng, size=size) - fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) - - def assert_fn(x, y): - (x,) = x - (y,) = y - assert x.dtype.kind == y.dtype.kind - - d = 2 if config.floatX == "float64" else 1 - np.testing.assert_array_almost_equal(np.abs(x.mean()), np.abs(y.mean()), d) - - compare_jax_and_py(fgraph, [], assert_fn=assert_fn) - - -def test_random_unimplemented(): - class NonExistentRV(RandomVariable): - name = "non-existent" - ndim_supp = 0 - ndims_params = [] - dtype = "floatX" - - def __call__(self, size=None, **kwargs): - return super().__call__(size=size, **kwargs) - - def rng_fn(cls, rng, size): - return 0 - - nonexistentrv = NonExistentRV() - rng = shared(np.random.RandomState(123)) - out = nonexistentrv(rng=rng) - fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) - - with pytest.raises(NotImplementedError): - compare_jax_and_py(fgraph, []) - - -def test_RandomStream(): - srng = RandomStream(seed=123) - out = srng.normal() - srng.normal() - - fn = function([], out, mode=jax_mode) - jax_res_1 = fn() - jax_res_2 = fn() - - assert np.array_equal(jax_res_1, jax_res_2) - - -def test_erf(): - x = scalar("x") - out = erf(x) - fg = FunctionGraph([x], [out]) - - compare_jax_and_py(fg, [1.0]) - - -def test_erfc(): - x = scalar("x") - out = erfc(x) - fg = FunctionGraph([x], [out]) - - compare_jax_and_py(fg, [1.0]) - - -def test_erfinv(): - x = scalar("x") - out = erfinv(x) - fg = FunctionGraph([x], [out]) - - compare_jax_and_py(fg, [1.0]) - - -def test_psi(): - x = scalar("x") - out = psi(x) - fg = FunctionGraph([x], [out]) - compare_jax_and_py(fg, [3.0]) - - -def test_log1mexp(): - x = vector("x") - out = log1mexp(x) - fg = FunctionGraph([x], [out]) - - compare_jax_and_py(fg, [[-1.0, -0.75, -0.5, -0.25]]) diff --git a/tests/link/test_numba.py b/tests/link/test_numba.py deleted file mode 100644 index 5308dc48a9..0000000000 --- a/tests/link/test_numba.py +++ /dev/null @@ -1,3606 +0,0 @@ -import contextlib -import inspect -from unittest import mock - -import numba -import numpy as np -import pytest -import scipy.stats as stats - -import aesara.scalar as aes -import aesara.scalar.basic as aesb -import aesara.scalar.math as aesm -import aesara.tensor as at -import aesara.tensor.basic as atb -import aesara.tensor.inplace as ati -import aesara.tensor.math as aem -import aesara.tensor.nnet.basic as nnetb -import aesara.tensor.random.basic as aer -from aesara import config, shared -from aesara.compile.function import function -from aesara.compile.mode import Mode -from aesara.compile.ops import ViewOp, deep_copy_op -from aesara.compile.sharedvalue import SharedVariable -from aesara.graph.basic import Apply, Constant -from aesara.graph.fg import FunctionGraph -from aesara.graph.op import Op, get_test_value -from aesara.graph.optdb import OptimizationQuery -from aesara.graph.type import Type -from aesara.ifelse import ifelse -from aesara.link.numba.dispatch import basic as numba_basic -from aesara.link.numba.dispatch import numba_typify -from aesara.link.numba.linker import NumbaLinker -from aesara.raise_op import assert_op -from aesara.scalar.basic import Composite -from aesara.scan.basic import scan -from aesara.scan.utils import until -from aesara.tensor import blas -from aesara.tensor import elemwise as at_elemwise -from aesara.tensor import extra_ops, nlinalg, slinalg -from aesara.tensor import subtensor as at_subtensor -from aesara.tensor.elemwise import Elemwise -from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum -from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast - - -class MyType(Type): - def filter(self, data): - return data - - def __eq__(self, other): - return isinstance(other, MyType) - - def __hash__(self): - return hash(MyType) - - -class MyOp(Op): - def perform(self, *args): - pass - - -class MySingleOut(Op): - def make_node(self, a, b): - return Apply(self, [a, b], [a.type()]) - - def perform(self, node, inputs, outputs): - res = (inputs[0] + inputs[1]).astype(inputs[0][0].dtype) - outputs[0][0] = res - - -class MyMultiOut(Op): - nin = 2 - nout = 2 - - @staticmethod - def impl(a, b): - res1 = 2 * a - res2 = 2 * b - return [res1, res2] - - def make_node(self, a, b): - return Apply(self, [a, b], [a.type(), b.type()]) - - def perform(self, node, inputs, outputs): - res1, res2 = self.impl(inputs[0], inputs[1]) - outputs[0][0] = res1 - outputs[1][0] = res2 - - -my_multi_out = Elemwise(MyMultiOut()) -my_multi_out.ufunc = MyMultiOut.impl -my_multi_out.ufunc.nin = 2 -my_multi_out.ufunc.nout = 2 - -opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) -numba_mode = Mode(NumbaLinker(), opts) -py_mode = Mode("py", opts) - -rng = np.random.default_rng(42849) - - -def set_test_value(x, v): - x.tag.test_value = v - return x - - -def compare_shape_dtype(x, y): - (x,) = x - (y,) = y - return x.shape == y.shape and x.dtype == y.dtype - - -def eval_python_only(fn_inputs, fgraph, inputs): - """Evaluate the Numba implementation in pure Python for coverage purposes.""" - - def py_tuple_setitem(t, i, v): - ll = list(t) - ll[i] = v - return tuple(ll) - - def py_to_scalar(x): - if isinstance(x, np.ndarray): - return x.item() - else: - return x - - def njit_noop(*args, **kwargs): - if len(args) == 1 and callable(args[0]): - return args[0] - else: - return lambda x: x - - def vectorize_noop(*args, **kwargs): - def wrap(fn): - # `numba.vectorize` allows an `out` positional argument. We need - # to account for that - sig = inspect.signature(fn) - nparams = len(sig.parameters) - - def inner_vec(*args): - if len(args) > nparams: - # An `out` argument has been specified for an in-place - # operation - out = args[-1] - out[...] = np.vectorize(fn)(*args[:nparams]) - return out - else: - return np.vectorize(fn)(*args) - - return inner_vec - - if len(args) == 1 and callable(args[0]): - return wrap(args[0], **kwargs) - else: - return wrap - - mocks = [ - mock.patch("numba.njit", njit_noop), - mock.patch("numba.vectorize", vectorize_noop), - mock.patch("aesara.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem), - mock.patch("aesara.link.numba.dispatch.basic.numba_njit", njit_noop), - mock.patch("aesara.link.numba.dispatch.basic.numba_vectorize", vectorize_noop), - mock.patch("aesara.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x), - mock.patch("aesara.link.numba.dispatch.basic.to_scalar", py_to_scalar), - mock.patch( - "aesara.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype", - lambda dtype: dtype, - ), - mock.patch("numba.np.unsafe.ndarray.to_fixed_tuple", lambda x, n: tuple(x)), - ] - - with contextlib.ExitStack() as stack: - for ctx in mocks: - stack.enter_context(ctx) - - aesara_numba_fn = function( - fn_inputs, - fgraph.outputs, - mode=numba_mode, - accept_inplace=True, - ) - _ = aesara_numba_fn(*inputs) - - -def compare_numba_and_py(fgraph, inputs, assert_fn=None): - """Function to compare python graph output and Numba compiled output for testing equality - - In the tests below computational graphs are defined in Aesara. These graphs are then passed to - this function which then compiles the graphs in both Numba and python, runs the calculation - in both and checks if the results are the same - - Parameters - ---------- - fgraph: FunctionGraph - Aesara function Graph object - inputs: iter - Inputs for function graph - assert_fn: func, opt - Assert function used to check for equality between python and Numba. If not - provided uses np.testing.assert_allclose - - """ - if assert_fn is None: - - def assert_fn(x, y): - return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype( - x, y - ) - - fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] - - aesara_py_fn = function( - fn_inputs, fgraph.outputs, mode=py_mode, accept_inplace=True - ) - py_res = aesara_py_fn(*inputs) - - aesara_numba_fn = function( - fn_inputs, - fgraph.outputs, - mode=numba_mode, - accept_inplace=True, - ) - numba_res = aesara_numba_fn(*inputs) - - # Get some coverage - eval_python_only(fn_inputs, fgraph, inputs) - - if len(fgraph.outputs) > 1: - for j, p in zip(numba_res, py_res): - assert_fn(j, p) - else: - assert_fn(numba_res, py_res) - - return numba_res - - -@pytest.mark.parametrize( - "v, expected, force_scalar, not_implemented", - [ - (MyType(), None, False, True), - (aes.float32, numba.types.float32, False, False), - (at.fscalar, numba.types.Array(numba.types.float32, 0, "A"), False, False), - (at.fscalar, numba.types.float32, True, False), - (at.lvector, numba.types.int64[:], False, False), - (at.dmatrix, numba.types.float64[:, :], False, False), - (at.dmatrix, numba.types.float64, True, False), - ], -) -def test_get_numba_type(v, expected, force_scalar, not_implemented): - cm = ( - contextlib.suppress() - if not not_implemented - else pytest.raises(NotImplementedError) - ) - with cm: - res = numba_basic.get_numba_type(v, force_scalar=force_scalar) - assert res == expected - - -@pytest.mark.parametrize( - "v, expected, force_scalar", - [ - (Apply(MyOp(), [], []), numba.types.void(), False), - (Apply(MyOp(), [], []), numba.types.void(), True), - ( - Apply(MyOp(), [at.lvector()], []), - numba.types.void(numba.types.int64[:]), - False, - ), - (Apply(MyOp(), [at.lvector()], []), numba.types.void(numba.types.int64), True), - ( - Apply(MyOp(), [at.dmatrix(), aes.float32()], [at.dmatrix()]), - numba.types.float64[:, :](numba.types.float64[:, :], numba.types.float32), - False, - ), - ( - Apply(MyOp(), [at.dmatrix(), aes.float32()], [at.dmatrix()]), - numba.types.float64(numba.types.float64, numba.types.float32), - True, - ), - ( - Apply(MyOp(), [at.dmatrix(), aes.float32()], [at.dmatrix(), aes.int32()]), - numba.types.Tuple([numba.types.float64[:, :], numba.types.int32])( - numba.types.float64[:, :], numba.types.float32 - ), - False, - ), - ( - Apply(MyOp(), [at.dmatrix(), aes.float32()], [at.dmatrix(), aes.int32()]), - numba.types.Tuple([numba.types.float64, numba.types.int32])( - numba.types.float64, numba.types.float32 - ), - True, - ), - ], -) -def test_create_numba_signature(v, expected, force_scalar): - res = numba_basic.create_numba_signature(v, force_scalar=force_scalar) - assert res == expected - - -@pytest.mark.parametrize( - "input, wrapper_fn, check_fn", - [ - ( - np.random.RandomState(1), - numba_typify, - lambda x, y: np.all(x.get_state()[1] == y.get_state()[1]), - ) - ], -) -def test_box_unbox(input, wrapper_fn, check_fn): - input = wrapper_fn(input) - - pass_through = numba.njit(lambda x: x) - res = pass_through(input) - - assert isinstance(res, type(input)) - assert check_fn(res, input) - - -@pytest.mark.parametrize( - "inputs, input_vals, output_fn, exc", - [ - ( - [at.vector()], - [rng.uniform(size=100).astype(config.floatX)], - lambda x: at.gammaln(x), - None, - ), - ( - [at.vector()], - [rng.standard_normal(100).astype(config.floatX)], - lambda x: at.sigmoid(x), - None, - ), - ( - [at.vector()], - [rng.standard_normal(100).astype(config.floatX)], - lambda x: at.log1mexp(x), - None, - ), - ( - [at.vector()], - [rng.standard_normal(100).astype(config.floatX)], - lambda x: at.erf(x), - None, - ), - ( - [at.vector()], - [rng.standard_normal(100).astype(config.floatX)], - lambda x: at.erfc(x), - None, - ), - ( - [at.vector() for i in range(4)], - [rng.standard_normal(100).astype(config.floatX) for i in range(4)], - lambda x, y, x1, y1: (x + y) * (x1 + y1) * y, - None, - ), - ( - [at.matrix(), at.scalar()], - [rng.normal(size=(2, 2)).astype(config.floatX), 0.0], - lambda a, b: at.switch(a, b, a), - None, - ), - ( - [at.scalar(), at.scalar()], - [ - np.array(1.0, dtype=config.floatX), - np.array(1.0, dtype=config.floatX), - ], - lambda x, y: ati.add_inplace(deep_copy_op(x), deep_copy_op(y)), - None, - ), - ( - [at.vector(), at.vector()], - [ - rng.standard_normal(100).astype(config.floatX), - rng.standard_normal(100).astype(config.floatX), - ], - lambda x, y: ati.add_inplace(deep_copy_op(x), deep_copy_op(y)), - None, - ), - ( - [at.vector(), at.vector()], - [ - rng.standard_normal(100).astype(config.floatX), - rng.standard_normal(100).astype(config.floatX), - ], - lambda x, y: my_multi_out(x, y), - NotImplementedError, - ), - ], -) -def test_Elemwise(inputs, input_vals, output_fn, exc): - - outputs = output_fn(*inputs) - - out_fg = FunctionGraph( - outputs=[outputs] if not isinstance(outputs, list) else outputs - ) - - cm = contextlib.suppress() if exc is None else pytest.raises(exc) - with cm: - compare_numba_and_py(out_fg, input_vals) - - -@pytest.mark.parametrize( - "inputs, input_values, scalar_fn", - [ - ( - [at.scalar("x"), at.scalar("y"), at.scalar("z")], - [ - np.array(10, dtype=config.floatX), - np.array(20, dtype=config.floatX), - np.array(30, dtype=config.floatX), - ], - lambda x, y, z: aes.add(x, y, z), - ), - ( - [at.scalar("x"), at.scalar("y"), at.scalar("z")], - [ - np.array(10, dtype=config.floatX), - np.array(20, dtype=config.floatX), - np.array(30, dtype=config.floatX), - ], - lambda x, y, z: aes.mul(x, y, z), - ), - ( - [at.scalar("x"), at.scalar("y")], - [ - np.array(10, dtype=config.floatX), - np.array(20, dtype=config.floatX), - ], - lambda x, y: x + y * 2 + aes.exp(x - y), - ), - ], -) -def test_Composite(inputs, input_values, scalar_fn): - composite_inputs = [aes.float64(i.name) for i in inputs] - comp_op = Elemwise(Composite(composite_inputs, [scalar_fn(*composite_inputs)])) - out_fg = FunctionGraph(inputs, [comp_op(*inputs)]) - compare_numba_and_py(out_fg, input_values) - - -@pytest.mark.parametrize( - "x, indices", - [ - (at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1,)), - ( - at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - (slice(None)), - ), - (at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), (1, 2, 0)), - ( - at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - (slice(1, 2), 1, slice(None)), - ), - ], -) -def test_Subtensor(x, indices): - """Test NumPy's basic indexing.""" - out_at = x[indices] - assert isinstance(out_at.owner.op, at_subtensor.Subtensor) - out_fg = FunctionGraph([], [out_at]) - compare_numba_and_py(out_fg, []) - - -@pytest.mark.parametrize( - "x, indices", - [ - (at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2],)), - ], -) -def test_AdvancedSubtensor1(x, indices): - """Test NumPy's advanced indexing in one dimension.""" - out_at = at_subtensor.advanced_subtensor1(x, *indices) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor1) - out_fg = FunctionGraph([], [out_at]) - compare_numba_and_py(out_fg, []) - - -@pytest.mark.parametrize( - "x, indices", - [ - (at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), ([1, 2], [2, 3])), - ( - at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - ([1, 2], slice(None), [3, 4]), - ), - ], -) -def test_AdvancedSubtensor(x, indices): - """Test NumPy's advanced indexing in more than one dimension.""" - out_at = x[indices] - assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_numba_and_py(out_fg, []) - - -@pytest.mark.parametrize( - "x, y, indices", - [ - ( - at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - at.as_tensor(np.array(10)), - (1,), - ), - ( - at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - at.as_tensor(rng.poisson(size=(4, 5))), - (slice(None)), - ), - ( - at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - at.as_tensor(np.array(10)), - (1, 2, 0), - ), - ( - at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - at.as_tensor(rng.poisson(size=(1, 5))), - (slice(1, 2), 1, slice(None)), - ), - ], -) -def test_IncSubtensor(x, y, indices): - out_at = at.set_subtensor(x[indices], y) - assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_numba_and_py(out_fg, []) - - out_at = at.inc_subtensor(x[indices], y) - assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_numba_and_py(out_fg, []) - - x_at = x.type() - out_at = at.set_subtensor(x_at[indices], y, inplace=True) - assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor) - out_fg = FunctionGraph([x_at], [out_at]) - compare_numba_and_py(out_fg, [x.data]) - - -@pytest.mark.parametrize( - "x, y, indices", - [ - ( - at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - at.as_tensor(rng.poisson(size=(2, 4, 5))), - ([1, 2],), - ), - ], -) -def test_AdvancedIncSubtensor1(x, y, indices): - out_at = at_subtensor.advanced_set_subtensor1(x, y, *indices) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1) - out_fg = FunctionGraph([], [out_at]) - compare_numba_and_py(out_fg, []) - - out_at = at_subtensor.advanced_inc_subtensor1(x, y, *indices) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1) - out_fg = FunctionGraph([], [out_at]) - compare_numba_and_py(out_fg, []) - - x_at = x.type() - out_at = at_subtensor.AdvancedIncSubtensor1(inplace=True)(x_at, y, *indices) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1) - out_fg = FunctionGraph([x_at], [out_at]) - compare_numba_and_py(out_fg, [x.data]) - - -@pytest.mark.parametrize( - "x, y, indices", - [ - ( - at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - at.as_tensor(rng.poisson(size=(2, 5))), - ([1, 2], [2, 3]), - ), - ( - at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), - at.as_tensor(rng.poisson(size=(2, 4))), - ([1, 2], slice(None), [3, 4]), - ), - ], -) -def test_AdvancedIncSubtensor(x, y, indices): - out_at = at.set_subtensor(x[indices], y) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_numba_and_py(out_fg, []) - - out_at = at.inc_subtensor(x[indices], y) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_numba_and_py(out_fg, []) - - x_at = x.type() - out_at = at.set_subtensor(x_at[indices], y) - # Inplace isn't really implemented for `AdvancedIncSubtensor`, so we just - # hack it on here - out_at.owner.op.inplace = True - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([x_at], [out_at]) - compare_numba_and_py(out_fg, [x.data]) - - -@pytest.mark.parametrize( - "x, i", - [ - (np.zeros((20, 3)), 1), - ], -) -def test_Shape(x, i): - g = Shape()(at.as_tensor_variable(x)) - g_fg = FunctionGraph([], [g]) - - compare_numba_and_py(g_fg, []) - - g = Shape_i(i)(at.as_tensor_variable(x)) - g_fg = FunctionGraph([], [g]) - - compare_numba_and_py(g_fg, []) - - -@pytest.mark.parametrize( - "v, shape", - [ - (0.0, (2, 3)), - (1.1, (2, 3)), - (set_test_value(at.scalar("a"), np.array(10.0, dtype=config.floatX)), (20,)), - (set_test_value(at.vector("a"), np.ones(10, dtype=config.floatX)), (20, 10)), - ], -) -def test_Alloc(v, shape): - g = at.alloc(v, *shape) - g_fg = FunctionGraph(outputs=[g]) - - (numba_res,) = compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - assert numba_res.shape == shape - - -def test_AllocEmpty(): - - x = at.empty((2, 3), dtype="float32") - x_fg = FunctionGraph([], [x]) - - # We cannot compare the values in the arrays, only the shapes and dtypes - compare_numba_and_py(x_fg, [], assert_fn=compare_shape_dtype) - - -@pytest.mark.parametrize( - "v, offset", - [ - (set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), 0), - (set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), 1), - (set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), -1), - ], -) -def test_AllocDiag(v, offset): - g = atb.AllocDiag(offset=offset)(v) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "v, new_order", - [ - # `{'drop': [], 'shuffle': [], 'augment': [0, 1]}` - ( - set_test_value( - at.lscalar(name="a"), - np.array(1, dtype=np.int64), - ), - ("x", "x"), - ), - # I.e. `a_at.T` - # `{'drop': [], 'shuffle': [1, 0], 'augment': []}` - ( - set_test_value( - at.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) - ), - (1, 0), - ), - # `{'drop': [], 'shuffle': [0, 1], 'augment': [2]}` - ( - set_test_value( - at.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX) - ), - (1, 0, "x"), - ), - # `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}` - ( - set_test_value( - at.tensor(config.floatX, [False, True, False], name="a"), - np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=config.floatX), - ), - ("x", 2, "x", 0, "x"), - ), - # I.e. `a_at.dimshuffle((0,))` - # `{'drop': [1], 'shuffle': [0], 'augment': []}` - ( - set_test_value( - at.tensor(config.floatX, [False, True], name="a"), - np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX), - ), - (0,), - ), - ( - set_test_value( - at.tensor(config.floatX, [False, True], name="a"), - np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX), - ), - (0,), - ), - ( - set_test_value( - at.tensor(config.floatX, [True, True, True], name="a"), - np.array([[[1.0]]], dtype=config.floatX), - ), - (), - ), - ], -) -def test_Dimshuffle(v, new_order): - g = at_elemwise.DimShuffle(v.broadcastable, new_order)(v) - g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "v", [set_test_value(aes.float64(), np.array(1.0, dtype="float64"))] -) -def test_TensorFromScalar(v): - g = atb.TensorFromScalar()(v) - g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "v", - [ - set_test_value(at.scalar(), np.array(1.0, dtype=config.floatX)), - ], -) -def test_ScalarFromTensor(v): - g = atb.ScalarFromTensor()(v) - g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -def test_Unbroadcast(): - v = set_test_value(at.row(), np.array([[1.0, 2.0]], dtype=config.floatX)) - g = Unbroadcast(0)(v) - g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "v, dtype", - [ - (set_test_value(at.fscalar(), np.array(1.0, dtype="float32")), aesb.float64), - (set_test_value(at.dscalar(), np.array(1.0, dtype="float64")), aesb.float32), - ], -) -def test_Cast(v, dtype): - g = aesb.Cast(dtype)(v) - g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "v, dtype", - [ - (set_test_value(at.iscalar(), np.array(10, dtype="int32")), aesb.float64), - ], -) -def test_Inv(v, dtype): - g = aesb.inv(v) - g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "v, shape, ndim", - [ - (set_test_value(at.vector(), np.array([4], dtype=config.floatX)), (), 0), - (set_test_value(at.vector(), np.arange(4, dtype=config.floatX)), (2, 2), 2), - ( - set_test_value(at.vector(), np.arange(4, dtype=config.floatX)), - set_test_value(at.lvector(), np.array([2, 2], dtype="int64")), - 2, - ), - ], -) -def test_Reshape(v, shape, ndim): - g = Reshape(ndim)(v, shape) - g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -def test_Reshape_scalar(): - v = at.vector() - v.tag.test_value = np.array([1.0], dtype=config.floatX) - g = Reshape(1)(v[0], (1,)) - g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "v, shape, fails", - [ - ( - set_test_value(at.matrix(), np.array([[1.0]], dtype=config.floatX)), - (1, 1), - False, - ), - ( - set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), - (1, 1), - True, - ), - ( - set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), - (1, None), - False, - ), - ], -) -def test_SpecifyShape(v, shape, fails): - g = SpecifyShape()(v, *shape) - g_fg = FunctionGraph(outputs=[g]) - cm = contextlib.suppress() if not fails else pytest.raises(AssertionError) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "v", - [ - set_test_value(at.vector(), np.arange(4, dtype=config.floatX)), - ], -) -def test_ViewOp(v): - g = ViewOp()(v) - g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, y", - [ - ( - set_test_value(at.lvector(), np.arange(4, dtype="int64")), - set_test_value(at.dvector(), np.arange(4, dtype="float64")), - ), - ( - set_test_value(at.dmatrix(), np.arange(4, dtype="float64").reshape((2, 2))), - set_test_value(at.lscalar(), np.array(4, dtype="int64")), - ), - ], -) -def test_Second(x, y): - # We use the `Elemwise`-wrapped version of `Second` - g = at.second(x, y) - g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "v, min, max", - [ - (set_test_value(at.scalar(), np.array(10, dtype=config.floatX)), 3.0, 7.0), - (set_test_value(at.scalar(), np.array(1, dtype=config.floatX)), 3.0, 7.0), - (set_test_value(at.scalar(), np.array(10, dtype=config.floatX)), 7.0, 3.0), - ], -) -def test_Clip(v, min, max): - g = aes.clip(v, min, max) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -def test_scalar_Elemwise_Clip(): - a = at.scalar("a") - b = at.scalar("b") - - z = at.switch(1, a, b) - c = at.clip(z, 1, 3) - c_fg = FunctionGraph(outputs=[c]) - - compare_numba_and_py(c_fg, [1, 1]) - - -@pytest.mark.parametrize( - "vals, dtype", - [ - ( - ( - set_test_value(at.scalar(), np.array(1, dtype=config.floatX)), - set_test_value(at.scalar(), np.array(2, dtype=config.floatX)), - set_test_value(at.scalar(), np.array(3, dtype=config.floatX)), - ), - config.floatX, - ), - ( - ( - set_test_value(at.dscalar(), np.array(1, dtype=np.float64)), - set_test_value(at.lscalar(), np.array(3, dtype=np.int32)), - ), - "float64", - ), - ( - (set_test_value(at.iscalar(), np.array(1, dtype=np.int32)),), - "float64", - ), - ( - (set_test_value(at.scalar(dtype=bool), True),), - bool, - ), - ], -) -def test_MakeVector(vals, dtype): - g = atb.MakeVector(dtype)(*vals) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "start, stop, step, dtype", - [ - ( - set_test_value(at.lscalar(), np.array(1)), - set_test_value(at.lscalar(), np.array(10)), - set_test_value(at.lscalar(), np.array(3)), - config.floatX, - ), - ], -) -def test_ARange(start, stop, step, dtype): - g = atb.ARange(dtype)(start, stop, step) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "careduce_fn, axis, v", - [ - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Sum( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 0, - set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x), - 0, - set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x), - 0, - set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x), - 0, - set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x), - 0, - set_test_value( - at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Sum( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 0, - set_test_value( - at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Sum( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - (0, 1), - set_test_value( - at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Sum( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - (1, 0), - set_test_value( - at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Sum( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - None, - set_test_value( - at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Sum( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 1, - set_test_value( - at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Prod( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 0, - set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: ProdWithoutZeros( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 0, - set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Prod( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 0, - set_test_value( - at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Prod( - axis=axis, dtype=dtype, acc_dtype=acc_dtype - )(x), - 1, - set_test_value( - at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x), - None, - set_test_value( - at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x), - None, - set_test_value( - at.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x), - None, - set_test_value( - at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) - ), - ), - ( - lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x), - None, - set_test_value( - at.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2)) - ), - ), - ], -) -def test_CAReduce(careduce_fn, axis, v): - g = careduce_fn(v, axis=axis) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "vals, axis", - [ - ( - ( - set_test_value( - at.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) - ), - set_test_value( - at.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) - ), - ), - 0, - ), - ( - ( - set_test_value( - at.matrix(), rng.normal(size=(2, 1)).astype(config.floatX) - ), - set_test_value( - at.matrix(), rng.normal(size=(3, 1)).astype(config.floatX) - ), - ), - 0, - ), - ( - ( - set_test_value( - at.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) - ), - set_test_value( - at.matrix(), rng.normal(size=(1, 2)).astype(config.floatX) - ), - ), - 1, - ), - ( - ( - set_test_value( - at.matrix(), rng.normal(size=(2, 2)).astype(config.floatX) - ), - set_test_value( - at.matrix(), rng.normal(size=(2, 1)).astype(config.floatX) - ), - ), - 1, - ), - ], -) -def test_Join(vals, axis): - g = at.join(axis, *vals) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -def test_Join_view(): - vals = ( - set_test_value(at.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)), - set_test_value(at.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)), - ) - g = atb.Join(view=1)(1, *vals) - g_fg = FunctionGraph(outputs=[g]) - - with pytest.raises(NotImplementedError): - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "n_splits, axis, values, sizes", - [ - ( - 0, - 0, - set_test_value(at.vector(), rng.normal(size=20).astype(config.floatX)), - set_test_value(at.vector(dtype="int64"), []), - ), - ( - 5, - 0, - set_test_value(at.vector(), rng.normal(size=5).astype(config.floatX)), - set_test_value( - at.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5) - ), - ), - ( - 5, - 0, - set_test_value(at.vector(), rng.normal(size=10).astype(config.floatX)), - set_test_value( - at.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5) - ), - ), - ( - 5, - -1, - set_test_value(at.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), - set_test_value( - at.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5) - ), - ), - ( - 5, - -2, - set_test_value(at.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), - set_test_value( - at.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5) - ), - ), - ], -) -def test_Split(n_splits, axis, values, sizes): - g = at.split(values, sizes, n_splits, axis=axis) - assert len(g) == n_splits - if n_splits == 0: - return - g_fg = FunctionGraph(outputs=[g] if n_splits == 1 else g) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "val, offset", - [ - ( - set_test_value( - at.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10)) - ), - 0, - ), - ( - set_test_value(at.vector(), np.arange(10, dtype=config.floatX)), - 0, - ), - ], -) -def test_ExtractDiag(val, offset): - g = at.diag(val, offset) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "n, m, k, dtype", - [ - (set_test_value(at.lscalar(), np.array(1, dtype=np.int64)), None, 0, None), - ( - set_test_value(at.lscalar(), np.array(1, dtype=np.int64)), - set_test_value(at.lscalar(), np.array(2, dtype=np.int64)), - 0, - "float32", - ), - ( - set_test_value(at.lscalar(), np.array(1, dtype=np.int64)), - set_test_value(at.lscalar(), np.array(2, dtype=np.int64)), - 1, - "int64", - ), - ], -) -def test_Eye(n, m, k, dtype): - g = at.eye(n, m, k, dtype=dtype) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "inputs, op, exc", - [ - ( - [ - set_test_value( - at.matrix(), rng.random(size=(2, 3)).astype(config.floatX) - ), - set_test_value(at.lmatrix(), rng.poisson(size=(2, 3))), - ], - MySingleOut, - UserWarning, - ), - ( - [ - set_test_value( - at.matrix(), rng.random(size=(2, 3)).astype(config.floatX) - ), - set_test_value(at.lmatrix(), rng.poisson(size=(2, 3))), - ], - MyMultiOut, - UserWarning, - ), - ], -) -def test_perform(inputs, op, exc): - - g = op()(*inputs) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -def test_perform_params(): - """This tests for `Op.perform` implementations that require the `params` arguments.""" - - x = at.vector() - x.tag.test_value = np.array([1.0, 2.0], dtype=config.floatX) - - out = assert_op(x, np.array(True)) - - if not isinstance(out, (list, tuple)): - out = [out] - - out_fg = FunctionGraph([x], out) - - with pytest.warns(UserWarning, match=".*object mode.*"): - compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) - - -def test_perform_type_convert(): - """This tests the use of `Type.filter` in `objmode`. - - The `Op.perform` takes a single input that it returns as-is, but it gets a - native scalar and it's supposed to return an `np.ndarray`. - """ - - x = at.vector() - x.tag.test_value = np.array([1.0, 2.0], dtype=config.floatX) - - out = assert_op(x.sum(), np.array(True)) - - if not isinstance(out, (list, tuple)): - out = [out] - - out_fg = FunctionGraph([x], out) - - with pytest.warns(UserWarning, match=".*object mode.*"): - compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) - - -@pytest.mark.parametrize( - "val", - [ - set_test_value(at.lscalar(), np.array(6, dtype="int64")), - ], -) -def test_Bartlett(val): - g = extra_ops.bartlett(val) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "val, axis, mode", - [ - ( - set_test_value( - at.matrix(), np.arange(3, dtype=config.floatX).reshape((3, 1)) - ), - 1, - "add", - ), - ( - set_test_value( - at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) - ), - 0, - "add", - ), - ( - set_test_value( - at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) - ), - 1, - "add", - ), - ( - set_test_value( - at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) - ), - 0, - "mul", - ), - ( - set_test_value( - at.matrix(), np.arange(6, dtype=config.floatX).reshape((3, 2)) - ), - 1, - "mul", - ), - ], -) -def test_CumOp(val, axis, mode): - g = extra_ops.CumOp(axis=axis, mode=mode)(val) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "a, val", - [ - ( - set_test_value(at.lmatrix(), np.zeros((10, 2), dtype="int64")), - set_test_value(at.lscalar(), np.array(1, dtype="int64")), - ) - ], -) -def test_FillDiagonal(a, val): - g = extra_ops.FillDiagonal()(a, val) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "a, val, offset", - [ - ( - set_test_value(at.lmatrix(), np.zeros((10, 2), dtype="int64")), - set_test_value(at.lscalar(), np.array(1, dtype="int64")), - set_test_value(at.lscalar(), np.array(-1, dtype="int64")), - ), - ( - set_test_value(at.lmatrix(), np.zeros((10, 2), dtype="int64")), - set_test_value(at.lscalar(), np.array(1, dtype="int64")), - set_test_value(at.lscalar(), np.array(0, dtype="int64")), - ), - ( - set_test_value(at.lmatrix(), np.zeros((10, 3), dtype="int64")), - set_test_value(at.lscalar(), np.array(1, dtype="int64")), - set_test_value(at.lscalar(), np.array(1, dtype="int64")), - ), - ], -) -def test_FillDiagonalOffset(a, val, offset): - g = extra_ops.FillDiagonalOffset()(a, val, offset) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "arr, shape, mode, order, exc", - [ - ( - tuple(set_test_value(at.lscalar(), v) for v in np.array([0])), - set_test_value(at.lvector(), np.array([2])), - "raise", - "C", - None, - ), - ( - tuple(set_test_value(at.lscalar(), v) for v in np.array([0, 0, 3])), - set_test_value(at.lvector(), np.array([2, 3, 4])), - "raise", - "C", - None, - ), - ( - tuple( - set_test_value(at.lvector(), v) - for v in np.array([[0, 1], [2, 0], [1, 3]]) - ), - set_test_value(at.lvector(), np.array([2, 3, 4])), - "raise", - "C", - None, - ), - ( - tuple( - set_test_value(at.lvector(), v) - for v in np.array([[0, 1], [2, 0], [1, 3]]) - ), - set_test_value(at.lvector(), np.array([2, 3, 4])), - "raise", - "F", - NotImplementedError, - ), - ( - tuple( - set_test_value(at.lvector(), v) - for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) - ), - set_test_value(at.lvector(), np.array([2, 3, 4])), - "raise", - "C", - ValueError, - ), - ( - tuple( - set_test_value(at.lvector(), v) - for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) - ), - set_test_value(at.lvector(), np.array([2, 3, 4])), - "wrap", - "C", - None, - ), - ( - tuple( - set_test_value(at.lvector(), v) - for v in np.array([[0, 1, 2], [2, 0, 3], [1, 3, 5]]) - ), - set_test_value(at.lvector(), np.array([2, 3, 4])), - "clip", - "C", - None, - ), - ], -) -def test_RavelMultiIndex(arr, shape, mode, order, exc): - g = extra_ops.RavelMultiIndex(mode, order)(*(arr + (shape,))) - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.raises(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, repeats, axis, exc", - [ - ( - set_test_value(at.lscalar(), np.array(1, dtype="int64")), - set_test_value(at.lscalar(), np.array(0, dtype="int64")), - None, - None, - ), - ( - set_test_value(at.lmatrix(), np.zeros((2, 2), dtype="int64")), - set_test_value(at.lscalar(), np.array(1, dtype="int64")), - None, - None, - ), - ( - set_test_value(at.lvector(), np.arange(2, dtype="int64")), - set_test_value(at.lvector(), np.array([1, 1], dtype="int64")), - None, - None, - ), - ( - set_test_value(at.lmatrix(), np.zeros((2, 2), dtype="int64")), - set_test_value(at.lscalar(), np.array(1, dtype="int64")), - 0, - UserWarning, - ), - ], -) -def test_Repeat(x, repeats, axis, exc): - g = extra_ops.Repeat(axis)(x, repeats) - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, axis, return_index, return_inverse, return_counts, exc", - [ - ( - set_test_value(at.lscalar(), np.array(1, dtype="int64")), - None, - False, - False, - False, - None, - ), - ( - set_test_value(at.lvector(), np.array([1, 1, 2], dtype="int64")), - None, - False, - False, - False, - None, - ), - ( - set_test_value(at.lmatrix(), np.array([[1, 1], [2, 2]], dtype="int64")), - None, - False, - False, - False, - None, - ), - ( - set_test_value( - at.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64") - ), - 0, - False, - False, - False, - UserWarning, - ), - ( - set_test_value( - at.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64") - ), - 0, - True, - True, - True, - UserWarning, - ), - ], -) -def test_Unique(x, axis, return_index, return_inverse, return_counts, exc): - g = extra_ops.Unique(return_index, return_inverse, return_counts, axis)(x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "arr, shape, order, exc", - [ - ( - set_test_value(at.lvector(), np.array([9, 15, 1], dtype="int64")), - at.as_tensor([2, 3, 4]), - "C", - None, - ), - ( - set_test_value(at.lvector(), np.array([1, 0], dtype="int64")), - at.as_tensor([2]), - "C", - None, - ), - ( - set_test_value(at.lvector(), np.array([9, 15, 1], dtype="int64")), - at.as_tensor([2, 3, 4]), - "F", - NotImplementedError, - ), - ], -) -def test_UnravelIndex(arr, shape, order, exc): - g = extra_ops.UnravelIndex(order)(arr, shape) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.raises(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "a, v, side, sorter, exc", - [ - ( - set_test_value(at.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)), - set_test_value(at.matrix(), rng.random((3, 2)).astype(config.floatX)), - "left", - None, - None, - ), - pytest.param( - set_test_value( - at.vector(), - np.array([0.29769574, 0.71649186, 0.20475563]).astype(config.floatX), - ), - set_test_value( - at.matrix(), - np.array( - [ - [0.18847123, 0.39659508], - [0.56220006, 0.57428752], - [0.86720994, 0.44522637], - ] - ).astype(config.floatX), - ), - "left", - None, - None, - marks=pytest.mark.xfail( - reason="This won't work until https://github.com/numba/numba/pull/7005 is merged" - ), - ), - ( - set_test_value(at.vector(), np.array([1.0, 2.0, 3.0], dtype=config.floatX)), - set_test_value(at.matrix(), rng.random((3, 2)).astype(config.floatX)), - "right", - set_test_value(at.lvector(), np.array([0, 2, 1])), - UserWarning, - ), - ], -) -def test_Searchsorted(a, v, side, sorter, exc): - g = extra_ops.SearchsortedOp(side)(a, v, sorter) - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, shape", - [ - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - [set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]], - ), - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - [at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)], - ), - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - at.as_tensor([set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]]), - ), - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - [at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)], - ), - ], -) -def test_BroadcastTo(x, shape): - g = extra_ops.BroadcastTo()(x, shape) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, y, exc", - [ - ( - set_test_value(at.matrix(), rng.random(size=(3, 2)).astype(config.floatX)), - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - None, - ), - ( - set_test_value( - at.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64") - ), - set_test_value( - at.vector(dtype="float32"), rng.random(size=(2,)).astype("float32") - ), - None, - ), - ( - set_test_value(at.lmatrix(), rng.poisson(size=(3, 2))), - set_test_value(at.fvector(), rng.random(size=(2,)).astype("float32")), - None, - ), - ( - set_test_value(at.lvector(), rng.random(size=(2,)).astype(np.int64)), - set_test_value(at.lvector(), rng.random(size=(2,)).astype(np.int64)), - None, - ), - ], -) -def test_Dot(x, y, exc): - g = aem.Dot()(x, y) - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "dy, sm, axis, exc", - [ - ( - set_test_value( - at.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) - ), - set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), - None, - None, - ), - ( - set_test_value( - at.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) - ), - set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), - 0, - None, - ), - ( - set_test_value( - at.matrix(), np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX) - ), - set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), - 1, - None, - ), - ], -) -def test_SoftmaxGrad(dy, sm, axis, exc): - g = nnetb.SoftmaxGrad(axis=axis)(dy, sm) - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, axis, exc", - [ - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - None, - None, - ), - ( - set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), - None, - None, - ), - ( - set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), - 0, - None, - ), - ], -) -def test_Softmax(x, axis, exc): - g = nnetb.Softmax(axis=axis)(x) - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, axis, exc", - [ - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - None, - None, - ), - ( - set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), - 0, - None, - ), - ( - set_test_value(at.matrix(), rng.random(size=(2, 3)).astype(config.floatX)), - 1, - None, - ), - ], -) -def test_LogSoftmax(x, axis, exc): - g = nnetb.LogSoftmax(axis=axis)(x) - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, exc", - [ - ( - set_test_value(aes.float64(), np.array(0.0, dtype="float64")), - None, - ), - ( - set_test_value(aes.float64(), np.array(-32.0, dtype="float64")), - None, - ), - ( - set_test_value(aes.float64(), np.array(-40.0, dtype="float64")), - None, - ), - ( - set_test_value(aes.float64(), np.array(32.0, dtype="float64")), - None, - ), - ( - set_test_value(aes.float64(), np.array(40.0, dtype="float64")), - None, - ), - ( - set_test_value(aes.int64(), np.array(32, dtype="int64")), - None, - ), - ], -) -def test_Softplus(x, exc): - g = aesm.Softplus(aes.upgrade_to_float)(x) - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, axes, exc", - [ - ( - set_test_value(at.dscalar(), np.array(0.0, dtype="float64")), - [], - None, - ), - ( - set_test_value(at.dvector(), rng.random(size=(3,)).astype("float64")), - [0], - None, - ), - ( - set_test_value(at.dmatrix(), rng.random(size=(3, 2)).astype("float64")), - [0], - None, - ), - ( - set_test_value(at.dmatrix(), rng.random(size=(3, 2)).astype("float64")), - [0, 1], - None, - ), - ], -) -def test_MaxAndArgmax(x, axes, exc): - g = aem.MaxAndArgmax(axes)(x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, lower, exc", - [ - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - True, - None, - ), - ( - set_test_value( - at.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - True, - None, - ), - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - False, - UserWarning, - ), - ], -) -def test_Cholesky(x, lower, exc): - g = slinalg.Cholesky(lower)(x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "A, x, lower, exc", - [ - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - set_test_value(at.dvector(), rng.random(size=(3,)).astype("float64")), - "gen", - None, - ), - ( - set_test_value( - at.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - set_test_value(at.dvector(), rng.random(size=(3,)).astype("float64")), - "gen", - None, - ), - ], -) -def test_Solve(A, x, lower, exc): - g = slinalg.Solve(lower)(A, x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "A, x, lower, exc", - [ - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - set_test_value(at.dvector(), rng.random(size=(3,)).astype("float64")), - "sym", - UserWarning, - ), - ], -) -def test_SolveTriangular(A, x, lower, exc): - g = slinalg.SolveTriangular(lower)(A, x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, exc", - [ - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - None, - ), - ( - set_test_value( - at.lmatrix(), - (lambda x: x.T.dot(x))(rng.poisson(size=(3, 3)).astype("int64")), - ), - None, - ), - ], -) -def test_Det(x, exc): - g = nlinalg.Det()(x) - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -# We were seeing some weird results in CI where the following two almost -# sign-swapped results were being return from Numba and Python, respectively. -# The issue might be related to https://github.com/numba/numba/issues/4519. -# Regardless, I was not able to reproduce anything like it locally after -# extensive testing. -x = np.array( - [ - [-0.60407637, -0.71177603, -0.35842241], - [-0.07735968, 0.50000561, -0.86256007], - [-0.7931628, 0.49332471, 0.35710434], - ], - dtype=np.float64, -) - -y = np.array( - [ - [0.60407637, 0.71177603, -0.35842241], - [0.07735968, -0.50000561, -0.86256007], - [0.7931628, -0.49332471, 0.35710434], - ], - dtype=np.float64, -) - - -@pytest.mark.parametrize( - "x, exc", - [ - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(x), - ), - None, - ), - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(y), - ), - None, - ), - ( - set_test_value( - at.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - None, - ), - ], -) -def test_Eig(x, exc): - g = nlinalg.Eig()(x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, uplo, exc", - [ - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - "L", - None, - ), - ( - set_test_value( - at.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - "U", - UserWarning, - ), - ], -) -def test_Eigh(x, uplo, exc): - g = nlinalg.Eigh(uplo)(x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "op, x, exc, op_args", - [ - ( - nlinalg.MatrixInverse, - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - None, - (), - ), - ( - nlinalg.MatrixInverse, - set_test_value( - at.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - None, - (), - ), - ( - nlinalg.Inv, - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - None, - (), - ), - ( - nlinalg.Inv, - set_test_value( - at.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - None, - (), - ), - ( - nlinalg.MatrixPinv, - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - None, - (True,), - ), - ( - nlinalg.MatrixPinv, - set_test_value( - at.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - None, - (False,), - ), - ], -) -def test_matrix_inverses(op, x, exc, op_args): - g = op(*op_args)(x) - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, mode, exc", - [ - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - "reduced", - None, - ), - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - "r", - None, - ), - ( - set_test_value( - at.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - "reduced", - None, - ), - ( - set_test_value( - at.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - "complete", - UserWarning, - ), - ], -) -def test_QRFull(x, mode, exc): - g = nlinalg.QRFull(mode)(x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, full_matrices, compute_uv, exc", - [ - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - True, - True, - None, - ), - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - False, - True, - None, - ), - ( - set_test_value( - at.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - True, - True, - None, - ), - ( - set_test_value( - at.lmatrix(), - (lambda x: x.T.dot(x))( - rng.integers(1, 10, size=(3, 3)).astype("int64") - ), - ), - True, - False, - UserWarning, - ), - ], -) -def test_SVD(x, full_matrices, compute_uv, exc): - g = nlinalg.SVD(full_matrices, compute_uv)(x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "x, y, exc", - [ - ( - set_test_value( - at.dmatrix(), - rng.random(size=(3, 3)).astype("float64"), - ), - set_test_value( - at.dmatrix(), - rng.random(size=(3, 3)).astype("float64"), - ), - None, - ), - ( - set_test_value( - at.dmatrix(), - rng.random(size=(3, 3)).astype("float64"), - ), - set_test_value( - at.lmatrix(), - rng.poisson(size=(3, 3)).astype("int64"), - ), - None, - ), - ], -) -def test_BatchedDot(x, y, exc): - g = blas.BatchedDot()(x, y) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -def test_shared(): - a = shared(np.array([1, 2, 3], dtype=config.floatX)) - - aesara_numba_fn = function([], a, mode="NUMBA") - numba_res = aesara_numba_fn() - - np.testing.assert_allclose(numba_res, a.get_value()) - - aesara_numba_fn = function([], a * 2, mode="NUMBA") - numba_res = aesara_numba_fn() - - np.testing.assert_allclose(numba_res, a.get_value() * 2) - - # Changed the shared value and make sure that the Numba-compiled function - # also changes. - new_a_value = np.array([3, 4, 5], dtype=config.floatX) - a.set_value(new_a_value) - - numba_res = aesara_numba_fn() - np.testing.assert_allclose(numba_res, new_a_value * 2) - - -@pytest.mark.parametrize( - "rv_op, dist_args, size", - [ - ( - aer.normal, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - at.as_tensor([3, 2]), - ), - ( - aer.uniform, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - at.as_tensor([3, 2]), - ), - ( - aer.triangular, - [ - set_test_value( - at.dscalar(), - np.array(-5.0, dtype=np.float64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - set_test_value( - at.dscalar(), - np.array(5.0, dtype=np.float64), - ), - ], - at.as_tensor([3, 2]), - ), - ( - aer.lognormal, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - at.as_tensor([3, 2]), - ), - pytest.param( - aer.pareto, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - ], - at.as_tensor([3, 2]), - marks=pytest.mark.xfail(reason="Not implemented"), - ), - ( - aer.exponential, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - ], - at.as_tensor([3, 2]), - ), - ( - aer.weibull, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - ], - at.as_tensor([3, 2]), - ), - ( - aer.logistic, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - at.as_tensor([3, 2]), - ), - ( - aer.geometric, - [ - set_test_value( - at.dvector(), - np.array([0.3, 0.4], dtype=np.float64), - ), - ], - at.as_tensor([3, 2]), - ), - ( - aer.hypergeometric, - [ - set_test_value( - at.lscalar(), - np.array(7, dtype=np.int64), - ), - set_test_value( - at.lscalar(), - np.array(8, dtype=np.int64), - ), - set_test_value( - at.lscalar(), - np.array(15, dtype=np.int64), - ), - ], - at.as_tensor([3, 2]), - ), - ( - aer.wald, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - at.as_tensor([3, 2]), - ), - ( - aer.laplace, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - at.as_tensor([3, 2]), - ), - ( - aer.binomial, - [ - set_test_value( - at.lvector(), - np.array([1, 2], dtype=np.int64), - ), - set_test_value( - at.dscalar(), - np.array(0.9, dtype=np.float64), - ), - ], - at.as_tensor([3, 2]), - ), - ( - aer.normal, - [ - set_test_value( - at.lvector(), - np.array([1, 2], dtype=np.int64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - at.as_tensor(tuple(set_test_value(at.lscalar(), v) for v in [3, 2])), - ), - ( - aer.poisson, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - ], - None, - ), - ( - aer.halfnormal, - [ - set_test_value( - at.lvector(), - np.array([1, 2], dtype=np.int64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - None, - ), - ( - aer.bernoulli, - [ - set_test_value( - at.dvector(), - np.array([0.1, 0.9], dtype=np.float64), - ), - ], - None, - ), - ( - aer.randint, - [ - set_test_value( - at.lscalar(), - np.array(0, dtype=np.int64), - ), - set_test_value( - at.lscalar(), - np.array(5, dtype=np.int64), - ), - ], - at.as_tensor([3, 2]), - ), - pytest.param( - aer.multivariate_normal, - [ - set_test_value( - at.dmatrix(), - np.array([[1, 2], [3, 4]], dtype=np.float64), - ), - set_test_value( - at.tensor("float64", [True, False, False]), - np.eye(2)[None, ...], - ), - ], - at.as_tensor(tuple(set_test_value(at.lscalar(), v) for v in [4, 3, 2])), - marks=pytest.mark.xfail(reason="Not implemented"), - ), - ], - ids=str, -) -def test_aligned_RandomVariable(rv_op, dist_args, size): - """Tests for Numba samplers that are one-to-one with Aesara's/NumPy's samplers.""" - rng = shared(np.random.RandomState(29402)) - g = rv_op(*dist_args, size=size, rng=rng) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "rv_op, dist_args, base_size, cdf_name, params_conv", - [ - ( - aer.beta, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - (2,), - "beta", - lambda *args: args, - ), - ( - aer.gamma, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - (2,), - "gamma", - lambda a, b: (a, 0.0, b), - ), - ( - aer.cauchy, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - (2,), - "cauchy", - lambda *args: args, - ), - ( - aer.chisquare, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ) - ], - (2,), - "chi2", - lambda *args: args, - ), - ( - aer.gumbel, - [ - set_test_value( - at.dvector(), - np.array([1.0, 2.0], dtype=np.float64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - (2,), - "gumbel_r", - lambda *args: args, - ), - ( - aer.negative_binomial, - [ - set_test_value( - at.lvector(), - np.array([100, 200], dtype=np.int64), - ), - set_test_value( - at.dscalar(), - np.array(0.09, dtype=np.float64), - ), - ], - (2,), - "nbinom", - lambda *args: args, - ), - pytest.param( - aer.vonmises, - [ - set_test_value( - at.dvector(), - np.array([-0.5, 0.5], dtype=np.float64), - ), - set_test_value( - at.dscalar(), - np.array(1.0, dtype=np.float64), - ), - ], - (2,), - "vonmises_line", - lambda mu, kappa: (kappa, mu), - marks=pytest.mark.xfail( - reason=( - "Numba's parameterization of `vonmises` does not match NumPy's." - "See https://github.com/numba/numba/issues/7886" - ) - ), - ), - ], -) -def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv): - """Tests for Numba samplers that are not one-to-one with Aesara's/NumPy's samplers.""" - rng = shared(np.random.RandomState(29402)) - g = rv_op(*dist_args, size=(2000,) + base_size, rng=rng) - g_fn = function(dist_args, g, mode=numba_mode) - samples = g_fn( - *[ - i.tag.test_value - for i in g_fn.maker.fgraph.inputs - if not isinstance(i, (SharedVariable, Constant)) - ] - ) - - bcast_dist_args = np.broadcast_arrays(*[i.tag.test_value for i in dist_args]) - - for idx in np.ndindex(*base_size): - cdf_params = params_conv(*tuple(arg[idx] for arg in bcast_dist_args)) - test_res = stats.cramervonmises( - samples[(Ellipsis,) + idx], cdf_name, args=cdf_params - ) - assert test_res.pvalue > 0.1 - - -@pytest.mark.parametrize( - "dist_args, size, cm", - [ - pytest.param( - [ - set_test_value( - at.dvector(), - np.array([100000, 1, 1], dtype=np.float64), - ), - ], - None, - contextlib.suppress(), - ), - pytest.param( - [ - set_test_value( - at.dmatrix(), - np.array( - [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], - dtype=np.float64, - ), - ), - ], - (10, 3), - contextlib.suppress(), - ), - pytest.param( - [ - set_test_value( - at.dmatrix(), - np.array( - [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], - dtype=np.float64, - ), - ), - ], - (10, 4), - pytest.raises(ValueError, match="Parameters shape.*"), - ), - ], -) -def test_CategoricalRV(dist_args, size, cm): - rng = shared(np.random.RandomState(29402)) - g = aer.categorical(*dist_args, size=size, rng=rng) - g_fg = FunctionGraph(outputs=[g]) - - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -@pytest.mark.parametrize( - "a, size, cm", - [ - pytest.param( - set_test_value( - at.dvector(), - np.array([100000, 1, 1], dtype=np.float64), - ), - None, - contextlib.suppress(), - ), - pytest.param( - set_test_value( - at.dmatrix(), - np.array( - [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], - dtype=np.float64, - ), - ), - (10, 3), - contextlib.suppress(), - ), - pytest.param( - set_test_value( - at.dmatrix(), - np.array( - [[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], - dtype=np.float64, - ), - ), - (10, 4), - pytest.raises(ValueError, match="Parameters shape.*"), - ), - ], -) -def test_DirichletRV(a, size, cm): - rng = shared(np.random.RandomState(29402)) - g = aer.dirichlet(a, size=size, rng=rng) - g_fn = function([a], g, mode=numba_mode) - - with cm: - a_val = a.tag.test_value - - # For coverage purposes only... - eval_python_only([a], FunctionGraph(outputs=[g], clone=False), [a_val]) - - all_samples = [] - for i in range(1000): - samples = g_fn(a_val) - all_samples.append(samples) - - exp_res = a_val / a_val.sum(-1) - res = np.mean(all_samples, axis=tuple(range(0, a_val.ndim - 1))) - assert np.allclose(res, exp_res, atol=1e-4) - - -def test_RandomState_updates(): - rng = shared(np.random.RandomState(1)) - rng_new = shared(np.random.RandomState(2)) - - x = at.random.normal(size=10, rng=rng) - res = function([], x, updates={rng: rng_new}, mode=numba_mode)() - - ref = np.random.RandomState(2).normal(size=10) - assert np.allclose(res, ref) - - -def test_random_Generator(): - rng = shared(np.random.default_rng(29402)) - g = aer.normal(rng=rng) - g_fg = FunctionGraph(outputs=[g]) - - with pytest.raises(TypeError): - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - -def test_scan_multiple_output(): - """Test a scan implementation of a SEIR model. - - SEIR model definition: - S[t+1] = S[t] - B[t] - E[t+1] = E[t] +B[t] - C[t] - I[t+1] = I[t+1] + C[t] - D[t] - - B[t] ~ Binom(S[t], beta) - C[t] ~ Binom(E[t], gamma) - D[t] ~ Binom(I[t], delta) - """ - - def binomln(n, k): - return at.exp(n + 1) - at.exp(k + 1) - at.exp(n - k + 1) - - def binom_log_prob(n, p, value): - return binomln(n, value) + value * at.exp(p) + (n - value) * at.exp(1 - p) - - # sequences - at_C = at.ivector("C_t") - at_D = at.ivector("D_t") - # outputs_info (initial conditions) - st0 = at.lscalar("s_t0") - et0 = at.lscalar("e_t0") - it0 = at.lscalar("i_t0") - logp_c = at.scalar("logp_c") - logp_d = at.scalar("logp_d") - # non_sequences - beta = at.scalar("beta") - gamma = at.scalar("gamma") - delta = at.scalar("delta") - - def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): - bt0 = st0 * beta - bt0 = bt0.astype(st0.dtype) - - logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype) - logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype) - - st1 = st0 - bt0 - et1 = et0 + bt0 - ct0 - it1 = it0 + ct0 - dt0 - return st1, et1, it1, logp_c1, logp_d1 - - (st, et, it, logp_c_all, logp_d_all), _ = scan( - fn=seir_one_step, - sequences=[at_C, at_D], - outputs_info=[st0, et0, it0, logp_c, logp_d], - non_sequences=[beta, gamma, delta], - ) - st.name = "S_t" - et.name = "E_t" - it.name = "I_t" - logp_c_all.name = "C_t_logp" - logp_d_all.name = "D_t_logp" - - out_fg = FunctionGraph( - [at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta], - [st, et, it, logp_c_all, logp_d_all], - ) - - s0, e0, i0 = 100, 50, 25 - logp_c0 = np.array(0.0, dtype=config.floatX) - logp_d0 = np.array(0.0, dtype=config.floatX) - beta_val, gamma_val, delta_val = [ - np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753] - ] - C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32) - D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32) - - test_input_vals = [ - C, - D, - s0, - e0, - i0, - logp_c0, - logp_d0, - beta_val, - gamma_val, - delta_val, - ] - compare_numba_and_py(out_fg, test_input_vals) - - -@config.change_flags(compute_test_value="raise") -def test_scan_tap_output(): - - a_at = at.scalar("a") - a_at.tag.test_value = 10.0 - - b_at = at.arange(11).astype(config.floatX) - b_at.name = "b" - - c_at = at.arange(20, 31, dtype=config.floatX) - c_at.name = "c" - - def input_step_fn(b, b2, c, x_tm1, y_tm1, y_tm3, a): - x_tm1.name = "x_tm1" - y_tm1.name = "y_tm1" - y_tm3.name = "y_tm3" - y_t = (y_tm1 + y_tm3) * a + b + b2 - z_t = y_t * c - x_t = x_tm1 + 1 - x_t.name = "x_t" - y_t.name = "y_t" - return x_t, y_t, at.fill((10,), z_t) - - scan_res, _ = scan( - fn=input_step_fn, - sequences=[ - { - "input": b_at, - "taps": [-1, -2], - }, - { - "input": c_at, - "taps": [-2], - }, - ], - outputs_info=[ - { - "initial": at.as_tensor_variable(0.0, dtype=config.floatX), - "taps": [-1], - }, - { - "initial": at.as_tensor_variable( - np.r_[-1.0, 1.3, 0.0].astype(config.floatX) - ), - "taps": [-1, -3], - }, - None, - ], - non_sequences=[a_at], - n_steps=5, - name="yz_scan", - strict=True, - ) - - out_fg = FunctionGraph([a_at, b_at, c_at], scan_res) - - test_input_vals = [ - np.array(10.0).astype(config.floatX), - np.arange(11, dtype=config.floatX), - np.arange(20, 31, dtype=config.floatX), - ] - compare_numba_and_py(out_fg, test_input_vals) - - -def test_scan_while(): - def power_of_2(previous_power, max_value): - return previous_power * 2, until(previous_power * 2 > max_value) - - max_value = at.scalar() - values, _ = scan( - power_of_2, - outputs_info=at.constant(1.0), - non_sequences=max_value, - n_steps=1024, - ) - - out_fg = FunctionGraph([max_value], [values]) - - test_input_vals = [ - np.array(45).astype(config.floatX), - ] - compare_numba_and_py(out_fg, test_input_vals) - - -@pytest.mark.parametrize( - "inputs, cond_fn, true_vals, false_vals", - [ - ([], lambda: np.array(True), np.r_[1, 2, 3], np.r_[-1, -2, -3]), - ( - [set_test_value(at.dscalar(), np.array(0.2, dtype=np.float64))], - lambda x: x < 0.5, - np.r_[1, 2, 3], - np.r_[-1, -2, -3], - ), - ( - [ - set_test_value(at.dscalar(), np.array(0.3, dtype=np.float64)), - set_test_value(at.dscalar(), np.array(0.5, dtype=np.float64)), - ], - lambda x, y: x > y, - x, - y, - ), - ( - [ - set_test_value(at.dvector(), np.array([0.3, 0.1], dtype=np.float64)), - set_test_value(at.dvector(), np.array([0.5, 0.9], dtype=np.float64)), - ], - lambda x, y: at.all(x > y), - x, - y, - ), - ( - [ - set_test_value(at.dvector(), np.array([0.3, 0.1], dtype=np.float64)), - set_test_value(at.dvector(), np.array([0.5, 0.9], dtype=np.float64)), - ], - lambda x, y: at.all(x > y), - [x, 2 * x], - [y, 3 * y], - ), - ( - [ - set_test_value(at.dvector(), np.array([0.5, 0.9], dtype=np.float64)), - set_test_value(at.dvector(), np.array([0.3, 0.1], dtype=np.float64)), - ], - lambda x, y: at.all(x > y), - [x, 2 * x], - [y, 3 * y], - ), - ], -) -def test_IfElse(inputs, cond_fn, true_vals, false_vals): - - out = ifelse(cond_fn(*inputs), true_vals, false_vals) - - if not isinstance(out, list): - out = [out] - - out_fg = FunctionGraph(inputs, out) - - compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs]) - - -@pytest.mark.xfail(reason="https://github.com/numba/numba/issues/7409") -def test_config_options_parallel(): - x = at.dvector() - - with config.change_flags(numba__vectorize_target="parallel"): - aesara_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"] - assert numba_mul_fn.targetoptions["parallel"] is True - - -def test_config_options_fastmath(): - x = at.dvector() - - with config.change_flags(numba__fastmath=True): - aesara_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"] - assert numba_mul_fn.targetoptions["fastmath"] is True - - -def test_config_options_cached(): - x = at.dvector() - - with config.change_flags(numba__cache=True): - aesara_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"] - assert not isinstance( - numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache - ) - - with config.change_flags(numba__cache=False): - aesara_numba_fn = function([x], x * 2, mode=numba_mode) - numba_mul_fn = aesara_numba_fn.vm.jit_fn.py_func.__globals__["mul"] - assert isinstance(numba_mul_fn._dispatcher.cache, numba.core.caching.NullCache) - - -def test_scalar_return_value_conversion(): - r"""Make sure that we convert \"native\" scalars to `ndarray`\s in the graph outputs.""" - x = at.scalar(name="x") - x_fn = function( - [x], - 2 * x, - mode=numba_mode, - ) - assert isinstance(x_fn(1.0), np.ndarray) diff --git a/tests/sandbox/test_multinomial_wo_replacement.py b/tests/sandbox/test_multinomial_wo_replacement.py index b03e1567f7..a40952c446 100644 --- a/tests/sandbox/test_multinomial_wo_replacement.py +++ b/tests/sandbox/test_multinomial_wo_replacement.py @@ -157,7 +157,7 @@ def test_select_distinct(self): p = fmatrix() n = iscalar() - with pytest.warns(DeprecationWarning): + with pytest.deprecated_call(): m = th_rng.multinomial_wo_replacement(pvals=p, n=n) f = function([p, n], m, allow_input_downcast=True) @@ -181,7 +181,7 @@ def test_fail_select_alot(self): p = fmatrix() n = iscalar() - with pytest.warns(DeprecationWarning): + with pytest.deprecated_call(): m = th_rng.multinomial_wo_replacement(pvals=p, n=n) f = function([p, n], m, allow_input_downcast=True) diff --git a/tests/sandbox/test_rng_mrg.py b/tests/sandbox/test_rng_mrg.py index 5c61fc974a..4d46a924de 100644 --- a/tests/sandbox/test_rng_mrg.py +++ b/tests/sandbox/test_rng_mrg.py @@ -1,3 +1,4 @@ +import contextlib import os import sys import time @@ -332,12 +333,20 @@ def test_broadcastable(): # the sizes of them are implicitly defined with "pvals" argument. if distribution in [R.multinomial, R.multinomial_wo_replacement]: # check when all dimensions are constant - uu = distribution(pvals=pvals_1) - assert uu.broadcastable == (False, True) + context_mgr = ( + pytest.deprecated_call() + if distribution == R.multinomial_wo_replacement + else contextlib.suppress() + ) + + with context_mgr: + uu = distribution(pvals=pvals_1) + assert uu.broadcastable == (False, True) # check when some dimensions are aesara variables - uu = distribution(pvals=pvals_2) - assert uu.broadcastable == (False, True) + with context_mgr: + uu = distribution(pvals=pvals_2) + assert uu.broadcastable == (False, True) else: # check when all dimensions are constant uu = distribution(size=size1) @@ -1109,9 +1118,10 @@ def basic_target_parameter_test(x): basic_target_parameter_test( srng.choice(p=pvals.astype("float32"), replace=False, target="cpu") ) - basic_target_parameter_test( - srng.multinomial_wo_replacement(pvals=pvals.astype("float32"), target="cpu") - ) + with pytest.deprecated_call(): + basic_target_parameter_test( + srng.multinomial_wo_replacement(pvals=pvals.astype("float32"), target="cpu") + ) @config.change_flags(compute_test_value="off") diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 255fb5cbaa..8609d9e843 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -488,3 +488,17 @@ def test_mean(mode): z = mean() z_fn = aesara.function([], z, mode=mode) assert z_fn() == 0 + + +def test_shape(): + a = float32("a") + assert isinstance(a.type, ScalarType) + assert a.shape.type.ndim == 1 + assert a.shape.type.shape == (0,) + assert a.shape.type.dtype == "int64" + + b = constant(2, name="b") + assert isinstance(b.type, ScalarType) + assert b.shape.type.ndim == 1 + assert b.shape.type.shape == (0,) + assert b.shape.type.dtype == "int64" diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index e0bdb9e7d7..a378c2f932 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -32,7 +32,7 @@ from aesara.graph.basic import Apply, ancestors, equal_computations from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op -from aesara.graph.opt import MergeOptimizer +from aesara.graph.rewriting.basic import MergeOptimizer from aesara.graph.utils import MissingInputError from aesara.misc.safe_asarray import _asarray from aesara.raise_op import assert_op @@ -824,7 +824,7 @@ def test_can_merge(self): assert scan_c is not scan_a g = FunctionGraph([x, y, c], [2 * scan_a, 2 * scan_b, 2 * scan_c], clone=False) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) scan_a_out, scan_b_out, scan_c_out = g.outputs assert scan_a_out is scan_b_out diff --git a/tests/scan/test_printing.py b/tests/scan/test_printing.py index 4046555563..8ff4175147 100644 --- a/tests/scan/test_printing.py +++ b/tests/scan/test_printing.py @@ -618,7 +618,7 @@ def no_shared_fn(n, x_tm1, M): forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0) >Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0) > |TensorConstant{0} [id J] - > |Subtensor{int64, int64, int64} [id K] + > |Subtensor{int64, int64, uint8} [id K] > | |*2- [id L] -> [id H] (inner_in_non_seqs-0) > | |ScalarFromTensor [id M] > | | |*0- [id N] -> [id C] (inner_in_seqs-0) diff --git a/tests/scan/test_opt.py b/tests/scan/test_rewriting.py similarity index 99% rename from tests/scan/test_opt.py rename to tests/scan/test_rewriting.py index 40ac651bf2..eacde13a66 100644 --- a/tests/scan/test_opt.py +++ b/tests/scan/test_rewriting.py @@ -12,7 +12,7 @@ from aesara.graph.basic import clone_replace, equal_computations from aesara.graph.fg import FunctionGraph from aesara.scan.op import Scan -from aesara.scan.opt import ScanInplaceOptimizer, ScanMerge +from aesara.scan.rewriting import ScanInplaceOptimizer, ScanMerge from aesara.scan.utils import until from aesara.tensor.blas import Dot22 from aesara.tensor.elemwise import Elemwise diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 22923c9403..a3fd87ec47 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -84,7 +84,12 @@ _is_sparse_variable, _mtypes, ) -from aesara.sparse.opt import CSMGradC, StructuredDotCSC, UsmmCscDense +from aesara.sparse.rewriting import ( + AddSD_ccode, + CSMGradC, + StructuredDotCSC, + UsmmCscDense, +) from aesara.tensor.basic import MakeVector from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.math import sum as at_sum @@ -491,7 +496,7 @@ def test_add_sd(self): sp.sparse.csr_matrix(random_lil((10, 40), config.floatX, 3)), np.random.standard_normal((10, 40)).astype(config.floatX), ], - (AddSD, sparse.opt.AddSD_ccode), + (AddSD, AddSD_ccode), ) def test_mul_ss(self): diff --git a/tests/sparse/test_opt.py b/tests/sparse/test_rewriting.py similarity index 95% rename from tests/sparse/test_opt.py rename to tests/sparse/test_rewriting.py index 08525e2484..4cc29894bb 100644 --- a/tests/sparse/test_opt.py +++ b/tests/sparse/test_rewriting.py @@ -1,14 +1,12 @@ -import pytest - - -sp = pytest.importorskip("scipy", minversion="0.7.0") - import numpy as np +import pytest +import scipy as sp import aesara from aesara import sparse from aesara.compile.mode import Mode, get_default_mode from aesara.configdefaults import config +from aesara.sparse.rewriting import SamplingDotCSR, sd_csc from aesara.tensor.basic import as_tensor_variable from aesara.tensor.math import sum as at_sum from aesara.tensor.type import ivector, matrix, vector @@ -38,7 +36,7 @@ def test_local_csm_properties_csm(): f(v.data, v.indices, v.indptr, v.shape) -@pytest.mark.skip(reason="Opt disabled as it don't support unsorted indices") +@pytest.mark.skip(reason="Rewrite disabled as it don't support unsorted indices") @pytest.mark.skipif( not aesara.config.cxx, reason="G++ not available, so we need to skip this test." ) @@ -143,7 +141,7 @@ def test_local_sampling_dot_csr(): # SamplingDotCSR's C implementation needs blas, so it should not # be inserted assert not any( - isinstance(node.op, sparse.opt.SamplingDotCSR) + isinstance(node.op, SamplingDotCSR) for node in f.maker.fgraph.toposort() ) @@ -174,6 +172,6 @@ def test_sd_csc(): nrows = as_tensor_variable(np.int32(A.shape[0])) b = as_tensor_variable(b) - res = aesara.sparse.opt.sd_csc(a_val, a_ind, a_ptr, nrows, b).eval() + res = sd_csc(a_val, a_ind, a_ptr, nrows, b).eval() utt.assert_allclose(res, target) diff --git a/tests/sparse/test_type.py b/tests/sparse/test_type.py index eb6ce331ce..5843e9c938 100644 --- a/tests/sparse/test_type.py +++ b/tests/sparse/test_type.py @@ -1,21 +1,52 @@ import pytest +import scipy as sp from aesara.sparse import matrix as sp_matrix from aesara.sparse.type import SparseTensorType from aesara.tensor import dmatrix -def test_clone(): - st = SparseTensorType("csr", "float64") +def test_SparseTensorType_constructor(): + st = SparseTensorType("csc", "float64") + assert st.format == "csc" + assert st.shape == (None, None) + + st = SparseTensorType("bsr", "float64", shape=(None, 1)) + assert st.format == "bsr" + assert st.shape == (None, 1) + + with pytest.raises(ValueError): + SparseTensorType("blah", "float64") + + +def test_SparseTensorType_clone(): + st = SparseTensorType("csr", "float64", shape=(3, None)) assert st == st.clone() + st_clone = st.clone(format="csc") + assert st_clone.format == "csc" + assert st_clone.dtype == st.dtype + assert st_clone.shape == st.shape + + st_clone = st.clone(shape=(2, 1)) + assert st_clone.format == st.format + assert st_clone.dtype == st.dtype + assert st_clone.shape == (2, 1) -def test_Sparse_convert_variable(): + +def test_SparseTensorType_convert_variable(): x = dmatrix(name="x") y = sp_matrix("csc", dtype="float64", name="y") z = sp_matrix("csr", dtype="float64", name="z") assert y.type.convert_variable(z) is None + assert z.type.convert_variable(y) is None + + res = y.type.convert_variable(x) + assert res.type == y.type + + res = z.type.convert_variable(x) + assert res.type == z.type # TODO FIXME: This is a questionable result, because `x.type` is associated # with a dense `Type`, but, since `TensorType` is a base class of `Sparse`, @@ -23,6 +54,30 @@ def test_Sparse_convert_variable(): # want to do that. assert x.type.convert_variable(y) is y - # TODO FIXME: We should be able to do this. - with pytest.raises(NotImplementedError): - y.type.convert_variable(x) + +def test_SparseTensorType_filter(): + y = sp_matrix("csc", dtype="float64", name="y") + z = sp_matrix("csr", dtype="float64", name="z") + w = sp_matrix("csr", dtype="float32", name="z") + + with pytest.raises(TypeError, match="Expected an array-like"): + y.type.filter(dmatrix()) + + x = sp.sparse.csc_matrix(sp.sparse.eye(5, 3)) + x_res = y.type.filter(x) + assert x is x_res + + x_res = z.type.filter(x) + assert x_res.format == "csr" + + with pytest.raises(TypeError): + x_res = z.type.filter(x, strict=True) + + x_res = w.type.filter(x, allow_downcast=True) + assert x_res.dtype == "float32" + + x_res = z.type.filter(x.astype("float32"), allow_downcast=True) + assert x_res.dtype == "float64" + + with pytest.raises(TypeError, match=".*dtype but got.*"): + w.type.filter(x) diff --git a/tests/tensor/nnet/test_abstract_conv.py b/tests/tensor/nnet/test_abstract_conv.py index b306c8b274..31a3df7aa3 100644 --- a/tests/tensor/nnet/test_abstract_conv.py +++ b/tests/tensor/nnet/test_abstract_conv.py @@ -5,7 +5,7 @@ import aesara.tensor as at from aesara.compile.mode import Mode from aesara.configdefaults import config -from aesara.graph.opt import check_stack_trace +from aesara.graph.rewriting.basic import check_stack_trace from aesara.tensor.nnet import abstract_conv as conv from aesara.tensor.nnet import conv2d_transpose, corr, corr3d from aesara.tensor.nnet.abstract_conv import ( diff --git a/tests/tensor/nnet/test_basic.py b/tests/tensor/nnet/test_basic.py index 278325b488..f96dd8f5ef 100644 --- a/tests/tensor/nnet/test_basic.py +++ b/tests/tensor/nnet/test_basic.py @@ -10,7 +10,7 @@ from aesara.configdefaults import config from aesara.gradient import grad from aesara.graph.fg import FunctionGraph -from aesara.graph.opt import check_stack_trace +from aesara.graph.rewriting.basic import check_stack_trace from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.math import ( Argmax, @@ -175,9 +175,11 @@ def f(a, b): utt.verify_grad(f, [rng.random((3, 4)), rng.random((4))]) def test_broadcast(self): - # test that we don't raise an error during optimization for no good - # reason as softmax_with_bias don't support correctly some/all - # broadcasted inputs pattern + """ + Test that we don't raise an error during rewriting for no good reason + as `softmax_with_bias` don't support correctly some/all broadcasted + inputs pattern. + """ initial_W = np.asarray( [[0.1, 0.1, 0.1], [0.1, 0.1, 0.1], [0.1, 0.1, 0.1]], dtype=config.floatX, @@ -240,7 +242,7 @@ def f(a): rng = np.random.default_rng(utt.fetch_seed()) utt.verify_grad(f, [rng.random((4,))]) - def test_matrix_perform_and_opt(self): + def test_matrix_perform_and_rewrite(self): m = config.mode m = aesara.compile.get_mode(m) m.check_isfinite = False @@ -280,11 +282,12 @@ def test_matrix_perform_and_opt(self): assert not np.any(np.isnan(grad_)) @pytest.mark.parametrize("axis", [None, 0, -1]) - def test_local_logsoftmax_opt(self, axis): - # Test the Logsoftmax substitution - # - # Check that Log(Softmax(x)) is substituted with Logsoftmax(x). Note that - # only the forward pass is checked (i.e., doesn't check the gradient) + def test_local_logsoftmax_rewrite(self, axis): + """Test the `Logsoftmax` substitution. + + Check that ``Log(Softmax(x))`` is substituted with ``Logsoftmax(x)``. Note that + only the forward pass is checked (i.e., doesn't check the gradient) + """ x = matrix("x") sm = softmax(x, axis=axis) @@ -294,18 +297,19 @@ def test_local_logsoftmax_opt(self, axis): assert check_stack_trace(f, ops_to_check=LogSoftmax) @pytest.mark.parametrize("axis", [None, 0, -1]) - def test_local_logsoftmax_grad_opt(self, axis): - # Test the Logsoftmax's grad substitution. - # - # Check that Log(Softmax(x))'s grad is substituted with Logsoftmax(x)'s - # grad and that the new operation does not explode for big inputs. - # Note that only the grad is checked. + def test_local_logsoftmax_grad_rewrite(self, axis): + """Test the `Logsoftmax`'s grad substitution. + + Check that ``Log(Softmax(x))``'s grad is substituted with ``Logsoftmax(x)``'s + grad and that the new operation does not explode for big inputs. + Note that only the grad is checked. + """ m = config.mode m = aesara.compile.get_mode(m) m.check_isfinite = False # some inputs that are large to make the gradient explode in the non - # optimized case + # rewritten case rng = np.random.default_rng(utt.fetch_seed()) a = np.exp(10 * rng.random((5, 10)).astype(config.floatX)) @@ -321,9 +325,10 @@ def myfunc(x): assert check_stack_trace(f, ops_to_check="all") def test_logsoftmax_grad_true_div_elemwise(self): - # Checks that the gradient of an expression similar to a log(softmax) - # but with a different elemwise operation than true_div is not - # optimized. + """ + Checks that the gradient of an expression similar to a ``log(softmax)`` but + with a different elemwise operation than true_div is not rewritten. + """ x = matrix("x") y = log(softmax(x)) @@ -340,7 +345,7 @@ def test_logsoftmax_grad_true_div_elemwise(self): ) fgraph = FunctionGraph([x], [new_g]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) assert softmax_grad_legacy in [n.op for n in fgraph.toposort()] @@ -638,7 +643,7 @@ def test_infer_shape(self): CrossentropyCategorical1Hot, ) - def test_softmax_optimizations(self): + def test_softmax_rewrites(self): x = matrix("x") one_of_n = lvector("one_of_n") op = crossentropy_categorical_1hot @@ -647,10 +652,10 @@ def test_softmax_optimizations(self): fgraph = FunctionGraph([x, one_of_n], [op(softmax_legacy(x), one_of_n)]) assert fgraph.outputs[0].owner.op == op - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias - def test_softmax_optimizations_w_bias(self): + def test_softmax_rewrites_w_bias(self): x = matrix("x") b = vector("b") one_of_n = lvector("one_of_n") @@ -659,12 +664,12 @@ def test_softmax_optimizations_w_bias(self): fgraph = FunctionGraph([x, b, one_of_n], [op(softmax_legacy(x + b), one_of_n)]) assert fgraph.outputs[0].owner.op == op - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) assert len(fgraph.toposort()) == 1 assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias - def test_softmax_optimizations_w_bias2(self): + def test_softmax_rewrites_w_bias2(self): x = matrix("x") b = vector("b") c = vector("c") @@ -676,12 +681,12 @@ def test_softmax_optimizations_w_bias2(self): ) assert fgraph.outputs[0].owner.op == op - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) assert len(fgraph.toposort()) == 2 assert fgraph.outputs[0].owner.op == crossentropy_softmax_argmax_1hot_with_bias - def test_softmax_grad_optimizations(self): + def test_softmax_grad_rewrites(self): x = matrix("x") one_of_n = lvector("one_of_n") op = crossentropy_categorical_1hot @@ -694,7 +699,7 @@ def test_softmax_grad_optimizations(self): ops_to_check=[crossentropy_softmax_1hot_with_bias_dx, softmax_legacy], ) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) ops = {node.op for node in fgraph.toposort()} assert crossentropy_softmax_argmax_1hot_with_bias not in ops @@ -717,7 +722,7 @@ def test_get_rid_of_advanced_indexing_version_of_xent(self): for expr in expressions: fgraph = FunctionGraph([x, y], [expr]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 4 @@ -726,7 +731,7 @@ def test_get_rid_of_advanced_indexing_version_of_xent(self): # Also verify the gradient wrt x fgraph = FunctionGraph([x, y], [grad(expr, x)]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 2 @@ -734,7 +739,7 @@ def test_get_rid_of_advanced_indexing_version_of_xent(self): assert softmax_legacy in ops assert softmax_grad_legacy not in ops - # Test that a biased softmax is optimized correctly + # Test that a biased softmax is rewritten correctly bias_expressions = [ at_sum(-log(softmax(x + b)[at.arange(y.shape[0]), y])), -at_sum(log(softmax(b + x)[at.arange(y.shape[0]), y])), @@ -744,14 +749,14 @@ def test_get_rid_of_advanced_indexing_version_of_xent(self): for expr in bias_expressions: fgraph = FunctionGraph([x, b, y], [expr, x]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 2 # [big_op, sum] assert crossentropy_softmax_argmax_1hot_with_bias in ops fgraph = FunctionGraph([x, b, y], [grad(expr, x)]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 2 @@ -770,7 +775,7 @@ def test_get_rid_of_advanced_indexing_version_of_xent(self): for expr in mean_expressions: fgraph = FunctionGraph([x, y], [expr]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 6 @@ -778,7 +783,7 @@ def test_get_rid_of_advanced_indexing_version_of_xent(self): assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)] fgraph = FunctionGraph([x, y], [grad(expr, x)]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 5 @@ -798,7 +803,7 @@ def test_get_rid_of_advanced_indexing_version_of_xent(self): for expr in mean_bias_expressions: fgraph = FunctionGraph([x, b, y], [expr]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 4 @@ -806,7 +811,7 @@ def test_get_rid_of_advanced_indexing_version_of_xent(self): assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)] fgraph = FunctionGraph([x, b, y], [grad(expr, x)]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 5 @@ -827,7 +832,7 @@ def test_xent_thing_int32(self): for expr in expressions: fgraph = FunctionGraph([x, y], [expr]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 5 @@ -836,7 +841,7 @@ def test_xent_thing_int32(self): # Also verify the gradient wrt x fgraph = FunctionGraph([x, y], [grad(expr, x)]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) ops = [node.op for node in fgraph.toposort()] assert len(ops) == 3 @@ -888,7 +893,7 @@ def validate_grad_graph(func): for expr in expressions: fgraph = FunctionGraph([x, y, a], [expr]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) assert 5 <= len(fgraph.toposort()) <= 10 @@ -898,7 +903,7 @@ def validate_grad_graph(func): # Verify the gradient wrt x fgraph = FunctionGraph([x, y, a], [grad(expr, x)]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) assert 3 <= len(fgraph.toposort()) <= 6 @@ -911,7 +916,7 @@ def validate_grad_graph(func): fgraph = FunctionGraph( [x, y, a], [grad(expr, x, known_grads={expr: a * x.sum()})] ) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) assert 6 <= len(fgraph.toposort()) <= 8 @@ -927,7 +932,7 @@ def test_argmax_pushdown(): # test that the max_and_argmax is pushed down if the max is not used out = max_and_argmax(sm(exp(tanh(sigmoid(x)))), axis=-1)[1] fgraph = FunctionGraph([x], [out]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) # print 'AFTER' # for node in fgraph.toposort(): @@ -942,7 +947,7 @@ def test_argmax_pushdown(): assert hasattr(fgraph.outputs[0].tag, "trace") - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) # print 'AFTER' # for node in fgraph.toposort(): @@ -963,7 +968,7 @@ def test_argmax_pushdown_bias(): out = argmax(softmax_with_bias(x, b), axis=-1) fgraph = FunctionGraph([x, b], [out]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) types_to_check = (DimShuffle, Elemwise, Argmax) assert len(fgraph.toposort()) == 3 @@ -977,7 +982,7 @@ def test_argmax_pushdown_bias(): out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0] fgraph = FunctionGraph([x, b], [out]) - optdb.query(OPT_FAST_RUN).optimize(fgraph) + optdb.query(OPT_FAST_RUN).rewrite(fgraph) assert len(fgraph.toposort()) == 2 assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias) @@ -987,10 +992,9 @@ def test_argmax_pushdown_bias(): def test_asymptotic_32(): - # This test makes sure that our functions behave sensibly when - # huge values are present + """Test that our functions behave sensibly when huge values are present.""" - # TODO: consider adding the optimization of crossentropy into the current + # TODO: consider adding the rewrite of crossentropy into the current # mode for the purpose of running this test for dtype in "float32", "float64": @@ -1027,15 +1031,17 @@ def test_asymptotic_32(): assert gxval[0, 1] == 0.25 -class TestSoftmaxOpt: - # Test that expressions of softmax in terms of exponentiated things - # divided by row sums are replaced by softmax expressions. - # - # Softmax_grad isn't that interesting as an Op, but it has the signature - # we look for when trying to insert CrossEntropySoftmax... grad. So, for - # now, we add softmax_grad to graphs. In the future, we may modify the - # CrossEntropySoftmax...grad to look for the more basic pattern. - # +class TestSoftmaxRewrite: + """ + Test that expressions of softmax in terms of exponentiated things + divided by row sums are replaced by softmax expressions. + + `Softmax_grad` isn't that interesting as an Op, but it has the signature + we look for when trying to insert `CrossEntropySoftmax` grad. So, for + now, we add `softmax_grad` to graphs. In the future, we may modify the + `CrossEntropySoftmax` grad to look for the more basic pattern. + + """ def setup_method(self): self.mode = aesara.compile.mode.get_default_mode() @@ -1086,7 +1092,7 @@ def test_basic_keepdims(self, axis): c_val = rng.random((3, 4, 5)).astype(config.floatX) assert np.allclose(f(c_val), sp.softmax(c_val, axis=axis)) - @pytest.mark.skip(reason="Optimization not enabled for the moment") + @pytest.mark.skip(reason="Rewrite not enabled for the moment") def test_grad(self): c = matrix() p_y = exp(c) / exp(c).sum(axis=1).dimshuffle(0, "x") @@ -1116,7 +1122,7 @@ def test_transpose_basic(self): assert len(f_ops) == 1 assert isinstance(f_ops[0], Softmax) - @pytest.mark.skip(reason="Optimization not enabled for the moment") + @pytest.mark.skip(reason="Rewrite not enabled for the moment") def test_transpose_grad(self): # this should be a transposed softmax c = matrix() @@ -1139,7 +1145,7 @@ def test_1d_basic(self): assert len(f_ops) == 1 assert isinstance(f_ops[0], Softmax) - @pytest.mark.skip(reason="Optimization not enabled for the moment") + @pytest.mark.skip(reason="Rewrite not enabled for the moment") def test_1D_grad(self): c = vector() p_y = exp(c) / exp(c).sum() @@ -1207,12 +1213,12 @@ def test_stabilize_log_softmax(): f = aesara.function([x], z, mode=mode) assert check_stack_trace(f, ops_to_check="all") - # check that the softmax has been optimized out + # Check that the softmax has been rewritten for node in f.maker.fgraph.toposort(): assert not isinstance(node.op, y.owner.op.__class__) - # call the function so debug mode can verify the optimized - # version matches the unoptimized version + # Call the function so debug mode can verify the rewritten version matches + # the un-rewritten version rng = np.random.default_rng(utt.fetch_seed()) f(np.cast[config.floatX](rng.random((2, 3)))) @@ -1222,25 +1228,25 @@ def test_relu(): rng = np.random.default_rng(utt.fetch_seed()) X = rng.standard_normal((20, 30)).astype(config.floatX) - # test the base case, without custom alpha value + # Test the base case, without custom alpha value y = relu(x).eval({x: X}) assert np.allclose(y, np.maximum(X, 0)) - # test for different constant alpha values (also outside of [0, 1]) + # Test for different constant alpha values (also outside of [0, 1]) for alpha in 0, 0.3, 1, 2, -0.3, -1, -2: y = relu(x, alpha).eval({x: X}) assert np.allclose(y, np.where(X > 0, X, alpha * X)) - # test for variable alpha (scalar, vector and matrix) + # Test for variable alpha (scalar, vector and matrix) for alpha in scalar(), vector(), matrix(): - # create value for alpha (correct ndim and broadcastable against X) + # Create value for alpha (correct ndim and broadcastable against X) A = np.array( rng.standard_normal(X.shape[::-1][: alpha.ndim][::-1]), dtype=config.floatX ) y = relu(x, alpha).eval({x: X, alpha: A}) assert np.allclose(y, np.where(X > 0, X, A * X), rtol=3e-5) - # test that for alpha of ndarray don't cause upcast. + # Test that an alpha of type `ndarray` doesn't generate an upcast x = matrix("x", dtype="float32") X = rng.standard_normal((20, 30)).astype("float32") alpha = np.asarray(0.123, dtype="float32") @@ -1251,8 +1257,7 @@ def test_relu(): def test_h_softmax(): - # Tests the output dimensions of the h_softmax when a target is provided or - # not. + """Tests the output dimensions of the `h_softmax` when a target is provided or not.""" input_size = 4 batch_size = 2 diff --git a/tests/tensor/nnet/test_conv3d2d.py b/tests/tensor/nnet/test_conv3d2d.py index 180b9dd96a..f717bc17f0 100644 --- a/tests/tensor/nnet/test_conv3d2d.py +++ b/tests/tensor/nnet/test_conv3d2d.py @@ -11,7 +11,7 @@ import tests.unittest_tools as utt from aesara.compile.sharedvalue import shared -from aesara.graph.opt import check_stack_trace +from aesara.graph.rewriting.basic import check_stack_trace from aesara.tensor.nnet.conv3d2d import ( DiagonalSubtensor, IncDiagonalSubtensor, diff --git a/tests/tensor/nnet/test_opt.py b/tests/tensor/nnet/test_rewriting.py similarity index 96% rename from tests/tensor/nnet/test_opt.py rename to tests/tensor/nnet/test_rewriting.py index 4aebc6c8ff..e3b6020d27 100644 --- a/tests/tensor/nnet/test_opt.py +++ b/tests/tensor/nnet/test_rewriting.py @@ -1,5 +1,5 @@ import aesara -from aesara.graph.opt import check_stack_trace +from aesara.graph.rewriting.basic import check_stack_trace from aesara.tensor.nnet.blocksparse import ( sparse_block_dot, sparse_block_gemv, diff --git a/tests/tensor/nnet/test_sigm.py b/tests/tensor/nnet/test_sigm.py index bd2bb11a45..ec1dbe5993 100644 --- a/tests/tensor/nnet/test_sigm.py +++ b/tests/tensor/nnet/test_sigm.py @@ -4,7 +4,7 @@ import aesara from aesara.compile.mode import get_default_mode, get_mode from aesara.configdefaults import config -from aesara.graph.opt import check_stack_trace +from aesara.graph.rewriting.basic import check_stack_trace from aesara.scalar.basic import Composite from aesara.tensor.elemwise import Elemwise from aesara.tensor.inplace import sigmoid_inplace diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 0bec385fda..f5366ce9e7 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -1,4 +1,5 @@ import pickle +import re from copy import copy import numpy as np @@ -13,8 +14,7 @@ from aesara.graph.basic import Constant, Variable, graph_inputs from aesara.graph.fg import FunctionGraph from aesara.graph.op import get_test_value -from aesara.graph.optdb import OptimizationQuery -from aesara.tensor.basic_opt import ShapeFeature +from aesara.graph.rewriting.db import RewriteDatabaseQuery from aesara.tensor.random.basic import ( bernoulli, beta, @@ -55,12 +55,13 @@ wald, weibull, ) +from aesara.tensor.rewriting.shape import ShapeFeature from aesara.tensor.type import iscalar, scalar, tensor from tests.unittest_tools import create_aesara_param -opts = OptimizationQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) -py_mode = Mode("py", opts) +rewrites_query = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) +py_mode = Mode("py", rewrites_query) def fixed_scipy_rvs(rvs_name): @@ -1229,10 +1230,15 @@ def test_multinomial_rng(): (10, 2, 3), lambda *args, **kwargs: np.tile(np.arange(3).astype(np.int64), (10, 2, 1)), ), + ( + np.full((4, 1, 3), [100000, 1, 1], dtype=config.floatX), + (4, 2), + lambda *args, **kwargs: np.zeros((4, 2), dtype=np.int64), + ), ], ) def test_categorical_samples(p, size, test_fn): - p = p / p.sum(axis=-1) + p = p / p.sum(axis=-1, keepdims=True) rng = np.random.default_rng(232) compare_sample_values( @@ -1251,7 +1257,20 @@ def test_categorical_basic(): rng = np.random.default_rng() with pytest.raises(ValueError): - categorical.rng_fn(rng, p, size=10) + # The independent dimension of p has shape=(3,) which cannot be + # broadcasted to (10,) + categorical.rng_fn(rng, p, size=(10,)) + + msg = re.escape("`size` is incompatible with the shape of `p`") + with pytest.raises(ValueError, match=msg): + # The independent dimension of p has shape=(3,) which cannot be + # broadcasted to (1,) + categorical.rng_fn(rng, p, size=(1,)) + + with pytest.raises(ValueError, match=msg): + # The independent dimensions of p have shape=(1, 3) which cannot be + # broadcasted to (3,) + categorical.rng_fn(rng, p[None], size=(3,)) def test_randint_samples(): @@ -1302,17 +1321,28 @@ def test_choice_samples(): with pytest.raises(NotImplementedError): choice._supp_shape_from_params(np.asarray(5)) + compare_sample_values(choice, np.asarray(5)) compare_sample_values(choice, np.asarray([5])) compare_sample_values(choice, np.array([1.0, 5.0], dtype=config.floatX)) compare_sample_values(choice, np.asarray([5]), 3) - with pytest.raises(ValueError): - compare_sample_values(choice, np.array([[1, 2], [3, 4]])) + compare_sample_values(choice, np.array([[1, 2], [3, 4]])) + compare_sample_values(choice, np.array([[1, 2], [3, 4]]), p=[0.4, 0.6]) compare_sample_values(choice, [1, 2, 3], 1) + compare_sample_values( choice, [1, 2, 3], 1, p=at.as_tensor([1 / 3.0, 1 / 3.0, 1 / 3.0]) ) + + # p must be 1-dimensional. + # TODO: The exception is raised at runtime but could be raised at compile + # time in some situations using static shape analysis. + with pytest.raises(ValueError): + rng = np.random.default_rng() + rng_at = shared(rng, borrow=True) + choice(a=[1, 2], p=at.as_tensor([[0.1, 0.9], [0.3, 0.7]]), rng=rng_at).eval() + compare_sample_values(choice, [1, 2, 3], (10, 2), replace=True) compare_sample_values(choice, at.as_tensor_variable([1, 2, 3]), 2, replace=True) diff --git a/tests/tensor/random/test_opt.py b/tests/tensor/random/test_rewriting.py similarity index 91% rename from tests/tensor/random/test_opt.py rename to tests/tensor/random/test_rewriting.py index f9695457cd..f8df26b037 100644 --- a/tests/tensor/random/test_opt.py +++ b/tests/tensor/random/test_rewriting.py @@ -7,8 +7,8 @@ from aesara.compile.mode import Mode from aesara.graph.basic import Constant from aesara.graph.fg import FunctionGraph -from aesara.graph.opt import EquilibriumOptimizer -from aesara.graph.optdb import OptimizationQuery +from aesara.graph.rewriting.basic import EquilibriumGraphRewriter +from aesara.graph.rewriting.db import RewriteDatabaseQuery from aesara.tensor.elemwise import DimShuffle from aesara.tensor.random.basic import ( dirichlet, @@ -19,7 +19,7 @@ uniform, ) from aesara.tensor.random.op import RandomVariable -from aesara.tensor.random.opt import ( +from aesara.tensor.random.rewriting import ( local_dimshuffle_rv_lift, local_rv_size_lift, local_subtensor_rv_lift, @@ -28,10 +28,12 @@ from aesara.tensor.type import iscalar, vector -no_mode = Mode("py", OptimizationQuery(include=[], exclude=[])) +no_mode = Mode("py", RewriteDatabaseQuery(include=[], exclude=[])) -def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None): +def apply_local_rewrite_to_rv( + rewrite, op_fn, dist_op, dist_params, size, rng, name=None +): dist_params_at = [] for p in dist_params: p_at = at.as_tensor(p).type() @@ -50,20 +52,20 @@ def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None p for p in dist_params_at + size_at if not isinstance(p, (slice, Constant)) ] - mode = Mode("py", EquilibriumOptimizer([opt], max_use_ratio=100)) + mode = Mode("py", EquilibriumGraphRewriter([rewrite], max_use_ratio=100)) - f_opt = function( + f_rewritten = function( f_inputs, dist_st, mode=mode, ) - (new_out,) = f_opt.maker.fgraph.outputs + (new_out,) = f_rewritten.maker.fgraph.outputs - return new_out, f_inputs, dist_st, f_opt + return new_out, f_inputs, dist_st, f_rewritten -def test_inplace_optimization(): +def test_inplace_rewrites(): out = normal(0, 1) out.owner.inputs[0].default_update = out.owner.outputs[0] @@ -87,7 +89,7 @@ def test_inplace_optimization(): assert np.array_equal(new_out.owner.inputs[1].data, []) -def test_inplace_optimization_extra_props(): +def test_inplace_rewrites_extra_props(): class Test(RandomVariable): name = "test" ndim_supp = 0 @@ -183,7 +185,7 @@ def rng_fn(self, rng, sigma, size): def test_local_rv_size_lift(dist_op, dist_params, size): rng = shared(np.random.default_rng(1233532), borrow=False) - new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv( + new_out, f_inputs, dist_st, f_rewritten = apply_local_rewrite_to_rv( local_rv_size_lift, lambda rv: rv, dist_op, @@ -349,7 +351,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): rng = shared(np.random.default_rng(1233532), borrow=False) - new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv( + new_out, f_inputs, dist_st, f_rewritten = apply_local_rewrite_to_rv( local_dimshuffle_rv_lift, lambda rv: rv.dimshuffle(ds_order), dist_op, @@ -377,9 +379,9 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): arg_values = [p.get_test_value() for p in f_inputs] res_base = f_base(*arg_values) - res_opt = f_opt(*arg_values) + res_rewritten = f_rewritten(*arg_values) - np.testing.assert_allclose(res_base, res_opt, rtol=rtol) + np.testing.assert_allclose(res_base, res_rewritten, rtol=rtol) @pytest.mark.parametrize( @@ -472,7 +474,7 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): i_at.tag.test_value = i indices_at += (i_at,) - new_out, f_inputs, dist_st, f_opt = apply_local_opt_to_rv( + new_out, f_inputs, dist_st, f_rewritten = apply_local_rewrite_to_rv( local_subtensor_rv_lift, lambda rv: rv[indices_at], dist_op, @@ -502,9 +504,9 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): arg_values = [p.get_test_value() for p in f_inputs] res_base = f_base(*arg_values) - res_opt = f_opt(*arg_values) + res_rewritten = f_rewritten(*arg_values) - np.testing.assert_allclose(res_base, res_opt, rtol=1e-3) + np.testing.assert_allclose(res_base, res_rewritten, rtol=1e-3) def test_Subtensor_lift_restrictions(): @@ -519,7 +521,7 @@ def test_Subtensor_lift_restrictions(): z = x - y fg = FunctionGraph([rng], [z], clone=False) - _ = EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) + _ = EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) subtensor_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner assert subtensor_node == y.owner @@ -531,7 +533,7 @@ def test_Subtensor_lift_restrictions(): # We add `x` as an output to make sure that `is_rv_used_in_graph` handles # `"output"` "nodes" correctly. fg = FunctionGraph([rng], [z, x], clone=False) - EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) + EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) assert fg.outputs[0] == z assert fg.outputs[1] == x @@ -539,7 +541,7 @@ def test_Subtensor_lift_restrictions(): # The non-`Subtensor` client doesn't depend on the RNG state, so we can # perform the lift fg = FunctionGraph([rng], [z], clone=False) - EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) + EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) rv_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner assert rv_node.op == normal @@ -557,7 +559,9 @@ def test_Dimshuffle_lift_restrictions(): z = x - y fg = FunctionGraph([rng], [z, y], clone=False) - _ = EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) + _ = EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply( + fg + ) dimshuffle_node = fg.outputs[0].owner.inputs[1].owner assert dimshuffle_node == y.owner @@ -569,7 +573,7 @@ def test_Dimshuffle_lift_restrictions(): # We add `x` as an output to make sure that `is_rv_used_in_graph` handles # `"output"` "nodes" correctly. fg = FunctionGraph([rng], [z, x], clone=False) - EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) + EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) assert fg.outputs[0] == z assert fg.outputs[1] == x @@ -577,7 +581,7 @@ def test_Dimshuffle_lift_restrictions(): # The non-`Dimshuffle` client doesn't depend on the RNG state, so we can # perform the lift fg = FunctionGraph([rng], [z], clone=False) - EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) + EquilibriumGraphRewriter([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) rv_node = fg.outputs[0].owner.inputs[1].owner assert rv_node.op == normal @@ -613,7 +617,7 @@ def test_Dimshuffle_lift_rename(ds_order, lifted, dist_op, dist_params, size, rt rng = shared(np.random.default_rng(1233532), borrow=False) - new_out, *_ = apply_local_opt_to_rv( + new_out, *_ = apply_local_rewrite_to_rv( local_dimshuffle_rv_lift, lambda rv: rv.dimshuffle(ds_order), dist_op, diff --git a/tests/tensor/random/test_utils.py b/tests/tensor/random/test_utils.py index 722809aed8..18a7650147 100644 --- a/tests/tensor/random/test_utils.py +++ b/tests/tensor/random/test_utils.py @@ -3,7 +3,7 @@ from aesara import config, function from aesara.compile.mode import Mode -from aesara.graph.optdb import OptimizationQuery +from aesara.graph.rewriting.db import RewriteDatabaseQuery from aesara.tensor.random.utils import RandomStream, broadcast_params from aesara.tensor.type import matrix, tensor from tests import unittest_tools as utt @@ -11,8 +11,8 @@ @pytest.fixture(scope="module", autouse=True) def set_aesara_flags(): - opts = OptimizationQuery(include=[None], exclude=[]) - py_mode = Mode("py", opts) + rewrites_query = RewriteDatabaseQuery(include=[None], exclude=[]) + py_mode = Mode("py", rewrites_query) with config.change_flags(mode=py_mode, compute_test_value="warn"): yield diff --git a/tests/tensor/rewriting/__init__.py b/tests/tensor/rewriting/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py new file mode 100644 index 0000000000..0301909c82 --- /dev/null +++ b/tests/tensor/rewriting/test_basic.py @@ -0,0 +1,1894 @@ +import copy + +import numpy as np +import pytest + +import aesara +import aesara.scalar as aes +import aesara.tensor as at +from aesara import shared +from aesara.compile import optdb +from aesara.compile.function import function +from aesara.compile.mode import get_default_mode, get_mode +from aesara.compile.ops import DeepCopyOp, deep_copy_op +from aesara.configdefaults import config +from aesara.graph.fg import FunctionGraph +from aesara.graph.rewriting.basic import check_stack_trace, out2in +from aesara.graph.rewriting.db import RewriteDatabaseQuery +from aesara.graph.rewriting.utils import rewrite_graph +from aesara.printing import pprint +from aesara.raise_op import Assert, CheckAndRaise +from aesara.tensor.basic import ( + Alloc, + Join, + MakeVector, + ScalarFromTensor, + Split, + TensorFromScalar, + join, + tile, +) +from aesara.tensor.elemwise import DimShuffle, Elemwise +from aesara.tensor.math import ( + add, + bitwise_and, + bitwise_or, + bitwise_xor, + dot, + eq, + exp, + floor_div, + ge, + gt, + int_div, + le, + log, + lt, + maximum, + minimum, + mul, + neq, +) +from aesara.tensor.math import pow as at_pow +from aesara.tensor.math import softplus, sqrt, sub +from aesara.tensor.math import sum as at_sum +from aesara.tensor.math import true_div +from aesara.tensor.rewriting.basic import ( + assert_op, + local_alloc_sink_dimshuffle, + local_merge_alloc, + local_useless_alloc, + local_useless_elemwise, +) +from aesara.tensor.rewriting.math import local_lift_transpose_through_dot +from aesara.tensor.rewriting.shape import ShapeFeature +from aesara.tensor.shape import ( + Reshape, + Shape_i, + SpecifyShape, + Unbroadcast, + specify_shape, + unbroadcast, +) +from aesara.tensor.subtensor import ( + AdvancedIncSubtensor1, + Subtensor, + advanced_inc_subtensor, + advanced_inc_subtensor1, + inc_subtensor, +) +from aesara.tensor.type import ( + TensorType, + dmatrix, + dscalar, + dvector, + fmatrix, + fscalar, + imatrices, + iscalar, + iscalars, + ivector, + lscalar, + lvector, + matrices, + matrix, + row, + scalar, + scalars, + tensor, + tensor3, + tensor4, + values_eq_approx_remove_nan, + vector, +) +from tests import unittest_tools as utt + + +rewrite_mode = config.mode +if rewrite_mode == "FAST_COMPILE": + rewrite_mode = "FAST_RUN" +rewrite_mode = get_mode(rewrite_mode) + +_stabilize_rewrites = RewriteDatabaseQuery(include=["fast_run"]) +_stabilize_rewrites.position_cutoff = 1.51 +_stabilize_rewrites = optdb.query(_stabilize_rewrites) + +_specialize_rewrites = RewriteDatabaseQuery(include=["fast_run"]) +_specialize_rewrites.position_cutoff = 2.01 +_specialize_rewrites = optdb.query(_specialize_rewrites) + +_fast_run_rewrites = RewriteDatabaseQuery(include=["fast_run"]) +_fast_run_rewrites = optdb.query(_fast_run_rewrites) + + +def rewrite(g, level="fast_run"): + if level == "fast_run": + _fast_run_rewrites.rewrite(g) + elif level == "specialize": + _specialize_rewrites.rewrite(g) + elif level == "stabilize": + _stabilize_rewrites.rewrite(g) + else: + raise ValueError(level) + return g + + +def test_local_useless_slice(): + # test a simple matrix + x = matrix("x") + mode_excluding = get_default_mode().excluding( + "local_useless_slice", "local_mul_canonizer" + ) + mode_including = ( + get_default_mode() + .including("local_useless_slice") + .excluding("local_mul_canonizer") + ) + + # test with and without the useless slice + o = 2 * x[0, :] + f_excluding = function([x], o, mode=mode_excluding) + f_including = function([x], o, mode=mode_including) + rng = np.random.default_rng(utt.fetch_seed()) + test_inp = rng.integers(-10, 10, (4, 4)).astype("float32") + assert all(f_including(test_inp) == f_excluding(test_inp)) + # test to see if the slice is truly gone + apply_node = f_including.maker.fgraph.toposort()[0] + subtens = apply_node.op + assert not any(isinstance(idx, slice) for idx in subtens.idx_list) + + # Now test that the stack trace is copied over properly, + # before before and after rewriting. + assert check_stack_trace(f_excluding, ops_to_check="all") + assert check_stack_trace(f_including, ops_to_check="all") + + # test a 4d tensor + z = tensor4("z") + o2 = z[1, :, :, 1] + o3 = z[0, :, :, :] + f_including_check = function([z], o2, mode=mode_including) + f_including_check_apply = function([z], o3, mode=mode_including) + + # The rewrite shouldn't apply here + apply_node = f_including_check.maker.fgraph.toposort()[0] + subtens = apply_node.op + assert [isinstance(idx, slice) for idx in subtens.idx_list].count(True) == 2 + # But it should here + apply_node = f_including_check_apply.maker.fgraph.toposort()[0] + subtens = apply_node.op + assert not any(isinstance(idx, slice) for idx in subtens.idx_list) + + # Finally, test that the stack trace is copied over properly, + # before before and after rewriting. + assert check_stack_trace(f_including_check, ops_to_check=Subtensor) + assert check_stack_trace(f_including_check_apply, ops_to_check=Subtensor) + + +def test_local_useless_fill(): + x = dvector() + y = dvector() + z = lvector() + + x_ = np.random.random((5,)) + y_ = np.random.random((5,)) + z_ = (np.random.random((5,)) * 5).astype("int64") + + # basic case + f = function([x], at.fill(x, x) * 2, mode=rewrite_mode) + assert [node.op for node in f.maker.fgraph.toposort()] == [mul] + res = f(x_) + exp_res = np.broadcast_to(x_, x_.shape) * 2 + assert np.array_equal(res, exp_res) + + # basic case + f = function([x, y], at.second(y, x) * 2, mode=rewrite_mode) + assert [node.op for node in f.maker.fgraph.toposort()] == [mul] + res = f(x_, y_) + exp_res = np.broadcast_to(x_, y_.shape) * 2 + assert np.array_equal(res, exp_res) + + # basic case + f = function([x, y], at.fill(x, y) * 2, mode=rewrite_mode) + assert [node.op for node in f.maker.fgraph.toposort()] == [mul] + res = f(x_, y_) + exp_res = np.broadcast_to(y_, x_.shape) * 2 + assert np.array_equal(res, exp_res) + + # now with different type(cast) + f = function([x, z], at.fill(z, x) * 2, mode=rewrite_mode) + assert [node.op for node in f.maker.fgraph.toposort()] == [mul] + res = f(x_, z_) + exp_res = np.broadcast_to(x_, z_.shape) * 2 + assert np.array_equal(res, exp_res) + + # now with different type(cast) + f = function([x, z], at.fill(x, z) * 2, mode=rewrite_mode) + assert [node.op for node in f.maker.fgraph.toposort()] == [mul] + res = f(x_, z_) + exp_res = np.broadcast_to(z_, x_.shape) * 2 + assert np.array_equal(res, exp_res) + + # now cutting out the input ?? + f = function([x, y], at.fill(x, y) * 2, mode=rewrite_mode) + assert [node.op for node in f.maker.fgraph.toposort()] == [mul] + res = f(x_, y_) + exp_res = np.broadcast_to(y_, x_.shape) * 2 + assert np.array_equal(res, exp_res) + + +def test_local_fill_to_alloc(): + x = dvector() + m = dmatrix() + + x_ = np.random.random((5,)) + m_ = np.random.random((5, 5)) + + y = at.fill(m, x) + + mode = rewrite_mode.including("stabilize", "local_fill_to_alloc").excluding( + "useless", "local_useless_fill" + ) + + f = function([m, x], y, mode=mode) + assert Alloc in [node.op.__class__ for node in f.maker.fgraph.toposort()] + + res = f(m_, x_) + exp_res = np.broadcast_to(x_, m_.shape) + assert np.array_equal(res, exp_res) + + y = at.fill(x, m) + + f = function([m, x], y, mode=mode) + + assert Alloc not in [node.op.__class__ for node in f.maker.fgraph.toposort()] + + res = f(m_, x_) + assert np.array_equal(res, m_) + + +class TestLocalCanonicalizeAlloc: + def setup_method(self): + self.rng = np.random.default_rng(utt.fetch_seed()) + + def test_inconsistent_constant(self): + x = at.as_tensor(self.rng.standard_normal((3, 7))) + a = at.alloc(x, 6, 7) + + assert a.owner and isinstance(a.owner.op, Alloc) + + # `local_useless_alloc` should replace the `Alloc` with an `Assert` + with pytest.raises(AssertionError): + f = function([], a, mode=rewrite_mode) + + x = at.as_tensor(self.rng.standard_normal((6, 7))) + a = at.alloc(x, 6, 7) + + f = function([], a, mode=rewrite_mode) + + # The rewrite should then be applied, and remove Alloc + assert not any( + isinstance(node.op, (Alloc, Assert)) for node in f.maker.fgraph.toposort() + ) + + def test_inconsistent_shared(self): + # These shapes don't match! + x = shared(self.rng.standard_normal((3, 7))) + a = at.alloc(x, 6, 7) + + assert a.owner and isinstance(a.owner.op, Alloc) + + f = function([], a, mode=rewrite_mode) + + # The rewrite should then be applied, and remove Alloc + assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()) + assert any(isinstance(node.op, Assert) for node in f.maker.fgraph.toposort()) + + with pytest.raises(AssertionError): + f() + + good_x_val = self.rng.standard_normal((6, 7)) + x.set_value(good_x_val) + + assert np.array_equal(f(), good_x_val) + + def test_basic_fill(self): + x = matrix("x") + y = at.fill(x, x) + + # The rewrite `locall_fill_to_alloc` should call `at.alloc`, + # which should return `x` and not `alloc(x, ...)` + f = function([x], [y], mode=rewrite_mode.including("local_fill_to_alloc")) + assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()) + + def test_basic_tile(self): + x = matrix("x") + y = at.tile(x, (1,) * 2) + + mode = rewrite_mode.including( + "local_dimshuffle_lift", + "local_useless_dimshuffle_in_reshape", + "local_alloc_sink_dimshuffle", + ) + f = function([x], [y], mode=mode) + + assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()) + + @pytest.mark.parametrize( + "x, has_alloc", + [ + (at.alloc(np.ones((2,)), 1, 3, 2), True), + (at.alloc(np.array(1.0), 1, 1), False), + (at.alloc(np.ones((1, 1)), 1, 1, 2), True), + (at.alloc(np.ones((1, 1)), 1, 2), True), + ], + ) + def test_useless_alloc_with_shape_one(self, x, has_alloc): + g = FunctionGraph(outputs=[x]) + assert any(isinstance(node.op, Alloc) for node in g.toposort()) + + alloc_lift = out2in(local_alloc_sink_dimshuffle) + alloc_lift.rewrite(g) + + if has_alloc: + assert any(isinstance(node.op, Alloc) for node in g.toposort()) + else: + assert not any(isinstance(node.op, Alloc) for node in g.toposort()) + + +class TestLocalUselessIncSubtensorAlloc: + rewrite_name = "local_useless_inc_subtensor_alloc" + + def setup_method(self): + # The rewrite requires the shape feature so we need to compile in + # FAST_RUN mode. + mode = config.mode + if mode == "FAST_COMPILE": + mode = "FAST_RUN" + self.mode = get_mode(mode) + self.rng = np.random.default_rng(utt.fetch_seed()) + + def test_advanced_inc_subtensor(self): + x = vector("x") + y = scalar("y") + i = matrix("i", dtype="int64") + z = advanced_inc_subtensor(x, at.alloc(y, *i.shape), i) + mode1 = self.mode.excluding(self.rewrite_name) + mode2 = self.mode.including(self.rewrite_name) + f1 = function([x, i, y], z, mode=mode1) + f2 = function([x, i, y], z, mode=mode2) + + # the alloc op should still be there + assert ( + len([n for n in f1.maker.fgraph.toposort() if isinstance(n.op, Alloc)]) == 1 + ) + # the alloc op should have been removed + assert ( + len([n for n in f2.maker.fgraph.toposort() if isinstance(n.op, Alloc)]) == 0 + ) + + x_value = np.random.standard_normal((5)).astype(config.floatX) + y_value = np.random.standard_normal() + i_value = self.rng.integers(0, 3, size=(2, 3)) + + r1 = f1(x_value, i_value, y_value) + r2 = f2(x_value, i_value, y_value) + + utt.assert_allclose(r1, r2) + + # Check stacktrace was copied over correctly after rewrite was applied + assert check_stack_trace(f1, ops_to_check=AdvancedIncSubtensor1) + assert check_stack_trace(f2, ops_to_check=AdvancedIncSubtensor1) + + def test_advanced_inc_subtensor1(self): + x = vector("x") + y = scalar("y") + i = vector("i", dtype="int64") + z = advanced_inc_subtensor1(x, at.alloc(y, *i.shape), i) + mode1 = self.mode.excluding(self.rewrite_name) + mode2 = self.mode.including(self.rewrite_name) + f1 = function([x, i, y], z, mode=mode1) + f2 = function([x, i, y], z, mode=mode2) + + # the alloc op should still be there + assert ( + len([n for n in f1.maker.fgraph.toposort() if isinstance(n.op, Alloc)]) == 1 + ) + # the alloc op should have been removed + assert ( + len([n for n in f2.maker.fgraph.toposort() if isinstance(n.op, Alloc)]) == 0 + ) + + x_value = np.random.standard_normal((5)).astype(config.floatX) + y_value = np.random.standard_normal() + i_value = self.rng.integers(0, 3, size=2) + + r1 = f1(x_value, i_value, y_value) + r2 = f2(x_value, i_value, y_value) + + utt.assert_allclose(r1, r2) + + assert check_stack_trace(f1, ops_to_check=AdvancedIncSubtensor1) + assert check_stack_trace(f2, ops_to_check="all") + + def test_incsubtensor(self): + x = vector("x") + y = scalar("y") + i = scalar("i", dtype="int64") + z = inc_subtensor(x[:i], at.alloc(y, i)) + mode1 = self.mode.excluding(self.rewrite_name) + mode2 = self.mode.including(self.rewrite_name) + f1 = function([x, i, y], z, mode=mode1) + f2 = function([x, i, y], z, mode=mode2) + + # the alloc op should still be there + assert ( + len([n for n in f1.maker.fgraph.toposort() if isinstance(n.op, Alloc)]) == 1 + ) + # the alloc op should have been removed + assert ( + len([n for n in f2.maker.fgraph.toposort() if isinstance(n.op, Alloc)]) == 0 + ) + + x_value = np.random.standard_normal((5)).astype(config.floatX) + y_value = np.random.standard_normal() + i_value = 3 + + r1 = f1(x_value, i_value, y_value) + r2 = f2(x_value, i_value, y_value) + + utt.assert_allclose(r1, r2) + + assert check_stack_trace(f1, ops_to_check="last") + assert check_stack_trace(f2, ops_to_check="last") + + +class TestUselessCheckAndRaise: + def test_basic(self): + mode = get_default_mode().including( + "canonicalize", "local_remove_useless_assert" + ) + x = scalar() + y = scalar() + f = function([x, y], assert_op(x, eq(x, y)), mode=mode) + assert f(1, 1) == 1 + with pytest.raises(AssertionError): + f(1, 0) + + def test_local_remove_useless_1(self): + """Remove `CheckAndRaise`s when all the conditions are always true.""" + x = scalar() + fg = FunctionGraph(outputs=[assert_op(x, 1)], clone=False) + fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"]) + topo = fg_res.toposort() + assert not any(isinstance(node.op, CheckAndRaise) for node in topo) + + def test_local_remove_useless_2(self): + """Remove `CheckAndRaise` conditions that are always true.""" + x = scalar() + y = scalar() + fg = FunctionGraph(outputs=[assert_op(x, y, 1)], clone=False) + fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"]) + topo = fg_res.toposort() + (assert_node,) = [node for node in topo if isinstance(node.op, CheckAndRaise)] + assert assert_node.inputs == [x, y] + + def test_local_remove_useless_3(self): + """Don't remove `CheckAndRaise` conditions that are always false.""" + x = scalar() + y = scalar() + fg = FunctionGraph(outputs=[assert_op(x, y, 0)], clone=False) + fg_res = rewrite_graph(fg, include=["canonicalize", "specialize"]) + topo = fg_res.toposort() + (assert_node,) = [node for node in topo if isinstance(node.op, CheckAndRaise)] + assert assert_node.inputs[:2] == [x, y] + assert assert_node.inputs[-1].data == 0 + + +def test_local_remove_all_assert(): + r"""Remove all `Assert`\s.""" + mode = get_default_mode().including("canonicalize", "local_remove_all_assert") + + x = scalar() + y = scalar() + f = function([x, y], assert_op(x, y), mode=mode) + # Without the rewrite, this would fail + assert f(1, 0) == 1 + topo = f.maker.fgraph.toposort() + assert not any(isinstance(node.op, CheckAndRaise) for node in topo) + + mode = get_default_mode() + a = assert_op(x, eq(x, 0).any()) + f = function([x], a, mode=mode.excluding("unsafe")) + topo = f.maker.fgraph.toposort() + a_op = [n for n in topo if isinstance(n.op, Assert)] + assert len(a_op) == 1 + + +class TestTile: + def test_local_useless_tile(self): + v = vector() + m = matrix() + mode = None + if config.mode == "FAST_COMPILE": + mode = "FAST_RUN" + for var, data in [(v, [1, 2, 3]), (m, [[1, 2], [3, 4]])]: + # When len(repeat pattern) <= var.ndim, everything is removed + # for ndim in range(1, var.ndim): + for ndim in range(var.ndim + 1): + f = function([var], tile(var, (1,) * ndim), mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert isinstance(topo[0].op, DeepCopyOp) + f(data) + # In this case, the rewrite only removes nodes; + # no need to `check_stack_trace` + # When len(repeat pattern) > var.ndim, only a dimshuffle should be + # left, but there can be a DeepCopy as well + for ndim in range(var.ndim + 1, var.ndim + 3): + f = function([var], tile(var, (1,) * ndim), mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) <= 2 + assert isinstance(topo[0].op, DimShuffle) + assert check_stack_trace(f, ops_to_check=[DimShuffle]) + f(data) + + +class TestUnbroadcast: + def setup_method(self): + self.mode = get_default_mode().including("canonicalize") + + def test_local_useless_unbroadcast(self): + x1 = tensor("float64", shape=(1, 2)) + x2 = tensor("float64", shape=(2, 1)) + unbroadcast_op = Unbroadcast(0) + + f = function([x1], unbroadcast_op(x1), mode=self.mode) + assert ( + sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort()) + == 1 + ) + + f = function([x2], unbroadcast_op(x2), mode=self.mode) + assert ( + sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort()) + == 0 + ) + + def test_local_unbroadcast_lift(self): + x = tensor("float64", shape=(1, 1)) + y = unbroadcast(at.exp(unbroadcast(x, 0)), 1) + + assert ( + sum( + isinstance(node.op, Unbroadcast) + for node in FunctionGraph([x], [y], copy_inputs=False).toposort() + ) + == 2 + ) + + f = function([x], y, mode=self.mode) + assert ( + sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort()) + == 1 + ) + + np.testing.assert_almost_equal(f([[1]]), np.exp([[1]])) + + +class TestUselessElemwise: + def setup_method(self): + self.mode = get_default_mode().including("canonicalize", "local_fill_to_alloc") + + def test_eq(self): + x = dmatrix() + y = dmatrix() + f = function([x, y], eq(x, y), mode=self.mode) + vx = np.random.random((5, 4)) + vy = np.random.random((5, 4)) + f(vx, vy) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert isinstance(topo[0].op, Elemwise) + assert isinstance(topo[0].op.scalar_op, aes.EQ) + f2 = function([x], eq(x, x), mode=self.mode) + assert np.all(f2(vx) == np.ones((5, 4))) + topo2 = f2.maker.fgraph.toposort() + # Shape_i{1}(), + # Shape_i{0}(), Alloc([[1]], Shape_i{0}.0, + # Shape_i{1}.0 + assert len(topo2) == 3 + assert isinstance(topo2[-1].op, Alloc) + + def test_neq(self): + x = dmatrix() + y = dmatrix() + f = function([x, y], neq(x, y), mode=self.mode) + vx = np.random.random((5, 4)) + vy = np.random.random((5, 4)) + f(vx, vy) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert isinstance(topo[0].op, Elemwise) + assert isinstance(topo[0].op.scalar_op, aes.NEQ) + f2 = function([x], neq(x, x), mode=self.mode) + assert np.all(f2(vx) == np.zeros((5, 4))) + topo2 = f2.maker.fgraph.toposort() + assert len(topo2) == 3 + assert isinstance(topo2[-1].op, Alloc) + + def test_mul(self): + x = dmatrix() + y = dmatrix() + f = function([x], mul(x), mode=self.mode) + vx = np.random.random((5, 4)) + vy = np.random.random((5, 4)) + f(vx) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert topo[0].op == deep_copy_op + f2 = function([x, y], mul(x, y), mode=self.mode) + assert np.all(f2(vx, vy) == vx * vy) + topo2 = f2.maker.fgraph.toposort() + assert len(topo2) == 1 + assert isinstance(topo2[0].op, Elemwise) + assert isinstance(topo2[0].op.scalar_op, aes.Mul) + + def test_add(self): + x = dmatrix() + y = dmatrix() + f = function([x], add(x), mode=self.mode) + vx = np.random.random((5, 4)) + vy = np.random.random((5, 4)) + f(vx) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert topo[0].op == deep_copy_op + f2 = function([x, y], add(x, y), mode=self.mode) + assert np.all(f2(vx, vy) == vx + vy) + topo2 = f2.maker.fgraph.toposort() + assert len(topo2) == 1 + assert isinstance(topo2[0].op, Elemwise) + assert isinstance(topo2[0].op.scalar_op, aes.Add) + + def test_identity(self): + # aes.identity is used in 2 Elemwise functions: + # tensor_copy, and view + x = matrix() + f = function([x], at.tensor_copy(x), mode=self.mode) + vx = np.random.random((5, 4)).astype(config.floatX) + f(vx) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert topo[0].op == deep_copy_op + + +class TestCastCast: + def setup_method(self): + mode = get_default_mode() + self.mode = mode.including("local_cast_cast") + + def test_consecutive(self): + x = fmatrix() + o = Elemwise(aes.Cast(aes.ScalarType("float64")))(x.astype("float64")) + f = function([x], o, mode=self.mode) + dx = np.random.random((5, 4)).astype("float32") + f(dx) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert isinstance(topo[0].op.scalar_op, aes.basic.Cast) + + x = dmatrix() + o = Elemwise(aes.Cast(aes.ScalarType("float32")))(x.astype("float32")) + f = function([x], o, mode=self.mode) + dx = np.random.random((5, 4)) + f(dx) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert isinstance(topo[0].op.scalar_op, aes.basic.Cast) + + def test_upcast(self): + # Upcast followed by any other cast + x = fmatrix() + o = Elemwise(aes.Cast(aes.ScalarType("complex128")))(x.astype("complex64")) + f = function([x], o, mode=self.mode) + dx = np.random.random((5, 4)).astype("float32") + f(dx) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert isinstance(topo[0].op.scalar_op, aes.basic.Cast) + + # Upcast followed by a downcast back to the base type + x = fmatrix() + o = Elemwise(aes.Cast(aes.ScalarType("float32")))(x.astype("float64")) + f = function([x], o, mode=self.mode) + dx = np.random.random((5, 4)).astype("float32") + f(dx) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert isinstance(topo[0].op, DeepCopyOp) + + # Downcast followed by an upcast back to the base type + # The rewrite shouldn't be applied + x = dmatrix() + o = Elemwise(aes.Cast(aes.ScalarType("float64")))(x.astype("float32")) + f = function([x], o, mode=self.mode) + dx = np.random.random((5, 4)) + f(dx) + topo = f.maker.fgraph.toposort() + assert ( + len(topo) == 1 and isinstance(topo[0].op.scalar_op, aes.basic.Composite) + ) or (len(topo) > 1) + + +def test_constant_folding(): + # Test that constant folding get registered at fast_compile + # An error removed that registration during the registration. + x = dvector() + mode = get_mode("FAST_COMPILE").excluding("fusion") + f = function([x], [x * 2, x + x], mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 2 + + # Test that we do not crash when constant folding elemwise scalar + # as they should not generate c code. + + x = at.constant(3) + assert x.ndim == 0 + mode = get_mode("FAST_COMPILE").excluding("fusion") + f = function([], [x * 2, x + x], mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 2 + assert all(isinstance(n.op, DeepCopyOp) for n in topo) + + +@pytest.mark.xfail( + reason="Aesara rewrites constants before stabilization. " + "This breaks stabilization rewrites in some cases. See #504.", + raises=AssertionError, +) +def test_constant_get_stabilized(): + # Currently Aesara enables the `constant_folding` rewrite before stabilization rewrites. + # This caused some stabilization rewrites to not be activated and that + # caused inf values to appear when they should not. + + # We can't simply move the `constant_folding` rewrite to + # specialize since this will break other rewrites. We will need to + # partially duplicate some canonicalize rewrites to fix this issue. + + x2 = scalar() + y2 = log(1 + exp(x2)) + mode = get_default_mode() + mode.check_isfinite = False + f2 = function([x2], y2, mode=mode) + + assert len(f2.maker.fgraph.toposort()) == 1 + assert f2.maker.fgraph.toposort()[0].op == softplus + assert f2(800) == 800 + + x = at.as_tensor_variable(800) + y = log(1 + exp(x)) + f = function([], y, mode=mode) + # When this error is fixed, the following line should be ok. + assert f() == 800, f() + + +class TestLocalSwitchSink: + def setup_method(self): + # condition values + self.condm = np.asarray([[0.1, 0, 1, -1], [0.0, 0.0, 0.0, 0.0], [1, 1, 1, 1]]) + self.condv = np.asarray([0.1, 0, 1, -1]) + self.conds = [0.1, 0, 1, -1] + + # x values + self.xm = np.ones((3, 4)) + self.xv = np.ones((4,)) + self.xs = 1.0 + + # expected results + self.resm = ( + [np.asarray([[1, 0, 1, 0], [0, 0, 0, 0], [1, 1, 1, 1]])] * 3 + + [np.asarray([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]])] + + 2 * [np.asarray([[1, 0, 1, 0]])] + + [[np.ones((3, 4)), np.zeros((3, 4)), np.ones((3, 4)), np.zeros((3, 4))]] + + [[np.ones((4,)), np.zeros((4,)), np.ones((4,)), np.zeros((4,))]] + + [[np.asarray(1.0), np.asarray(0.0), np.asarray(1.0), np.asarray(0.0)]] + ) + + self.mode = ( + get_default_mode() + .including("canonicalize", "fast_run") + .excluding("gpu", "fusion") + ) + self.mode = copy.copy(self.mode) + self.mode.check_isfinite = False + + def function_remove_nan(self, *args, **kwargs): + """ + Wrapper around function for this test. + + It disables checking for NaN removed by rewrites in `DebugMode` + (it has false positives in that case). + """ + f = function(*args, **kwargs) + + def wrapped_f(*args, **kwargs): + # This is a bit ugly since it changes the global value of + # TensorType.values_eq_approx. + old_values_eq_approx = staticmethod(TensorType.values_eq_approx) + TensorType.values_eq_approx = staticmethod(values_eq_approx_remove_nan) + try: + out = f(*args, **kwargs) + finally: + TensorType.values_eq_approx = old_values_eq_approx + return out + + return wrapped_f + + def test_local_mul_switch_sink(self): + c = dscalar() + idx = 0 + for condition in [ + (dmatrix("cond"), self.condm), + (dvector("cond"), self.condv), + (dscalar("cond"), self.conds), + ]: + for x in [ + (dmatrix("x"), self.xm), + (dvector("x"), self.xv), + (dscalar("x"), self.xs), + ]: + y = mul( + at.switch(condition[0] > 0, 1.0 * x[0], 0.0 * x[0]), + at.switch(condition[0] > 0, 1.0 * x[0], log(c) * x[0]), + ) + f = self.function_remove_nan( + [condition[0], x[0], c], [y], mode=self.mode + ) + if type(condition[1]) is list: + for i in range(len(condition[1])): + res = f(condition[1][i], x[1], -1) + assert ( + res == np.asarray(self.resm[idx][i]) + ).sum() == self.resm[idx][i].size + else: + res = f(condition[1], x[1], -1) + assert (res == np.asarray(self.resm[idx])).sum() == self.resm[ + idx + ].size + idx += 1 + + # This case caused a missed rewrite in the past. + x = dscalar("x") + y = at.switch(x < 7, x, sqrt(x - 7)) + f = self.function_remove_nan([x], aesara.gradient.grad(y, x), self.mode) + assert f(5) == 1, f(5) + + @pytest.mark.slow + def test_local_div_switch_sink(self): + c = dscalar() + idx = 0 + for condition in [ + (dmatrix("cond"), self.condm), + (dvector("cond"), self.condv), + (dscalar("cond"), self.conds), + ]: + for x in [ + (dmatrix("x"), self.xm), + (dvector("x"), self.xv), + (dscalar("x"), self.xs), + ]: + y = true_div( + at.switch(condition[0] > 0, 1.0 * x[0], 0.0 * x[0]), + at.switch(condition[0] > 0, 1.0 * x[0], log(c) * x[0]), + ) + f = self.function_remove_nan( + [condition[0], x[0], c], [y], mode=self.mode + ) + if type(condition[1]) is list: + for i in range(len(condition[1])): + res = f(condition[1][i], x[1], -1) + assert ( + res == np.asarray(self.resm[idx][i]) + ).sum() == self.resm[idx][i].size + else: + res = f(condition[1], x[1], -1) + assert (res == np.asarray(self.resm[idx])).sum() == self.resm[ + idx + ].size + idx += 1 + + +class TestLocalUselessSwitch: + def setup_method(self): + self.mode = rewrite_mode.excluding("constant_folding") + + @pytest.mark.parametrize( + "dtype1", + ["int32", "int64"], + ) + @pytest.mark.parametrize( + "dtype2", + ["int32", "int64"], + ) + @pytest.mark.parametrize( + "cond", + [0, 1, np.array([True])], + ) + def test_const(self, dtype1, dtype2, cond): + x = matrix("x", dtype=dtype1) + y = matrix("y", dtype=dtype2) + z = at.switch(cond, x, y) + f = function([x, y], z, mode=self.mode) + assert not any( + node.op + for node in f.maker.fgraph.toposort() + if ( + isinstance(node.op, Elemwise) + and isinstance(node.op.scalar_op, aes.basic.Switch) + ) + ) + vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1) + vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2) + np_res = np.where(cond, vx, vy) + assert np.array_equal(f(vx, vy), np_res) + + @pytest.mark.parametrize( + "dtype1", + ["int32", "int64"], + ) + def test_left_is_right(self, dtype1): + x = matrix("x", dtype=dtype1) + varc = matrix("varc", dtype=dtype1) + z1 = at.switch(1, x, x) + z0 = at.switch(0, x, x) + z2 = at.switch(varc, x, x) + f1 = function([x], z1, mode=self.mode) + f0 = function([x], z0, mode=self.mode) + f2 = function([x, varc], z2, mode=self.mode) + + topo = f1.maker.fgraph.toposort() + assert len(topo) == 1 + assert topo[0].op == deep_copy_op + + topo = f0.maker.fgraph.toposort() + assert len(topo) == 1 + assert topo[0].op == deep_copy_op + + topo = f2.maker.fgraph.toposort() + assert len(topo) == 1 + assert topo[0].op == deep_copy_op + + vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1) + vc = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1) + assert np.array_equal(f1(vx), vx) + assert np.array_equal(f0(vx), vx) + assert np.array_equal(f2(vx, vc), vx) + + @pytest.mark.parametrize( + "dtype1", + ["float32", "float64"], + ) + def test_shape_le_0(self, dtype1): + x = matrix("x", dtype=dtype1) + z0 = at.switch(le(x.shape[0], 0), 0, x.shape[0]) + f0 = function([x], z0, mode=self.mode) + assert isinstance(f0.maker.fgraph.toposort()[0].op, Shape_i) + + z1 = at.switch(le(x.shape[1], 0), 0, x.shape[1]) + f1 = function([x], z1, mode=self.mode) + assert isinstance(f1.maker.fgraph.toposort()[0].op, Shape_i) + + vx = np.random.standard_normal((0, 5)).astype(dtype1) + assert f0(vx) == 0 + assert f1(vx) == 5 + + def test_broadcasting_1(self): + # test switch(cst, matrix, row) + x = matrix("x", dtype="int32") + y = vector("y", dtype="int64") + + z = at.switch(1, x, y) + f = function([x, y], z, mode=self.mode) + + start_var = f.maker.fgraph.outputs[0].owner.inputs[0] + assert isinstance(start_var.owner.op, Elemwise) + assert isinstance(start_var.owner.op.scalar_op, aes.basic.Cast) + assert not any(node.op == at.switch for node in f.maker.fgraph.toposort()) + + vx = np.array([[1, 2, 3], [4, 5, 6]], dtype="int32") + vy = np.array([10, 11, 12], dtype="int64") + np_res = np.where(1, vx, vy) + assert np.array_equal(f(vx, vy), np_res) + + z = at.switch(0, x, y) + f = function([x, y], z, mode=self.mode) + + assert isinstance(f.maker.fgraph.outputs[0].owner.op, Alloc) + assert f.maker.fgraph.inputs[1] == f.maker.fgraph.outputs[0].owner.inputs[0] + assert not any(node.op == at.switch for node in f.maker.fgraph.toposort()) + + vx = np.array([[1, 2, 3], [4, 5, 6]], dtype="int32") + vy = np.array([10, 11, 12], dtype="int64") + np_res = np.where(0, vx, vy) + assert np.array_equal(f(vx, vy), np_res) + + def test_broadcasting_2(self): + # test switch(cst, vector, matrix) + + x = vector("x", dtype="int32") + y = matrix("y", dtype="int64") + + z = at.switch(1, x, y) + f = function([x, y], z, mode=self.mode) + + assert isinstance(f.maker.fgraph.outputs[0].owner.op, Alloc) + assert not any(node.op == at.switch for node in f.maker.fgraph.toposort()) + + vx = np.array([4, 5, 6], dtype="int32") + vy = np.array([[7, 8, 9], [10, 11, 12]], dtype="int64") + np_res = np.where(1, vx, vy) + assert np.array_equal(f(vx, vy), np_res) + + z = at.switch(0, x, y) + f = function([x, y], z, mode=self.mode) + + assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp) + assert not any(node.op == at.switch for node in f.maker.fgraph.toposort()) + + vx = np.array([4, 5, 6], dtype="int32") + vy = np.array([[7, 8, 9], [10, 11, 12]], dtype="int64") + np_res = np.where(0, vx, vy) + assert np.array_equal(f(vx, vy), np_res) + + def test_broadcasting_3(self): + # test switch(matrix, same_vector, same_vector) + + x = matrix("x", dtype="int32") + y = vector("y", dtype="int64") + z = at.switch(x, y, y) + f = function([x, y], z, mode=self.mode) + vx = np.array([[0, 1], [1, 0]], dtype="int32") + vy = np.array([7, 8], dtype="int64") + utt.assert_allclose(f(vx, vy), np.where(vx, vy, vy)) + + assert isinstance(f.maker.fgraph.outputs[0].owner.op, Alloc) + assert not any(node.op == at.switch for node in f.maker.fgraph.toposort()) + + +class TestLocalMergeSwitchSameCond: + @pytest.mark.parametrize( + "op", + [ + add, + sub, + mul, + true_div, + int_div, + floor_div, + minimum, + maximum, + gt, + lt, + ge, + le, + eq, + neq, + at_pow, + ], + ) + def test_elemwise_float_ops(self, op): + # float Ops + mats = matrices("cabxy") + c, a, b, x, y = mats + s1 = at.switch(c, a, b) + s2 = at.switch(c, x, y) + + g = rewrite(FunctionGraph(mats, [op(s1, s2)])) + assert str(g).count("Switch") == 1 + + @pytest.mark.parametrize( + "op", + [ + bitwise_and, + bitwise_or, + bitwise_xor, + ], + ) + def test_elemwise_int_ops(self, op): + # integer Ops + mats = imatrices("cabxy") + c, a, b, x, y = mats + s1 = at.switch(c, a, b) + s2 = at.switch(c, x, y) + g = rewrite(FunctionGraph(mats, [op(s1, s2)])) + assert str(g).count("Switch") == 1 + + @pytest.mark.parametrize("op", [add, mul]) + def test_elemwise_multi_inputs(self, op): + # add/mul with more than two inputs + mats = imatrices("cabxy") + c, a, b, x, y = mats + s1 = at.switch(c, a, b) + s2 = at.switch(c, x, y) + u, v = matrices("uv") + s3 = at.switch(c, u, v) + g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)])) + assert str(g).count("Switch") == 1 + + +class TestLocalOptAlloc: + """ + TODO FIXME: These tests are incomplete; they need to `assert` something. + """ + + dtype = "float32" + + def test_sum_upcast(self): + s = lscalar() + a = at.alloc(np.asarray(5, dtype=self.dtype), s, s) + with config.change_flags(warn_float64="raise"): + f = function([s], a.sum()) + f(5) + + def test_prod_upcast(self): + s = lscalar() + a = at.alloc(np.asarray(5, dtype=self.dtype), s, s) + + with config.change_flags(warn_float64="raise"): + f = function([s], a.prod()) + f(5) + + @config.change_flags(on_opt_error="raise") + def test_sum_bool_upcast(self): + s = lscalar() + a = at.alloc(np.asarray(True, dtype="bool"), s, s) + f = function([s], a.sum()) + f(5) + # test with user specified dtype + f = function([s], a.sum(dtype=self.dtype)) + f(5) + # test only 1 axis summed + f = function([s], a.sum(axis=0, dtype=self.dtype)) + f(5) + + +class TestLocalOptAllocF16(TestLocalOptAlloc): + dtype = "float16" + + +def test_local_join_1(): + # test for vector + a = vector("a") + s = at.stack([a]) + f = function([a], s, mode=rewrite_mode) + val = f([1]) + assert np.all(val == [1]) + e = f.maker.fgraph.toposort() + assert len([n for n in e if isinstance(n.op, Join)]) == 0 + assert f.maker.fgraph.outputs[0].dtype == config.floatX + + # test for matrix join(0,a) + a = matrix("a") + s = join(0, a) + f = function([a], s, mode=rewrite_mode) + val = f([[1]]) + assert np.all(val == [[1]]) + e = f.maker.fgraph.toposort() + assert len([n for n in e if isinstance(n.op, Join)]) == 0 + assert f.maker.fgraph.outputs[0].dtype == config.floatX + + # test for matrix join(1,a) + s = join(1, a) + f = function([a], s, mode=rewrite_mode) + val = f([[1]]) + assert np.all(val == [[1]]) + e = f.maker.fgraph.toposort() + assert len([n for n in e if isinstance(n.op, Join)]) == 0 + assert f.maker.fgraph.outputs[0].dtype == config.floatX + + # test we don't apply when their is 2 inputs + s = join(1, a, a) + f = function([a], s, mode=rewrite_mode) + val = f([[1]]) + assert np.all(val == [[1]]) + e = f.maker.fgraph.toposort() + assert len([n for n in e if isinstance(n.op, Join)]) == 1 + assert f.maker.fgraph.outputs[0].dtype == config.floatX + + +def test_local_join_empty(): + # test for vector, vector, empty to vector + empty_vec = np.asarray([], dtype=config.floatX) + a = vector("a") + s = at.join(0, a, a, empty_vec) + f = function([a], s, mode=rewrite_mode) + val = f([1]) + assert np.all(val == [1]) + e = f.maker.fgraph.toposort() + assert len([n for n in e if isinstance(n.op, Join)]) == 1 + assert all( + not isinstance(n.op, Join) or len(n.inputs) == 3 + for n in e + if isinstance(n.op, Join) + ) + assert f.maker.fgraph.outputs[0].dtype == config.floatX + + # test for matrix join(1,a) + empty_mat = np.asarray([[]], dtype=config.floatX) + m = matrix("m") + s = join(1, empty_mat, m, m, m) + f = function([m], s, mode=rewrite_mode) + val = f([[1]]) + assert np.all(val == [[1]]) + e = f.maker.fgraph.toposort() + assert len([n for n in e if isinstance(n.op, Join)]) == 1 + assert all( + not isinstance(n.op, Join) or len(n.inputs) == 4 + for n in e + if isinstance(n.op, Join) + ) + assert f.maker.fgraph.outputs[0].dtype == config.floatX + # test for vector, vector, empty to matrix + # We can't rewrite this case. + s = at.stack([a, a, empty_vec]) + f = function([a], s, mode=rewrite_mode) + val = f([]) + assert np.all(val == [1]) + e = f.maker.fgraph.toposort() + assert len([n for n in e if isinstance(n.op, Join)]) == 1 + assert all( + not isinstance(n.op, Join) or len(n.inputs) == 4 + for n in e + if isinstance(n.op, Join) + ) + assert f.maker.fgraph.outputs[0].dtype == config.floatX + # test for matrix join(0,a) + # We can't rewrite this case. + s = join(0, m, np.asarray([[2.0]], dtype=config.floatX), m) + f = function([m], s, mode=rewrite_mode) + val = f([[1]]) + assert np.all(val == [[1], [2], [1]]) + e = f.maker.fgraph.toposort() + assert len([n for n in e if isinstance(n.op, Join)]) == 1 + assert all( + not isinstance(n.op, Join) or len(n.inputs) == 4 + for n in e + if isinstance(n.op, Join) + ) + assert f.maker.fgraph.outputs[0].dtype == config.floatX + + +def test_local_join_make_vector(): + a, b, c, d, e = scalars("abcde") + v = vector("v") + mv = MakeVector(config.floatX) + s = at.join(0, mv(a), v, mv(b, c), mv(d, e)) + f = function([a, b, c, d, e, v], s, mode=rewrite_mode) + val = f(1, 2, 3, 4, 6, [7, 8]) + assert np.all(val == [1, 7, 8, 2, 3, 4, 6]) + e = f.maker.fgraph.toposort() + assert len([n for n in e if isinstance(n.op, Join)]) == 1 + assert all( + not isinstance(n.op, Join) or len(n.inputs) == 4 + for n in e + if isinstance(n.op, Join) + ) + assert f.maker.fgraph.outputs[0].dtype == config.floatX + + assert check_stack_trace(f, ops_to_check="all") + + +@pytest.mark.parametrize( + "dtype", + [ + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + ], +) +def test_local_tensor_scalar_tensor(dtype): + t_type = TensorType(dtype=dtype, shape=()) + t = t_type() + s = at.scalar_from_tensor(t) + t2 = at.tensor_from_scalar(s) + + f = function([t], t2, mode=rewrite_mode) + e = f.maker.fgraph.toposort() + assert not any( + n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor)) + ) + + +@pytest.mark.parametrize( + "dtype", + [ + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + ], +) +def test_local_scalar_tensor_scalar(dtype): + s_type = aes.ScalarType(dtype=dtype) + s = s_type() + t = at.tensor_from_scalar(s) + s2 = at.scalar_from_tensor(t) + + f = function([s], s2, mode=rewrite_mode) + e = f.maker.fgraph.toposort() + assert not any( + n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor)) + ) + + +def test_local_useless_split(): + x = matrix("x") + splits = ivector("splits") + rewritten = at.split(x, splits, n_splits=1) + not_rewritten = at.split(x, splits, n_splits=3) + + mode = get_default_mode().including("local_useless_split") + f_rewritten = function([x, splits], rewritten, mode=mode) + f_not_rewritten = function([x, splits], not_rewritten, mode=mode) + + f_rewritten(np.random.random((4, 4)).astype(config.floatX), [4]) + f_not_rewritten(np.random.random((4, 4)).astype(config.floatX), [1, 2, 1]) + graph_rewritten = f_rewritten.maker.fgraph.toposort() + graph_not_rewritten = f_not_rewritten.maker.fgraph.toposort() + + assert isinstance(graph_rewritten[-1].op, DeepCopyOp) + assert len(graph_not_rewritten) == 1 + assert isinstance(graph_not_rewritten[0].op, Split) + + assert check_stack_trace(f_rewritten, ops_to_check=[Assert]) + assert check_stack_trace(f_not_rewritten, ops_to_check="all") + + +@pytest.mark.parametrize("i", list(range(1, 4))) +def test_local_flatten_lift(i): + x = tensor4() + out = at.flatten(exp(x), i) + assert out.ndim == i + mode = get_default_mode() + mode = mode.including("local_reshape_lift") + f = function([x], out, mode=mode) + x_np = np.random.random((5, 4, 3, 2)).astype(config.floatX) + out_np = f(x_np) + topo = f.maker.fgraph.toposort() + shape_out_np = tuple(x_np.shape[: i - 1]) + (np.prod(x_np.shape[i - 1 :]),) + assert shape_out_np == out_np.shape + + reshape_nodes = [n for n in topo if isinstance(n.op, Reshape)] + assert len(reshape_nodes) == 1 and at.is_flat(reshape_nodes[0].outputs[0], ndim=i) + assert isinstance(topo[-1].op, Elemwise) + + +class TestLiftTransposeThroughDot: + def simple_rewrite(self, g): + out2in(local_useless_elemwise).rewrite(g) + out2in(local_lift_transpose_through_dot).rewrite(g) + out2in(local_useless_elemwise).rewrite(g) + return g + + def test_matrix_matrix(self): + a, b = matrices("ab") + g = self.simple_rewrite(FunctionGraph([a, b], [dot(a, b).T])) + sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{1,0}(a)))" + assert str(g) == sg, (str(g), sg) + assert check_stack_trace(g, ops_to_check="all") + + def test_row_matrix(self): + a = vector("a") + b = matrix("b") + g = rewrite( + FunctionGraph([a, b], [dot(a.dimshuffle("x", 0), b).T]), + level="stabilize", + ) + sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{0,x}(a)))" + assert str(g) == sg, (str(g), sg) + assert check_stack_trace(g, ops_to_check="all") + + def test_matrix_col(self): + a = vector("a") + b = matrix("b") + g = rewrite( + FunctionGraph([a, b], [dot(b, a.dimshuffle(0, "x")).T]), + level="stabilize", + ) + sg = "FunctionGraph(dot(InplaceDimShuffle{x,0}(a), InplaceDimShuffle{1,0}(b)))" + assert str(g) == sg, (str(g), sg) + assert check_stack_trace(g, ops_to_check="all") + + +def test_local_upcast_elemwise_constant_inputs(): + s = dvector("s") + x = at_sum(log(10**s)) + f = function([s], [aesara.gradient.grad(x, s)]) + f([-42, -2.1, -1, -0.5, 0, 0.2, 1, 2, 12]) + + # This tests a corner case for which the rewrite should not be applied. + with config.change_flags(floatX="float32"): + v = lvector() + function([v], true_div(v, 2)) + + +def test_assert_op_gradient(): + x = vector("x") + assert_op = Assert() + cost = at_sum(assert_op(x, x.size < 2)) + grad = aesara.gradient.grad(cost, x) + func = function([x], grad) + + x_val = np.ones(shape=(1,), dtype=config.floatX) + assert func(x_val) == 1 + + +def test_local_merge_alloc(): + # Add this rewrite to the default mode; otherwise, FAST_COMPILE fails. + default_mode = get_default_mode() + rewrite_mode = default_mode.including("local_merge_alloc") + + x = iscalar("x") + y = iscalar("y") + y2 = iscalar("y2") + z = iscalar("z") + w = iscalar("w") + m = fscalar("m") + # case 1 + # Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) + output = at.alloc(at.alloc(m, 1, y, 1, 1), x, y, z, w) + f = function([m, x, y, z, w], output, mode=rewrite_mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert isinstance(topo[0].op, Alloc) + o = f(0.0, 1, 2, 3, 4) + assert o.shape == (1, 2, 3, 4) + + # case 2 + # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) + output = at.alloc(at.alloc(m, y, 1, 1), x, y, z, w) + f = function([m, x, y, z, w], output, mode=rewrite_mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert isinstance(topo[0].op, Alloc) + o = f(0.0, 1, 2, 3, 4) + assert o.shape == (1, 2, 3, 4) + + # case 3 + # Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> + # Alloc(m, x, assert(y1, y1==y2), z, w) + output = at.alloc(at.alloc(m, y, 1, 1), x, y2, z, w) + f = function([m, x, y, y2, z, w], output, mode=rewrite_mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 3 + assert isinstance(topo[-2].op, Assert) + assert isinstance(topo[-1].op, Alloc) + o = f(0.0, 1, 2, 2, 3, 4) + assert o.shape == (1, 2, 3, 4) + with pytest.raises((AssertionError, ValueError)): + f(0.0, 1, 2, 5, 3, 4) + + +def test_local_useless_alloc(): + + useless_alloc = out2in(local_useless_alloc) + merge_alloc = out2in(local_merge_alloc) + + x = iscalar("x") + y = iscalar("y") + y2 = iscalar("y2") + z = iscalar("z") + w = iscalar("w") + m = fscalar("m") + + # case 1 + # Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) + output = at.alloc(at.alloc(m, 1, y, 1, 1), x, y, z, w) + g = FunctionGraph([m, x, y, z, w], [output]) + + useless_alloc.rewrite(g) + merge_alloc.rewrite(g) + useless_alloc.rewrite(g) + + topo = g.toposort() + assert len(topo) == 1 + assert isinstance(topo[0].op, Alloc) + + # case 2 + # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) + output = at.alloc(at.alloc(m, y, 1, 1), x, y, z, w) + g = FunctionGraph([m, x, y, z, w], [output]) + + useless_alloc.rewrite(g) + merge_alloc.rewrite(g) + useless_alloc.rewrite(g) + + topo = g.toposort() + assert len(topo) == 1 + assert isinstance(topo[0].op, Alloc) + + # case 3 + # Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> + # Alloc(m, x, assert(y1, y1==y2), z, w) + output = at.alloc(at.alloc(m, y, 1, 1), x, y2, z, w) + g = FunctionGraph([m, x, y, y2, z, w], [output]) + + useless_alloc.rewrite(g) + merge_alloc.rewrite(g) + useless_alloc.rewrite(g) + + topo = g.toposort() + assert len(topo) == 3 + assert isinstance(topo[-2].op, Assert) + assert isinstance(topo[-1].op, Alloc) + + +def test_local_merge_consecutive_specify_shape(): + x = matrix() + s = at.as_tensor([iscalar(), iscalar()]) + y = specify_shape(specify_shape(x, s), s) + + y_fg = FunctionGraph(outputs=[y], copy_inputs=False) + y_rewritten_fg = rewrite_graph( + y_fg, + clone=False, + include=["canonicalize", "local_merge_consecutive_specify_shape"], + ) + y_rewritten = y_rewritten_fg.outputs[0] + + assert isinstance(y_rewritten.owner.op, SpecifyShape) + assert y_rewritten.owner.inputs[0] == x + + +def test_local_merge_consecutive_specify_shape2(): + x = tensor3() + s1, s2, s3, s4 = iscalars("s1", "s2", "s3", "s4") + y = specify_shape(specify_shape(x, [s1, s2, None]), [None, s3, s4]) + + y_fg = FunctionGraph(outputs=[y], copy_inputs=False) + y_rewritten_fg = rewrite_graph( + y_fg, + clone=False, + include=["canonicalize", "local_merge_consecutive_specify_shape"], + ) + y_rewritten = y_rewritten_fg.outputs[0] + + assert isinstance(y_rewritten.owner.op, SpecifyShape) + assert tuple(y_rewritten.owner.inputs) == (x, s1, s3, s4) + + +def test_printing(): + a, b = scalars("ab") + mv = MakeVector(config.floatX) + v = mv(a, b) + assert pprint(v) == "[a, b]" + + +class TestLocalElemwiseAlloc: + """ + + TODO FIXME: Remove redundant tests. + + """ + + dtype = config.floatX + + def setup_method(self): + self.fast_compile_mode = get_mode("FAST_COMPILE") + self.fast_run_mode = get_mode("FAST_RUN") + + self.vec = vector("vec", dtype=self.dtype) + self.mat = matrix("mat", dtype=self.dtype) + self.tens = tensor3("tens", dtype=self.dtype) + + self.alloc_wo_dep = at.alloc(self.vec, 2, 2) + self.alloc_wo_dep_broad = at.alloc(self.vec, 1, 2) + self.alloc_w_dep = at.alloc(self.vec, *self.mat.shape) + self.alloc_w_dep_broad = at.alloc(self.vec, 1, *self.mat.shape) + self.alloc_w_dep_broad2 = at.alloc( + self.vec, self.mat.shape[0], self.mat.shape[1], 1 + ) + self.alloc_w_dep_tens = at.alloc( + self.vec, self.tens.shape[0], self.tens.shape[1] + ) + self.tv_wo_dep = at.alloc(self.vec, 5, 5) + self.tm_wo_dep = at.alloc(self.mat, 5, 5, 5) + self.s = iscalar("s") + self.tv_w_dep = at.alloc(self.vec, self.s, self.s) + self.tm_w_dep = at.alloc(self.mat, 5, 5, 5) + self.row = row(dtype=self.dtype) + self.o = at.alloc(self.row, 5, 5) + + @staticmethod + def verify_op_count(f, count, cls): + assert ( + sum( + isinstance(elem.op, cls) + for elem in f.maker.fgraph.toposort() + if elem.op is not None + ) + == count + ) + + @pytest.mark.parametrize( + "expr, x_shape, y_shape", + [ + (lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 2), (3, 2)), + (lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 1), (1, 1)), + (lambda x, y: at.mul(x, at.alloc(y, 2, 3)), (1, 3), (2, 3)), + ( + lambda x, y: at.mul( + at.alloc(x, 3).dimshuffle("x", 0), y.dimshuffle("x", "x") + ), + (), + (), + ), + pytest.param( + lambda x, y: at.mul(y, at.alloc(1, x)), + (), + (), + marks=pytest.mark.xfail(reason="Not implemented"), + ), + (lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)), + (lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)), + ( + lambda x, y: at.mul(at.alloc(x, 15, 1), at.alloc(y, 15, 1)), + (15, 1), + (15, 1), + ), + ( + lambda x, y: at.mul(at.alloc(x, 15, 2), at.alloc(y, 15, 2)), + (15, 2), + (15, 2), + ), + ( + lambda x, y: at.mul(at.alloc(x, 15, 2).dimshuffle(1, 0), y), + (15, 2), + (2, 15), + ), + (lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2)), + ( + lambda x, y: at.mul(at.alloc(x, 1, 15, 2).dimshuffle(0, 2, 1), y), + (15, 2), + (2, 15), + ), + ], + ) + def test_basic(self, expr, x_shape, y_shape): + x = at.tensor("int64", (False,) * len(x_shape), name="x") + y = at.tensor("int64", (False,) * len(y_shape), name="y") + z = expr(x, y) + + z_opt = aesara.function( + [x, y], + z, + mode=get_default_mode().including("local_elemwise_alloc"), + on_unused_input="ignore", + ) + + assert not any( + isinstance(node.op, Alloc) for node in z_opt.maker.fgraph.toposort() + ) + + z_no_opt = aesara.function( + [x, y], + z, + mode=get_default_mode().excluding("local_elemwise_alloc"), + on_unused_input="ignore", + ) + + x_val = np.arange(np.prod(x_shape), dtype=np.int64).reshape(x_shape) + y_val = np.arange(np.prod(y_shape), dtype=np.int64).reshape(y_shape) + + res = z_opt(x_val, y_val) + exp_res = z_no_opt(x_val, y_val) + assert np.array_equal(res, exp_res) + + def test_single_input(self): + """Test that rewrite is not triggered when there is only one `Alloc` in an `Elemwise`.""" + x = at.matrix("x") + z = at.exp(at.alloc(x, 15, 1)) + + z_fg = FunctionGraph(outputs=[z], copy_inputs=False, features=[ShapeFeature()]) + + z_opt_fg = rewrite_graph(z_fg, clone=False, include=["local_elemwise_alloc"]) + assert any(isinstance(node.op, Alloc) for node in z_opt_fg.apply_nodes) + + def test_remove_alloc_wo_dimshuffle(self): + # Exclude `local_useless_alloc`, since it does not introduce + # `Assert` in all the same cases. + self.fast_run_mode = self.fast_run_mode.excluding( + "local_useless_alloc", "local_alloc_sink_dimshuffle" + ) + func = function( + [self.vec, self.mat], + self.alloc_wo_dep + self.mat, + mode=self.fast_compile_mode, + ) + self.verify_op_count(func, 1, Alloc) + self.verify_op_count(func, 0, Assert) + assert check_stack_trace(func, ops_to_check="all") + + func = function( + [self.vec, self.mat], self.alloc_wo_dep + self.mat, mode=self.fast_run_mode + ) + self.verify_op_count(func, 0, Alloc) + self.verify_op_count(func, 2, Assert) + + func = function( + [self.vec, self.mat], + self.alloc_wo_dep_broad + self.mat, + mode=self.fast_run_mode, + ) + self.verify_op_count(func, 0, Alloc) + self.verify_op_count(func, 1, Assert) + + # No optimization on alloc without assert + func = function( + [self.vec, self.mat], + self.alloc_w_dep + self.mat, + mode=self.fast_compile_mode, + ) + self.verify_op_count(func, 1, Alloc) + self.verify_op_count(func, 0, Assert) + + func = function( + [self.vec, self.mat], self.alloc_w_dep + self.mat, mode=self.fast_run_mode + ) + self.verify_op_count(func, 0, Alloc) + self.verify_op_count(func, 0, Assert) + + func = function( + [self.vec, self.mat], + self.alloc_w_dep_broad + self.mat, + mode=self.fast_run_mode, + ) + self.verify_op_count(func, 0, Alloc) + self.verify_op_count(func, 0, Assert) + + # This was previously not rewritten, but it is now that we + # have `BroadcastTo`. + func = function( + [self.vec, self.mat], + self.alloc_w_dep_broad2 + self.mat, + mode=self.fast_run_mode, + ) + self.verify_op_count(func, 0, Alloc) + self.verify_op_count(func, 1, Assert) + + def test_remove_alloc_w_dimshuffle(self): + func = function( + [self.vec, self.tens], + self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens, + mode=self.fast_compile_mode, + ) + self.verify_op_count(func, 1, Alloc) + self.verify_op_count(func, 0, Assert) + + # TODO FIXME: The `BroadcastTo` shapes should use the constants + # provided by the first/`Alloc` term, and not the unknown values from + # the `tens` term. + func = function( + [self.vec, self.tens], + self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens, + mode=self.fast_run_mode, + ) + self.verify_op_count(func, 0, Alloc) + self.verify_op_count(func, 2, Assert) + + func = function( + [self.vec, self.tens], + self.alloc_w_dep_tens.dimshuffle(0, 1, "x") + self.tens, + mode=self.fast_compile_mode, + ) + self.verify_op_count(func, 1, Alloc) + self.verify_op_count(func, 0, Assert) + + func = function( + [self.vec, self.tens], + self.alloc_w_dep_tens.dimshuffle(0, 1, "x") + self.tens, + mode=self.fast_run_mode, + ) + self.verify_op_count(func, 0, Alloc) + self.verify_op_count(func, 0, Assert) + + def test_multi_input_single_alloc(self): + # No optimization on dimshuffle with assert + func = function( + [self.vec, self.mat], + self.tv_wo_dep + self.tm_wo_dep, + mode=self.fast_compile_mode, + ) + self.verify_op_count(func, 2, Alloc) + self.verify_op_count(func, 0, Assert) + + # Optimization on dimshuffle with assert + # TODO: When we support static shape constraints like `shape[i] != 1`, + # reproduce this with such a constraint on `mat` and make sure the + # `BroadcastTo` is removed. + func = function( + [self.vec, self.mat], + self.tv_wo_dep + self.tm_wo_dep, + mode=self.fast_run_mode, + ) + self.verify_op_count(func, 0, Alloc) + self.verify_op_count(func, 0, Assert) + + # No optimization on dimshuffle without assert + func = function( + [self.vec, self.mat, self.s], + self.tv_w_dep + self.tm_w_dep, + mode=self.fast_compile_mode, + ) + self.verify_op_count(func, 2, Alloc) + self.verify_op_count(func, 0, Assert) + + # Optimization on dimshuffle without assert + func = function( + [self.vec, self.mat, self.s], + self.tv_w_dep + self.tm_w_dep, + mode=self.fast_run_mode, + ) + self.verify_op_count(func, 0, Alloc) + self.verify_op_count(func, 1, Assert) + + def test_misc(self): + x = row(dtype=self.dtype) + y = tensor(dtype=self.dtype, shape=(False, False, True)) + + out = at.alloc(x, 5, 5).dimshuffle(0, 1, "x") + y + func = function([y, x], out, mode=self.fast_run_mode) + + self.verify_op_count(func, 0, Alloc) + self.verify_op_count(func, 2, Assert) + + y_val = np.random.random((5, 5, 1)).astype(self.dtype) + x_val = np.random.random((1, 5)).astype(self.dtype) + exp_res = np.broadcast_to(x_val, (5, 5))[..., None] + y_val + assert np.array_equal(func(y_val, x_val), exp_res) + + +def test_deprecations(): + """Make sure we can import from deprecated modules.""" + with pytest.deprecated_call(): + from aesara.tensor.basic_opt import register_useless # noqa: F401 F811 + + with pytest.deprecated_call(): + from aesara.tensor.rewriting.basic import ShapeFeature # noqa: F401 diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py new file mode 100644 index 0000000000..cfb9b6a61d --- /dev/null +++ b/tests/tensor/rewriting/test_elemwise.py @@ -0,0 +1,1204 @@ +import contextlib + +import numpy as np +import pytest + +import aesara +import aesara.scalar as aes +import aesara.tensor as at +from aesara import shared +from aesara.compile.function import function +from aesara.compile.mode import Mode, get_default_mode +from aesara.configdefaults import config +from aesara.graph.basic import Constant +from aesara.graph.fg import FunctionGraph +from aesara.graph.rewriting.basic import check_stack_trace, out2in +from aesara.graph.rewriting.db import RewriteDatabaseQuery +from aesara.graph.rewriting.utils import rewrite_graph +from aesara.misc.safe_asarray import _asarray +from aesara.scalar.basic import Composite +from aesara.tensor.basic import MakeVector +from aesara.tensor.elemwise import DimShuffle, Elemwise +from aesara.tensor.math import ( + add, + bitwise_and, + bitwise_or, + cos, + cosh, + dot, + eq, + exp, + int_div, + invert, + iround, + log, + log2, + log10, + mul, + neg, + neq, +) +from aesara.tensor.math import pow as at_pow +from aesara.tensor.math import reciprocal +from aesara.tensor.math import round as at_round +from aesara.tensor.math import sin, sinh, sqr, sqrt +from aesara.tensor.math import sum as at_sum +from aesara.tensor.math import tan, tanh, true_div, xor +from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift +from aesara.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape +from aesara.tensor.shape import reshape +from aesara.tensor.type import ( + TensorType, + dmatrices, + dscalar, + dvector, + fscalar, + fvector, + matrix, + scalar, + tensor, + vector, + vectors, +) +from tests import unittest_tools as utt + + +dimshuffle_lift = out2in(local_dimshuffle_lift) + + +def ds(x, y): + return DimShuffle(x.type.broadcastable, y)(x) + + +def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)): + x = TensorType(shape=xbc, dtype="float64")("x") + y = TensorType(shape=ybc, dtype="float64")("y") + z = TensorType(shape=zbc, dtype="float64")("z") + return x, y, z + + +class TestDimshuffleLift: + def test_double_transpose(self): + x, y, z = inputs() + e = ds(ds(x, (1, 0)), (1, 0)) + g = FunctionGraph([x], [e]) + assert ( + str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))" + ) + dimshuffle_lift.rewrite(g) + assert str(g) == "FunctionGraph(x)" + # no need to check_stack_trace as graph is supposed to be empty + + def test_merge2(self): + x, y, z = inputs() + e = ds(ds(x, (1, "x", 0)), (2, 0, "x", 1)) + g = FunctionGraph([x], [e]) + assert ( + str(g) + == "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))" + ), str(g) + dimshuffle_lift.rewrite(g) + assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g) + # Check stacktrace was copied over correctly after rewrite was applied + assert check_stack_trace(g, ops_to_check="all") + + def test_elim3(self): + x, y, z = inputs() + e = ds(ds(ds(x, (0, "x", 1)), (2, 0, "x", 1)), (1, 0)) + g = FunctionGraph([x], [e]) + assert str(g) == ( + "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}" + "(InplaceDimShuffle{0,x,1}(x))))" + ), str(g) + dimshuffle_lift.rewrite(g) + assert str(g) == "FunctionGraph(x)", str(g) + # no need to check_stack_trace as graph is supposed to be empty + + def test_lift(self): + x, y, z = inputs([False] * 1, [False] * 2, [False] * 3) + e = x + y + z + g = FunctionGraph([x, y, z], [e]) + + # It does not really matter if the DimShuffles are inplace + # or not. + init_str_g_inplace = ( + "FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}" + "(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))" + ) + init_str_g_noinplace = ( + "FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}" + "(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))" + ) + assert str(g) in (init_str_g_inplace, init_str_g_noinplace), str(g) + + rewrite_str_g_inplace = ( + "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}" + "(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))" + ) + rewrite_str_g_noinplace = ( + "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}" + "(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))" + ) + dimshuffle_lift.rewrite(g) + assert str(g) in (rewrite_str_g_inplace, rewrite_str_g_noinplace), str(g) + # Check stacktrace was copied over correctly after rewrite was applied + assert check_stack_trace(g, ops_to_check="all") + + def test_recursive_lift(self): + v = vector(dtype="float64") + m = matrix(dtype="float64") + out = ((v + 42) * (m + 84)).T + g = FunctionGraph([v, m], [out]) + init_str_g = ( + "FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}" + "(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}" + "(, " + "InplaceDimShuffle{x}(TensorConstant{42}))), " + "Elemwise{add,no_inplace}" + "(, " + "InplaceDimShuffle{x,x}(TensorConstant{84})))))" + ) + assert str(g) == init_str_g + new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)[0] + new_g = FunctionGraph(g.inputs, [new_out]) + rewrite_str_g = ( + "FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}" + "(InplaceDimShuffle{0,x}(), " + "InplaceDimShuffle{x,x}(TensorConstant{42})), " + "Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}" + "(), " + "InplaceDimShuffle{x,x}(TensorConstant{84}))))" + ) + assert str(new_g) == rewrite_str_g + # Check stacktrace was copied over correctly after rewrite was applied + assert check_stack_trace(new_g, ops_to_check="all") + + def test_useless_dimshuffle(self): + x, _, _ = inputs() + e = ds(x, (0, 1)) + g = FunctionGraph([x], [e]) + assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))" + dimshuffle_lift.rewrite(g) + assert str(g) == "FunctionGraph(x)" + # Check stacktrace was copied over correctly after rewrite was applied + assert hasattr(g.outputs[0].tag, "trace") + + def test_dimshuffle_on_broadcastable(self): + x, y, z = inputs([False, True], [True, False, True], [False, False, True]) + u = at.constant(1) + ds_x = ds(x, (0, "x")) # useless + ds_y = ds(y, (2, 1, 0)) # useless + ds_z = ds(z, (2, 1, 0)) # useful + ds_u = ds(u, ("x")) # useful + g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u]) + assert ( + str(g) + == "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))" + ) + dimshuffle_lift.rewrite(g) + assert ( + str(g) + == "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))" + ) + # Check stacktrace was copied over correctly after rewrite was applied + assert hasattr(g.outputs[0].tag, "trace") + + +def test_local_useless_dimshuffle_in_reshape(): + vec = TensorType(shape=(False,), dtype="float64")("vector") + mat = TensorType(shape=(False, False), dtype="float64")("mat") + row = TensorType(shape=(True, False), dtype="float64")("row") + col = TensorType(shape=(False, True), dtype="float64")("col") + + reshape_dimshuffle_vector = reshape(vec.dimshuffle("x", 0), vec.shape) + reshape_dimshuffle_mat = reshape(mat.dimshuffle("x", 0, "x", 1), mat.shape) + reshape_dimshuffle_row = reshape(row.dimshuffle(1, "x"), row.shape) + reshape_dimshuffle_col = reshape(col.dimshuffle(0), col.shape) + + g = FunctionGraph( + [vec, mat, row, col], + [ + reshape_dimshuffle_vector, + reshape_dimshuffle_mat, + reshape_dimshuffle_row, + reshape_dimshuffle_col, + ], + ) + + assert str(g) == ( + "FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), " + "Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), " + "Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), " + "Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))" + ) + useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape) + useless_dimshuffle_in_reshape.rewrite(g) + assert str(g) == ( + "FunctionGraph(Reshape{1}(vector, Shape(vector)), " + "Reshape{2}(mat, Shape(mat)), " + "Reshape{2}(row, Shape(row)), " + "Reshape{2}(col, Shape(col)))" + ) + + # Check stacktrace was copied over correctly after rewrite was applied + assert check_stack_trace(g, ops_to_check="all") + + # Check that the rewrite does not get applied when the order + # of dimensions has changed. + reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape) + h = FunctionGraph([mat], [reshape_dimshuffle_mat2]) + str_h = str(h) + useless_dimshuffle_in_reshape.rewrite(h) + assert str(h) == str_h + + +class TestFusion: + rewrites = RewriteDatabaseQuery( + include=[ + "local_elemwise_fusion", + "composite_elemwise_fusion", + "canonicalize", + "inplace", + ], + exclude=["cxx_only", "BlasOpt"], + ) + mode = Mode(get_default_mode().linker, rewrites) + _shared = staticmethod(shared) + topo_exclude = () + + def my_init(dtype="float64", num=0): + return np.zeros((5, 5), dtype=dtype) + num + + fw, fx, fy, fz = [ + tensor(dtype="float32", shape=[False] * 2, name=n) for n in "wxyz" + ] + dw, dx, dy, dz = [ + tensor(dtype="float64", shape=[False] * 2, name=n) for n in "wxyz" + ] + ix, iy, iz = [tensor(dtype="int32", shape=[False] * 2, name=n) for n in "xyz"] + fv = fvector("v") + fs = fscalar("s") + fwv = my_init("float32", 1) + fxv = my_init("float32", 2) + fyv = my_init("float32", 3) + fzv = my_init("float32", 4) + fvv = _asarray(np.random.random(5), dtype="float32") + fsv = np.asarray(np.random.random(), dtype="float32") + dwv = my_init("float64", 5) + ixv = _asarray(my_init(num=60), dtype="int32") + iyv = _asarray(my_init(num=70), dtype="int32") + izv = _asarray(my_init(num=70), dtype="int32") + fwx = fw + fx + ftanx = tan(fx) + + @pytest.mark.parametrize( + "case", + [ + ( + fx + fy + fz, + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv + fyv + fzv, + "float32", + ), # 0 + ( + fx * fy * fz, + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv * fyv * fzv, + "float32", + ), # 1 + ( + fx + fy * fz, + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv + fyv * fzv, + "float32", + ), # 2 + ( + fx * fy + fz, + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv * fyv + fzv, + "float32", + ), # 3 + ( + fw + fx + fy + fz, + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv + fyv + fzv, + "float32", + ), + ( + (fw + fx) + (fy + fz), + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv + fyv + fzv, + "float32", + ), # 5 + ( + ((fw + fx) + fy) + fz, + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv + fyv + fzv, + "float32", + ), + ( + (fw + (fx + fy)) + fz, + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv + fyv + fzv, + "float32", + ), + ( + (fw + (fx + fy) + fz), + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv + fyv + fzv, + "float32", + ), + ( + fw + (fx + (fy + fz)), + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv + fyv + fzv, + "float32", + ), + ( + (fw + fx) + (fy + fz), + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv + fyv + fzv, + "float32", + ), # 10 + ( + fw * fx * fy * fz, + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv * fxv * fyv * fzv, + "float32", + ), + ( + fw + fx * fy * fz, + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv * fyv * fzv, + "float32", + ), + ( + fx + fy * fz * fx, + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv + fyv * fzv * fxv, + "float32", + ), + ( + fx * fy + fz + fy, + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv * fyv + fzv + fyv, + "float32", + ), + ( + fx * fy * fz * fw + fx + fy + fz + fw, + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fxv * fyv * fzv * fwv + fxv + fyv + fzv + fwv, + "float32", + ), # 15 + # test with constant + ( + (fw + fx) + (fy + fz) + 2.0, + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv + fyv + fzv + 2, + "float32", + ), + ( + ((fw + fx) + 2.0 + fy) + fz, + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv + fyv + fzv + 2, + "float32", + ), + ( + (fw + (fx + 2.0 + fy)) + fz, + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv + fyv + fzv + 2, + "float32", + ), + ( + (fw + (fx + fy) + 2 + fz), + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv + fyv + fzv + 2, + "float32", + ), + ( + fw + (fx + (fy + fz) + 2.0), + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv + fyv + fzv + 2, + "float32", + ), # 20 + ( + 2 + (fw + fx) + (fy + fz), + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 1, + fwv + fxv + fyv + fzv + 2, + "float32", + ), + # mix float32 and float64 + ( + 2 + (dw + fx) + (fy + fz), + (dw, fx, fy, fz), + (dwv, fxv, fyv, fzv), + 1, + dwv + fxv + fyv + fzv + 2, + "float64", + ), + ( + 2 + (fw + dw) + (fy + fz), + (fw, dw, fy, fz), + (fwv, dwv, fyv, fzv), + 1, + fwv + dwv + fyv + fzv + 2, + "float64", + ), + ( + 2 + (fw + fx) + (dw + fz), + (fw, fx, dw, fz), + (fwv, fxv, dwv, fzv), + 1, + fwv + fxv + dwv + fzv + 2, + "float64", + ), + ( + 2 + (fw + fx) + (fy + dw), + (fw, fx, fy, dw), + (fwv, fxv, fyv, dwv), + 1, + fwv + fxv + fyv + dwv + 2, + "float64", + ), # 25 + # test when their is other op then elemwise. + ( + (fwx.sum()) + (fwx) + (fy + fz), + (fw, fx, fy, fz), + (fwv, fxv, fyv, fzv), + 4, + (fwv + fxv).sum() + fwv + fxv + fyv + fzv, + "float32", + ), + # test other elemwise op + ( + fx + fy + cos(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv + fyv + np.cos(fzv), + "float32", + ), + ( + fx + fy + cosh(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv + fyv + np.cosh(fzv), + "float32", + ), + ( + fx + fy + abs(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv + fyv + np.absolute(fzv), + "float32", + ), + ( + ix + iy + abs(iz), + (ix, iy, iz), + (ixv, iyv, izv), + 1, + ixv + iyv + np.absolute(izv), + "int32", + ), # 30 + ( + fx + fy + log(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv + fyv + np.log(fzv), + "float32", + ), + ( + fx + fy + log2(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv + fyv + np.log2(fzv), + "float32", + ), + ( + fx + fy + log10(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv + fyv + np.log10(fzv), + "float32", + ), + ( + fx + fy**fz, + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv + fyv**fzv, + "float32", + ), # pow + ( + fx + fy + exp(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv + fyv + np.exp(fzv), + "float32", + ), # 35 + ( + fx - fy - fz, + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - fyv - fzv, + "float32", + ), + ( + fx - (fy / fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - (fyv / fzv), + "float32", + ), + ( + fx - true_div(fy, 2), + (fx, fy), + (fxv, fyv), + 1, + fxv - (fyv / 2), + "float32", + ), + ( + fx - true_div(fy, fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - (fyv / fzv), + "float32", + ), + ( + fx - int_div(ix * 100, iy * 1000), + (fx, ix, iy), + (fxv, ixv, iyv), + 1, + fxv - ((ixv * 100) // (iyv * 1000)), + { + "custom": "float64", + "numpy + floatX": config.floatX, + "numpy": "float64", + }, + ), # 40 + (fx - (fy / 2), (fx, fy), (fxv, fyv), 1, fxv - (fyv / 2), "float32"), + ( + fx - (fy % fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - (fyv % fzv), + "float32", + ), + ( + fx - (fy > fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - (fyv > fzv), + "float32", + ), + ( + fx - (fy >= fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - (fyv >= fzv), + "float32", + ), + ( + fx - (fy < fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - (fyv < fzv), + "float32", + ), # 45 + ( + fx - (fy <= fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - (fyv <= fzv), + "float32", + ), + ( + fx - eq(fy, fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - (fyv == fzv), + "float32", + ), + ( + fx - neq(fy, fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - (fyv != fzv), + "float32", + ), + ( + fx - fy + tan(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - fyv + np.tan(fzv), + "float32", + ), + ( + fx - fy + tanh(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - fyv + np.tanh(fzv), + "float32", + ), # 50 + ( + fx - fy + sin(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - fyv + np.sin(fzv), + "float32", + ), + ( + fx - fy + sinh(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - fyv + np.sinh(fzv), + "float32", + ), + ( + fx - fy + sqr(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - fyv + (fzv * fzv), + "float32", + ), + ( + fx - fy + sqrt(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - fyv + np.sqrt(fzv), + "float32", + ), + ( + fx - fy + reciprocal(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - fyv + (1 / fzv), + "float32", + ), # 55 + ( + fx - fy + neg(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - fyv + (-fzv), + "float32", + ), + ( + fx - fy + at_round(fz), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + fxv - fyv + np.round(fzv), + "float32", + ), + ( + ix - iy + iround(fz), + (ix, iy, fz), + (ixv, iyv, fzv), + 1, + ixv - iyv + np.round(fzv), + "int64", + ), + # Bit op + ( + fx - bitwise_or(iy, iz), + (fx, iy, iz), + (fxv, iyv, izv), + 1, + fxv - (iyv | izv), + { + "custom": "float64", + "numpy + floatX": config.floatX, + "numpy": "float64", + }, + ), + ( + fx - xor(iy, iz), + (fx, iy, iz), + (fxv, iyv, izv), + 1, + fxv - (iyv ^ izv), + { + "custom": "float64", + "numpy + floatX": config.floatX, + "numpy": "float64", + }, + ), # 60 + ( + fx - bitwise_and(iy, iz), + (fx, iy, iz), + (fxv, iyv, izv), + 1, + fxv - (iyv & izv), + { + "custom": "float64", + "numpy + floatX": config.floatX, + "numpy": "float64", + }, + ), + ( + fx - invert(iy), + (fx, iy), + (fxv, iyv), + 1, + fxv - (~iyv), + { + "custom": "float64", + "numpy + floatX": config.floatX, + "numpy": "float64", + }, + ), + ( + fx - at.cast(fy, dtype="float64"), + (fx, fy), + (fxv, fyv), + 1, + fxv - np.asarray(fyv, "float64"), + "float64", + ), + ( + at_pow(fx * fy + fz, fx * fy), + (fx, fy, fz), + (fxv, fyv, fzv), + 1, + np.power(fxv * fyv + fzv, fxv * fyv), + "float32", + ), + ( + fv + fy**fz, + (fv, fy, fz), + (fvv, fyv, fzv), + 2, + fvv + fyv**fzv, + "float32", + ), # fused with a dimshuffle #65 + ( + fv - fy + tanh(fz), + (fv, fy, fz), + (fvv, fyv, fzv), + 2, + fvv - fyv + np.tanh(fzv), + "float32", + ), # fused with a dimshuffle + # Cases where the same input is reused many times. + ( + mul(fx, fx, fx, fx), + (fx,), + (fxv,), + 1, + fxv * fxv * fxv * fxv, + "float32", + ), + ( + mul(fx, ftanx, ftanx), + (fx,), + (fxv,), + 1, + fxv * np.tan(fxv) * np.tan(fxv), + "float32", + ), + ( + mul(fx, ftanx, ftanx, fx), + (fx,), + (fxv,), + 1, + fxv * np.tan(fxv) * np.tan(fxv) * fxv, + "float32", + ), + ( + mul(ftanx, ftanx, fx + fy), + (fx, fy), + (fxv, fyv), + 1, + np.tan(fxv) * np.tan(fxv) * (fxv + fyv), + "float32", + ), # 70 + # Cases with different broadcast pattern. They should not + # be merged as this would duplicate computation + # The graph should have 2 elemwise and 1 dimshuffle + ( + fx * sin(fs), + (fx, fs), + (fxv, fsv), + 3, + fxv * np.sin(fsv), + "float32", + ), + ], + ) + def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): + """Verify that `Elemwise` fusion works.""" + + g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case + + if isinstance(out_dtype, dict): + out_dtype = out_dtype[config.cast_policy] + + if self._shared is None: + f = function(list(sym_inputs), g, mode=self.mode) + for x in range(nb_repeat): + out = f(*val_inputs) + else: + out = self._shared(np.zeros((5, 5), dtype=out_dtype), "out") + assert out.dtype == g.dtype + f = function(sym_inputs, [], updates=[(out, g)], mode=self.mode) + for x in range(nb_repeat): + f(*val_inputs) + out = out.get_value() + + atol = 1e-8 + if out_dtype == "float32": + atol = 1e-6 + + assert np.allclose(out, answer * nb_repeat, atol=atol) + + topo = f.maker.fgraph.toposort() + topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)] + if assert_len_topo: + + assert len(topo_) == nb_elemwise + + if nb_elemwise == 1: + # if no variable appears multiple times in the + # input of g, + # check that the number of input to the Composite + # Elemwise is ok + if len(set(g.owner.inputs)) == len(g.owner.inputs): + expected_len_sym_inputs = sum( + not isinstance(x, Constant) for x in topo_[0].inputs + ) + assert expected_len_sym_inputs == len(sym_inputs) + + assert out_dtype == out.dtype + + def test_fusion_35_inputs(self): + r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit.""" + inpts = vectors(["i%i" % i for i in range(35)]) + + # Make an elemwise graph looking like: + # sin(i34 + sin(i33 + sin(... i1 + sin(i0) ...))) + out = sin(inpts[0]) + for idx in range(1, 35): + out = sin(inpts[idx] + out) + + with config.change_flags(cxx=""): + f = function(inpts, out, mode=self.mode) + + # Make sure they all weren't fused + composite_nodes = [ + node + for node in f.maker.fgraph.toposort() + if isinstance(getattr(node.op, "scalar_op", None), aes.basic.Composite) + ] + assert not any(len(node.inputs) > 31 for node in composite_nodes) + + @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") + def test_big_fusion(self): + # In the past, pickle of Composite generated in that case + # crashed with max recursion limit. So we were not able to + # generate C code in that case. + factors = [] + sd = dscalar() + means = dvector() + + cst_05 = at.constant(0.5) + cst_m05 = at.constant(-0.5) + cst_2 = at.constant(2) + cst_m2 = at.constant(-2) + ones = at.constant(np.ones(10)) + n = 85 + if config.mode in ["DebugMode", "DEBUG_MODE"]: + n = 10 + + for i in range(n): + f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log( + cst_05 * (sd**cst_m2) / np.pi + ) + factors.append(at_sum(f)) + + logp = add(*factors) + + vars = [sd, means] + + # Make sure that C compilation is used + mode = Mode("cvm", self.rewrites) + dlogp = function(vars, [aesara.grad(logp, v) for v in vars], mode=mode) + + # Make sure something was fused + assert any( + isinstance(getattr(node.op, "scalar_op", None), aes.basic.Composite) + for node in dlogp.maker.fgraph.toposort() + ) + + def test_add_mul_fusion_inplace(self): + + rewrites = RewriteDatabaseQuery( + include=[ + "local_elemwise_fusion", + "composite_elemwise_fusion", + "canonicalize", + "inplace", + ], + exclude=["cxx_only", "BlasOpt"], + ) + + mode = Mode(self.mode.linker, rewrites) + + x, y, z = dmatrices("xyz") + out = dot(x, y) + x + y + z + f = function([x, y, z], out, mode=mode) + topo = [n for n in f.maker.fgraph.toposort()] + assert len(topo) == 2 + assert topo[-1].op.inplace_pattern + + new_out = f.maker.fgraph.outputs[0] + assert isinstance(new_out.owner.op, Elemwise) + assert isinstance(new_out.owner.op.scalar_op, aes.basic.Add) + assert len(new_out.owner.inputs) == 4 + + # TODO: Do we really need to do this? + _ = f( + np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5)) + ) + + @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") + def test_no_c_code(self): + r"""Make sure we avoid fusions for `Op`\s without C code implementations.""" + + # This custom `Op` has no `c_code` method + class NoCCodeOp(aes.basic.UnaryScalarOp): + def impl(self, x): + return x * 2 + + no_c_code_op = Elemwise(NoCCodeOp(aes.basic.upgrade_to_float)) + + mode = Mode(linker="cvm") + mode._optimizer = mode._optimizer.including( + "local_elemwise_fusion", + "composite_elemwise_fusion", + "canonicalize", + "inplace", + ) + + x = vector() + out = x * no_c_code_op(x + 1) + f = function([x], out, mode=mode) + + assert not any( + isinstance(getattr(n.op, "scalar_op"), aes.basic.Composite) + for n in f.maker.fgraph.toposort() + ) + + @pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]]) + def test_test_values(self, test_value): + """Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions. + + The test values we're talking about are the ones used when C implementations + are checked. + + """ + + rewrites = RewriteDatabaseQuery( + include=[ + "local_elemwise_fusion", + "composite_elemwise_fusion", + "canonicalize", + ], + exclude=["cxx_only", "BlasOpt"], + ) + + mode = Mode(self.mode.linker, rewrites) + + x, y, z = dmatrices("xyz") + + x.tag.test_value = test_value + y.tag.test_value = test_value + z.tag.test_value = test_value + + if test_value.size == 0: + cm = pytest.raises(ValueError) + else: + cm = contextlib.suppress() + + with config.change_flags( + compute_test_value="raise", compute_test_value_opt="raise" + ): + out = x * y + z + with cm: + f = function([x, y, z], out, mode=mode) + + if test_value.size != 0: + # Confirm that the fusion happened + assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite) + assert len(f.maker.fgraph.toposort()) == 1 + + x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs + assert np.array_equal( + f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]] + ) + + +class TimesN(aes.basic.UnaryScalarOp): + """ + Used in test TestCompositeCodegen + + Must be outside of the class, otherwise, the c cache code can't + pickle this class and this cause stuff printing during test. + """ + + def __eq__(self, other): + return super().__eq__(other) and self.n == other.n + + def __hash__(self): + return super().__hash__() ^ hash(self.n) + + def __init__(self, n, *args, **kwargs): + self.n = n + aes.basic.UnaryScalarOp.__init__(self, *args, **kwargs) + + def impl(self, x): + return x * self.n + + def c_support_code_apply(self, node, nodename): + n = str(self.n) + return ( + """ + float %(nodename)s_timesn(float x) { return x * %(n)s; } + """ + % locals() + ) + + def c_code(self, node, name, inputs, outputs, sub): + (x,) = inputs + (z,) = outputs + return f"{z} = {name}_timesn({x});" + + +class TestCompositeCodegen: + """ + Test The Composite Ops code generation in a case where there is multiple + scalar ops with support code. + """ + + def setup_method(self): + upgrade_to_float = aes.basic.upgrade_to_float + + self.scal_times_2 = TimesN(2, upgrade_to_float, name="times_2") + self.times_2 = Elemwise(self.scal_times_2, name="times_2") + + self.scal_times_3 = TimesN(3, upgrade_to_float, name="times_3") + self.times_3 = Elemwise(self.scal_times_3, name="times_3") + + self.x = fvector() + + def test_nested_composite(self): + y = self.times_2(self.x) + z = self.times_3(y) + f = function([self.x], z) + if config.mode != "FAST_COMPILE": + assert len(f.maker.fgraph.toposort()) == 1 + fval = f([1, 2, 3]) + assert np.all(fval == [6, 12, 18]) + + def test_local_useless_composite(self): + x = aes.float32() + c = aes.Composite([x], [x + 1, x - 1]) + X = matrix() + o = Elemwise(scalar_op=c)(X) + mode = get_default_mode().including("local_useless_composite") + + f = function([X], o[0], mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert len(topo[0].outputs) == 1 + utt.assert_allclose(f([[1.0]]), [[2.0]]) + + f = function([X], o[1], mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert len(topo[0].outputs) == 1 + utt.assert_allclose(f([[1.0]]), [[0.0]]) + + +def test_local_useless_dimshuffle_makevector(): + a = scalar() + x = MakeVector(config.floatX)(a) + y = x.dimshuffle(()) + + y_fg = FunctionGraph(outputs=[y], copy_inputs=False) + + y_rewritten_fg = rewrite_graph( + y_fg, + clone=False, + include=["canonicalize", "local_useless_dimshuffle_makevector"], + ) + + assert y_rewritten_fg.outputs[0] == a diff --git a/tests/tensor/rewriting/test_extra_ops.py b/tests/tensor/rewriting/test_extra_ops.py new file mode 100644 index 0000000000..fabb14939a --- /dev/null +++ b/tests/tensor/rewriting/test_extra_ops.py @@ -0,0 +1,302 @@ +import numpy as np +import pytest + +import aesara.scalar as aes +from aesara.compile.function import function +from aesara.compile.mode import OPT_NONE, Mode, get_default_mode +from aesara.graph.fg import FunctionGraph +from aesara.graph.rewriting.utils import rewrite_graph +from aesara.tensor.basic import Alloc, alloc, as_tensor_variable, second +from aesara.tensor.elemwise import DimShuffle, Elemwise +from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique +from aesara.tensor.type import dscalar + + +@pytest.mark.parametrize("return_index", [False]) +@pytest.mark.parametrize("return_counts", [False]) +@pytest.mark.parametrize("return_inverse", [False]) +def test_local_Unique_scalar(return_index, return_counts, return_inverse): + x = dscalar() + y = unique( + x, + return_index=return_index, + return_counts=return_counts, + return_inverse=return_inverse, + axis=None, + ) + + y_fg = FunctionGraph(outputs=[y], copy_inputs=False) + y_rewritten_fg = rewrite_graph( + y_fg, clone=False, include=["canonicalize", "local_Unique_scalar"] + ) + y_rewritten = y_rewritten_fg.outputs[0] + y_rewritten_start = y_rewritten + + assert isinstance(y_rewritten_start.owner.op, DimShuffle) + assert y_rewritten_start.owner.inputs[0] == x + + default_mode = get_default_mode() + rewrite_mode = default_mode.excluding("local_Unique_scalar") + y_fn = function([x], [y, y_rewritten], mode=rewrite_mode) + + x_val = np.array(-10.0, dtype=np.float64) + y_exp_val, y_val = y_fn(x_val) + assert np.array_equal(y_exp_val, y_val) + + +@pytest.mark.parametrize( + "x_val, axis, new_shape", + [ + (np.array(-10, dtype=np.int64), None, ()), + (np.array(-10, dtype=np.int64), None, (2, 3)), + (np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)), + ], +) +@pytest.mark.parametrize("return_index", [False]) +@pytest.mark.parametrize("return_counts", [False]) +@pytest.mark.parametrize("return_inverse", [False]) +def test_local_Unique_Alloc_lift( + x_val, axis, new_shape, return_index, return_counts, return_inverse +): + x = as_tensor_variable(x_val).type() + y = unique( + alloc(x, *new_shape), + return_index=return_index, + return_counts=return_counts, + return_inverse=return_inverse, + axis=axis, + ) + + if isinstance(y, list): + y, *_ = y + + # This approach allows us to directly confirm that `x` is in the result. + y_fg = FunctionGraph(outputs=[y], copy_inputs=False) + y_rewritten_fg = rewrite_graph( + y_fg, + clone=False, + include=["canonicalize", "local_Unique_Alloc_lift"], + exclude=["local_Unique_scalar"], + ) + y_rewritten = y_rewritten_fg.outputs[0] + y_rewritten_start = y_rewritten + + assert isinstance(y_rewritten_start.owner.op, Unique) + assert y_rewritten_start.owner.inputs[0] == x + assert not any(isinstance(node.op, Alloc) for node in y_rewritten_fg.apply_nodes) + + default_mode = get_default_mode() + # The rewrite has already been applied to `y_rewritten`, so we can--and + # should--exclude it from the compilation of both our reference, `y`, and + # the rewritten result, `y_rewritten`. + # The remaining exclusions simply allow us to perform the check below that + # makes sure the original `Alloc` is present in our reference (sub)graph. + rewrite_mode = default_mode.excluding( + "local_useless_alloc", "local_alloc_sink_dimshuffle", "local_Unique_Alloc_lift" + ) + y_fn = function([x], [y, y_rewritten], mode=rewrite_mode) + # Make sure that the original `Alloc` is used to compute the reference `y` + # result + assert any(isinstance(node.op, Alloc) for node in y_fn.maker.fgraph.apply_nodes) + + y_exp_val, y_val = y_fn(x_val) + assert np.array_equal(y_exp_val, y_val) + + +@pytest.mark.parametrize( + "x_val, axis, new_shape", + [ + (np.array(-10, dtype=np.int64), None, (2, 3)), + (np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)), + ], +) +@pytest.mark.parametrize("return_index", [False]) +@pytest.mark.parametrize("return_counts", [False]) +@pytest.mark.parametrize("return_inverse", [False]) +def test_local_Unique_BroadcastTo( + x_val, axis, new_shape, return_index, return_counts, return_inverse +): + x = as_tensor_variable(x_val).type() + y = unique( + BroadcastTo()(x, tuple(new_shape)), + return_index=return_index, + return_counts=return_counts, + return_inverse=return_inverse, + axis=axis, + ) + + if isinstance(y, list): + y, *_ = y + + # This approach allows us to directly confirm that `x` is in the result. + y_fg = FunctionGraph(outputs=[y], copy_inputs=False) + y_rewritten_fg = rewrite_graph( + y_fg, + clone=False, + include=["canonicalize", "local_Unique_BroadcastTo_lift"], + exclude=["local_Unique_scalar"], + ) + y_rewritten = y_rewritten_fg.outputs[0] + y_rewritten_start = y_rewritten + + assert isinstance(y_rewritten_start.owner.op, Unique) + assert y_rewritten_start.owner.inputs[0] == x + assert not any( + isinstance(node.op, BroadcastTo) for node in y_rewritten_fg.apply_nodes + ) + + default_mode = get_default_mode() + # The rewrite has already been applied to `y_rewritten`, so we can--and + # should--exclude it from the compilation of both our reference, `y`, and + # the rewritten result, `y_rewritten`. + rewrite_mode = default_mode.excluding("local_Unique_BroadcastTo_lift") + y_fn = function([x], [y, y_rewritten], mode=rewrite_mode) + # Make sure that the original `BroadcastTo` is used to compute the + # reference `y` result + assert any( + isinstance(node.op, BroadcastTo) for node in y_fn.maker.fgraph.apply_nodes + ) + + y_exp_val, y_val = y_fn(x_val) + assert np.array_equal(y_exp_val, y_val) + + +@pytest.mark.parametrize( + "x_val, unique_axis, repeats, repeat_axis", + [ + (np.array([[-10, -3], [-10, 2]], dtype=np.int64), None, (1, 2), 0), + ], +) +@pytest.mark.parametrize("return_index", [False]) +@pytest.mark.parametrize("return_counts", [False]) +@pytest.mark.parametrize("return_inverse", [False]) +def test_local_Unique_Repeat( + x_val, + unique_axis, + repeats, + repeat_axis, + return_index, + return_counts, + return_inverse, +): + x = as_tensor_variable(x_val).type() + y = unique( + repeat(x, tuple(repeats), axis=repeat_axis), + return_index=return_index, + return_counts=return_counts, + return_inverse=return_inverse, + axis=unique_axis, + ) + + if isinstance(y, list): + y, *_ = y + + # This approach allows us to directly confirm that `x` is in the result. + y_fg = FunctionGraph(outputs=[y], copy_inputs=False) + y_rewritten_fg = rewrite_graph( + y_fg, + clone=False, + include=["canonicalize", "local_Unique_Repeat_lift"], + exclude=["local_Unique_scalar"], + ) + y_rewritten = y_rewritten_fg.outputs[0] + y_rewritten_start = y_rewritten + + assert isinstance(y_rewritten_start.owner.op, Unique) + assert y_rewritten_start.owner.inputs[0] == x + assert not any(isinstance(node.op, Repeat) for node in y_rewritten_fg.apply_nodes) + + default_mode = get_default_mode() + # The rewrite has already been applied to `y_rewritten`, so we can--and + # should--exclude it from the compilation of both our reference, `y`, and + # the rewritten result, `y_rewritten`. + rewrite_mode = default_mode.excluding("local_Unique_Repeat_lift") + y_fn = function([x], [y, y_rewritten], mode=rewrite_mode) + # Make sure that the original `BroadcastTo` is used to compute the + # reference `y` result + assert any(isinstance(node.op, Repeat) for node in y_fn.maker.fgraph.apply_nodes) + + y_exp_val, y_val = y_fn(x_val) + assert np.array_equal(y_exp_val, y_val) + + +@pytest.mark.parametrize( + "x_val, unique_axis, new_shape", + [ + (np.array(-10, dtype=np.int64), None, ()), + (np.array(-10, dtype=np.int64), None, (2, 3)), + (np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)), + ], +) +@pytest.mark.parametrize("return_index", [False]) +@pytest.mark.parametrize("return_counts", [False]) +@pytest.mark.parametrize("return_inverse", [False]) +def test_local_Unique_second( + x_val, unique_axis, new_shape, return_index, return_counts, return_inverse +): + x = as_tensor_variable(x_val).type() + a = np.zeros(tuple(new_shape), dtype=x.dtype) + y = unique( + second(a, x), + return_index=return_index, + return_counts=return_counts, + return_inverse=return_inverse, + axis=unique_axis, + ) + + if isinstance(y, list): + y, *_ = y + + # This approach allows us to directly confirm that `x` is in the result. + y_fg = FunctionGraph(outputs=[y], copy_inputs=False) + y_rewritten_fg = rewrite_graph( + y_fg, + clone=False, + include=["canonicalize", "local_Unique_second_lift"], + exclude=["local_Unique_scalar", "topo_constant_folding"], + ) + y_rewritten = y_rewritten_fg.outputs[0] + y_rewritten_start = y_rewritten + + assert isinstance(y_rewritten_start.owner.op, Unique) + + y_rewritten_start = y_rewritten_start.owner.inputs[0] + + if y_rewritten_start.owner and isinstance(y_rewritten_start.owner.op, DimShuffle): + y_rewritten_start = y_rewritten_start.owner.inputs[0] + + assert y_rewritten_start == x + assert not any( + isinstance(node.op.scalar_op, aes.Second) + for node in y_rewritten_fg.apply_nodes + if isinstance(node.op, Elemwise) + ) + + # The rewrite has already been applied to `y_rewritten`, so we can--and + # should--exclude it from the compilation of both our reference, `y`, and + # the rewritten result, `y_rewritten`. + y_fn = function([x], [y, y_rewritten], mode=Mode(optimizer=OPT_NONE)) + + # Make sure that the original `BroadcastTo` is used to compute the + # reference `y` result + assert any( + isinstance(node.op.scalar_op, aes.Second) + for node in y_fn.maker.fgraph.apply_nodes + if isinstance(node.op, Elemwise) + ) + + y_exp_val, y_val = y_fn(x_val) + assert np.array_equal(y_exp_val, y_val) + + +def test_local_remove_scalar_BroadcastTo(): + x = dscalar() + y = BroadcastTo()(x, ()) + + assert isinstance(y.owner.op, BroadcastTo) + + res = rewrite_graph( + y, clone=False, include=["canonicalize", "local_remove_scalar_BroadcastTo"] + ) + + assert res is x diff --git a/tests/tensor/test_math_opt.py b/tests/tensor/rewriting/test_math.py similarity index 92% rename from tests/tensor/test_math_opt.py rename to tests/tensor/rewriting/test_math.py index e71238a8f2..a73632db9d 100644 --- a/tests/tensor/test_math_opt.py +++ b/tests/tensor/rewriting/test_math.py @@ -18,19 +18,18 @@ from aesara.configdefaults import config from aesara.graph.basic import Apply, Constant, equal_computations from aesara.graph.fg import FunctionGraph -from aesara.graph.opt import ( - LocalOptGroup, - TopoOptimizer, +from aesara.graph.rewriting.basic import ( + SequentialNodeRewriter, + WalkingGraphRewriter, check_stack_trace, in2out, out2in, ) -from aesara.graph.opt_utils import is_same_graph, optimize_graph -from aesara.graph.optdb import OptimizationQuery +from aesara.graph.rewriting.db import RewriteDatabaseQuery +from aesara.graph.rewriting.utils import is_same_graph, rewrite_graph from aesara.misc.safe_asarray import _asarray from aesara.tensor import inplace from aesara.tensor.basic import Alloc, join, switch -from aesara.tensor.basic_opt import local_dimshuffle_lift from aesara.tensor.blas import Dot22, Gemv from aesara.tensor.blas_c import CGemv from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise @@ -80,7 +79,8 @@ from aesara.tensor.math import sgn, sigmoid, sin, sinh, softplus, sqr, sqrt, sub from aesara.tensor.math import sum as at_sum from aesara.tensor.math import tan, tanh, true_div, xor -from aesara.tensor.math_opt import ( +from aesara.tensor.rewriting.elemwise import local_dimshuffle_lift +from aesara.tensor.rewriting.math import ( compute_mul, is_1pexp, local_grad_log_erfc_neg, @@ -125,36 +125,36 @@ from tests import unittest_tools as utt -mode_opt = config.mode -if mode_opt == "FAST_COMPILE": - mode_opt = "FAST_RUN" -mode_opt = get_mode(mode_opt) +rewrite_mode = config.mode +if rewrite_mode == "FAST_COMPILE": + rewrite_mode = "FAST_RUN" +rewrite_mode = get_mode(rewrite_mode) dimshuffle_lift = out2in(local_dimshuffle_lift) -_optimizer_stabilize = OptimizationQuery(include=["fast_run"]) -_optimizer_stabilize.position_cutoff = 1.51 -_optimizer_stabilize = optdb.query(_optimizer_stabilize) +_stablize_rewrites = RewriteDatabaseQuery(include=["fast_run"]) +_stablize_rewrites.position_cutoff = 1.51 +_stablize_rewrites = optdb.query(_stablize_rewrites) -_optimizer_specialize = OptimizationQuery(include=["fast_run"]) -_optimizer_specialize.position_cutoff = 2.01 -_optimizer_specialize = optdb.query(_optimizer_specialize) +_specialize_rewrites = RewriteDatabaseQuery(include=["fast_run"]) +_specialize_rewrites.position_cutoff = 2.01 +_specialize_rewrites = optdb.query(_specialize_rewrites) -_optimizer_fast_run = OptimizationQuery(include=["fast_run"]) -_optimizer_fast_run = optdb.query(_optimizer_fast_run) +_fast_run_rewrites = RewriteDatabaseQuery(include=["fast_run"]) +_fast_run_rewrites = optdb.query(_fast_run_rewrites) def ds(x, y): return DimShuffle(x.type.broadcastable, y)(x) -def optimize(g, level="fast_run"): +def rewrite(g, level="fast_run"): if level == "fast_run": - _optimizer_fast_run.optimize(g) + _fast_run_rewrites.rewrite(g) elif level == "specialize": - _optimizer_specialize.optimize(g) + _specialize_rewrites.rewrite(g) elif level == "stabilize": - _optimizer_stabilize.optimize(g) + _stablize_rewrites.rewrite(g) else: raise ValueError(level) return g @@ -189,19 +189,19 @@ def test_main(self): # 1. ((a/x + b/y) * x * y) --> a*y + b*x e = (a / z + b / x) * x * z g = FunctionGraph([a, b, c, d, x, y, z], [e]) - mul_canonizer.optimize(g) - TopoOptimizer( - LocalOptGroup(local_greedy_distributor), order="out_to_in" - ).optimize(g) + mul_canonizer.rewrite(g) + WalkingGraphRewriter( + SequentialNodeRewriter(local_greedy_distributor), order="out_to_in" + ).rewrite(g) assert str(pprint(g.outputs[0])) == "((a * x) + (b * z))" # 2. ((a/x + b) * x) --> a + b*x e = (a / x + b) * x g = FunctionGraph([a, b, x], [e]) - mul_canonizer.optimize(g) - TopoOptimizer( - LocalOptGroup(local_greedy_distributor), order="out_to_in" - ).optimize(g) + mul_canonizer.rewrite(g) + WalkingGraphRewriter( + SequentialNodeRewriter(local_greedy_distributor), order="out_to_in" + ).rewrite(g) assert str(pprint(g.outputs[0])) == "(a + (b * x))" def test_kording_bug(self): @@ -251,15 +251,11 @@ class TestAlgebraicCanonizer: ], ) def test_muldiv(self, e, exp_g): - g_opt = optimize_graph(e, custom_opt=mul_canonizer) - assert equal_computations([g_opt], [exp_g]) - - def test_elemwise_multiple_inputs_optimisation(self): - # verify that the AlgebraicCanonizer merge sequential Elemwise({mul,add}) part 1 - # - # This part are that case that is done, but don't include case - # that are not implemented but are supposed to be. - # + g_rewritten = rewrite_graph(e, custom_rewrite=mul_canonizer) + assert equal_computations([g_rewritten], [exp_g]) + + def test_elemwise_multiple_inputs_rewrites(self): + """Verify that the `AlgebraicCanonizer` merges sequential ``Elemwise({mul,add})``.""" # Test with and without DimShuffle shp = (5, 5) fx, fy, fz = fmatrices("xyz") @@ -363,19 +359,18 @@ def test_elemwise_multiple_inputs_optimisation(self): ] # [10:11] # print cases - # We must be sure that the AlgebraicCanonizer is working, but that we don't have other - # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion + # We must be sure that the `AlgebraicCanonizer` is working, but that we don't have other + # rewrites that could hide bug in the `AlgebraicCanonizer` as `local_elemwise_fusion` mode = get_default_mode() - opt = OptimizationQuery(["canonicalize"]) - opt = opt.excluding("local_elemwise_fusion") - mode = mode.__class__(linker=mode.linker, optimizer=opt) + rewrites = RewriteDatabaseQuery(["canonicalize"]) + rewrites = rewrites.excluding("local_elemwise_fusion") + mode = mode.__class__(linker=mode.linker, optimizer=rewrites) for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases): if isinstance(out_dtype, dict): out_dtype = out_dtype[config.cast_policy] f = function( list(sym_inputs), g, - # we need the optimisation enabled, debug do this. mode=mode, ) @@ -386,11 +381,13 @@ def test_elemwise_multiple_inputs_optimisation(self): @pytest.mark.skip( reason="Current implementation of AlgebraicCanonizer does not implement all cases." ) - def test_elemwise_multiple_inputs_optimisation2(self): - # verify that the AlgebraicCanonizer merge sequential Elemwise({mul,add}) part 2. - # This part are that case that should have been done, but that are not implemented. - # Test with and without DimShuffle + def test_elemwise_multiple_inputs_rewrites_2(self): + """Verify that the `AlgebraicCanonizer` merges sequential ``Elemwise({mul,add})``. + + This part are that case that should have been done, but that are not implemented. + """ + # Test with and without `DimShuffle` shp = (5, 5) fx, fy, fz = fmatrices("xyz") dx, dy, dz = dmatrices("xyz") @@ -498,15 +495,14 @@ def test_elemwise_multiple_inputs_optimisation2(self): # print cases # We must be sure that the AlgebraicCanonizer is working, but that we don't have other - # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion + # rewrites that could hide bugs in the `AlgebraicCanonizer` as `local_elemwise_fusion` mode = get_default_mode() - mode._optimizer = OptimizationQuery(["canonicalize"]) + mode._optimizer = RewriteDatabaseQuery(["canonicalize"]) mode._optimizer = mode._optimizer.excluding("local_elemwise_fusion") for id, [g, sym_inputs, val_inputs, nb_elemwise, out_dtype] in enumerate(cases): f = function( list(sym_inputs), g, - # we need the optimisation enabled, debug do this. mode=mode, ) @@ -514,16 +510,20 @@ def test_elemwise_multiple_inputs_optimisation2(self): assert len(f.maker.fgraph.toposort()) == nb_elemwise assert out_dtype == out.dtype - def test_multiple_case(self): - # test those case take from the comment in AlgebraicCanonizer - # x / x -> 1 - # (x * y) / x -> y - # x / y / x -> 1 / y - # x / y / z -> x / (y * z) - # x / (y / z) -> (x * z) / y - # (a / b) * (b / c) * (c / d) -> a / d - # (2.0 * x) / (4.0 * y) -> (0.5 * x) / y - # 2 * x / 2 -> x + def test_mul_div_cases(self): + """ + TODO + + x / x -> 1 + (x * y) / x -> y + x / y / x -> 1 / y + x / y / z -> x / (y * z) + x / (y / z) -> (x * z) / y + (a / b) * (b / c) * (c / d) -> a / d + (2.0 * x) / (4.0 * y) -> (0.5 * x) / y + 2 * x / 2 -> x + + """ # with and without DimShuffle # TODO: with DimShuffle @@ -543,14 +543,14 @@ def test_multiple_case(self): dwv = _asarray(np.random.random(shp), dtype="float64") dvv = _asarray(np.random.random((shp[0])), dtype="float64").reshape(1, shp[0]) - # We must be sure that the AlgebraicCanonizer is working, but that we don't have other - # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion + # We must be sure that the `AlgebraicCanonizer` is working, but that we don't have other + # rewrites that could hide bugs in the `AlgebraicCanonizer` as `local_elemwise_fusion` mode = get_default_mode() - opt = OptimizationQuery(["canonicalize"]) - opt = opt.including("ShapeOpt", "local_fill_to_alloc") - opt = opt.excluding("local_elemwise_fusion") - mode = mode.__class__(linker=mode.linker, optimizer=opt) + rewrite_query = RewriteDatabaseQuery(["canonicalize"]) + rewrite_query = rewrite_query.including("ShapeOpt", "local_fill_to_alloc") + rewrite_query = rewrite_query.excluding("local_elemwise_fusion") + mode = mode.__class__(linker=mode.linker, optimizer=rewrite_query) # test x / x -> 1 for id, (g, sym_inputs, val_inputs, out_dtype) in enumerate( [ @@ -855,8 +855,7 @@ def test_multiple_case(self): assert out_dtype == out.dtype def test_abs_mul_div(self): - # test that if we have - # 4 * x / abs(2*x) it get simplifier during canonicalisation. + """Test that ``4 * x / abs(2*x)`` gets "simplified" during canonicalization.""" x = dscalar() # a = at.at_abs(x) @@ -869,8 +868,8 @@ def test_abs_mul_div(self): f = function([x], [(4 * x) / abs(2 * x)], mode=mode) f(0.1) f(-1) - # some stabilization optimization make the output be finite instead of nan - # debug_mode will raise an error when he see nan + # Some stabilization rewrites make the output finite instead of NaN. + # `debug_mode` will raise an error when he see NaN if not isinstance(mode, DebugMode): assert np.isfinite(f(0)) @@ -880,8 +879,6 @@ def test_abs_mul_div(self): f = function([x], [(4 * x) / abs(x / 2)], mode=mode) f(0.1) f(-1) - # some stabilization optimization make the output be finite instead of nan - # debug_mode will raise an error when he see nan if not isinstance(mode, DebugMode): assert np.isfinite(f(0)) @@ -903,13 +900,12 @@ def test_multiple_case_that_fail(self): dyv = _asarray(np.random.random(shp), dtype="float32") dzv = _asarray(np.random.random(shp), dtype="float32") # fvv = _asarray(np.random.random((shp[0]), dtype='float32').reshape(1, shp[0]) - # We must be sure that the AlgebraicCanonizer is working, but that we don't have other - # optimisation that could hide bug in the AlgebraicCanonizer as local_elemwise_fusion + mode = get_default_mode() - opt = OptimizationQuery(["canonicalize"]) - opt = opt.excluding("local_elemwise_fusion") - mode = mode.__class__(linker=mode.linker, optimizer=opt) + rewrites = RewriteDatabaseQuery(["canonicalize"]) + rewrites = rewrites.excluding("local_elemwise_fusion") + mode = mode.__class__(linker=mode.linker, optimizer=rewrites) # test fail! # test x / y / z -> x / (y * z) for (g, sym_inputs, val_inputs, out_dtype) in [ @@ -944,19 +940,19 @@ def test_multiple_case_that_fail(self): def test_canonicalize_nan(self): # Regression test for bug in canonicalization of NaN values. # This bug caused an infinite loop which was caught by the equilibrium - # optimizer, resulting in an error log message. + # rewriter, resulting in an error log message. sio = StringIO() handler = logging.StreamHandler(sio) handler.setLevel(logging.ERROR) - logging.getLogger("aesara.graph.opt").addHandler(handler) + logging.getLogger("aesara.graph.rewriting.basic").addHandler(handler) try: x = vector() function([x], x + np.nan) finally: - logging.getLogger("aesara.graph.opt").removeHandler(handler) + logging.getLogger("aesara.graph.rewriting.basic").removeHandler(handler) # Ideally this test would only catch the maxed out equilibrium - # optimizer error message, but to be safe in case this message + # rewriter error message, but to be safe in case this message # is modified in the future, we assert that there is no error # at all. assert not sio.getvalue() @@ -970,9 +966,11 @@ def test_mismatching_types(self): z.owner.op, z.owner.inputs, [tensor("float64", (None, None))] ).outputs[0] - z_opt = optimize_graph(z, custom_opt=in2out(local_mul_canonizer, name="blah")) + z_rewritten = rewrite_graph( + z, custom_rewrite=in2out(local_mul_canonizer, name="blah") + ) # No rewrite was applied - assert z_opt is z + assert z_rewritten is z def test_local_merge_abs(): @@ -997,8 +995,9 @@ def test_local_merge_abs(): def test_merge_abs_bugfix(): - # Test crash in optimization reported by Jeremiah Lowin at - # https://groups.google.com/d/topic/theano-users/TaXfqXP2Mj0/discussion + """ + See https://groups.google.com/d/topic/theano-users/TaXfqXP2Mj0/discussion + """ input = matrix() # normalize on cols step1 = input / input.sum(0) @@ -1074,7 +1073,7 @@ def test_cast_in_mul_canonizer(): class TestFusion: - opts = OptimizationQuery( + rewrites = RewriteDatabaseQuery( include=[ "local_elemwise_fusion", "composite_elemwise_fusion", @@ -1083,7 +1082,7 @@ class TestFusion: ], exclude=["cxx_only", "BlasOpt"], ) - mode = Mode(get_default_mode().linker, opts) + mode = Mode(get_default_mode().linker, rewrites) _shared = staticmethod(shared) topo_exclude = () @@ -1782,7 +1781,7 @@ def my_init(shp, dtype="float64", num=0): def test_add_mul_fusion_inplace(self): - opts = OptimizationQuery( + rewrites_query = RewriteDatabaseQuery( include=[ "local_elemwise_fusion", "composite_elemwise_fusion", @@ -1792,7 +1791,7 @@ def test_add_mul_fusion_inplace(self): exclude=["cxx_only", "BlasOpt"], ) - mode = Mode(self.mode.linker, opts) + mode = Mode(self.mode.linker, rewrites_query) x, y, z = dmatrices("xyz") out = dot(x, y) + x + y + z @@ -1883,7 +1882,7 @@ def test_local_log_add_exp(): assert np.isfinite(f([10000], [10000])) # causes overflow if handled incorrectly utt.assert_allclose(f([10000], [10000]), 10000 + np.log1p(1)) - # test that when max = +-inf, optimized output still works correctly + # test that when max = +-inf, rewritten output still works correctly assert f([-np.inf], [-np.inf]) == -np.inf assert f([np.inf], [np.inf]) == np.inf assert f([np.inf], [-np.inf]) == np.inf @@ -1896,7 +1895,7 @@ def test_local_log_add_exp(): assert np.isfinite(f([10000], [10000])) # causes overflow if handled incorrectly utt.assert_allclose(f([10000], [10000]), 20000) - # TODO: test that the optimization works in the presence of broadcasting. + # TODO: test that the rewrite works in the presence of broadcasting. def test_local_subtensor_of_dot(): @@ -1942,8 +1941,6 @@ def test_equality(a, b): def test_local_elemwise_sub_zeros(): - # Test opt local_elemwise_sub_zeros - # We test separately for scalars, vectors and matrices scal = scalar() vect = vector() mat = matrix() @@ -1967,38 +1964,32 @@ def test_local_elemwise_sub_zeros(): # Test scalar minus scalar f = function([scal], scal - scal, mode=mode) - # Check optimized graph is correct assert isinstance(f.maker.fgraph.toposort()[0].op, Elemwise) assert isinstance(f.maker.fgraph.toposort()[0].op.scalar_op, aes.Second) assert isinstance( f.maker.fgraph.toposort()[0].inputs[1], TensorConstant ) or isinstance(f.maker.fgraph.toposort()[0].inputs[1], TensorConstant) utt.assert_allclose(f(scalar_val), 0.0) - # Check stack trace is copied over assert check_stack_trace(f, ops_to_check="all") # Test vector minus vector f = function([vect], vect - vect, mode=mode) - # Check optimized graph is correct assert isinstance(f.maker.fgraph.toposort()[0].op, Elemwise) assert isinstance(f.maker.fgraph.toposort()[0].op.scalar_op, aes.Second) assert isinstance( f.maker.fgraph.toposort()[0].inputs[1], TensorConstant ) or isinstance(f.maker.fgraph.toposort()[0].inputs[1], TensorConstant) utt.assert_allclose(f(vect_val), np.zeros(vect_val.shape)) - # Check stack trace is copied over assert check_stack_trace(f, ops_to_check="all") # Test vector minus vector f = function([mat], mat - mat, mode=mode) - # Check optimized graph is correct assert isinstance(f.maker.fgraph.toposort()[0].op, Elemwise) assert isinstance(f.maker.fgraph.toposort()[0].op.scalar_op, aes.Second) assert isinstance( f.maker.fgraph.toposort()[0].inputs[1], TensorConstant ) or isinstance(f.maker.fgraph.toposort()[0].inputs[1], TensorConstant) utt.assert_allclose(f(mat_val), np.zeros(mat_val.shape)) - # Check stack trace is copied over assert check_stack_trace(f, ops_to_check="all") @@ -2161,7 +2152,7 @@ def test_shape_inequality_with_self(self): self.assert_eqs_const(f, 0) assert f(x_val) == 0 f = function([x], minimum([0, 0], x.shape[0]), mode=mode) - # This case isn't optimized. + # This case isn't rewritten. # self.assert_eqs_const(f, 0) utt.assert_allclose(f(x_val), [0, 0]) @@ -2184,7 +2175,7 @@ def test_shape_add_inequality(self): @pytest.mark.skipif( config.mode == "FAST_COMPILE", - reason="Skip opt test as the opt is disabled", + reason="This rewrite is disabled.", ) def test_equality_shapes(self): # Test equality where one sides contain only shapes related @@ -2331,10 +2322,10 @@ def speed_local_pow_specialize_range(): val = np.random.random((1e7)) v = vector() mode = get_default_mode() - mode_without_pow_opt = mode.excluding("local_pow_specialize") + mode_without_pow_rewrite = mode.excluding("local_pow_specialize") for i in range(500, 513): f1 = function([v], v**i, mode=mode) - f2 = function([v], v**i, mode=mode_without_pow_opt) + f2 = function([v], v**i, mode=mode_without_pow_rewrite) assert len(f1.maker.fgraph.toposort()) == 1 t1 = time.time() f1(val) @@ -2346,7 +2337,7 @@ def speed_local_pow_specialize_range(): print("WARNING WE ARE SLOWER") for i in range(-3, -1500, -1): f1 = function([v], v**i, mode=mode) - f2 = function([v], v**i, mode=mode_without_pow_opt) + f2 = function([v], v**i, mode=mode_without_pow_rewrite) assert len(f1.maker.fgraph.toposort()) == 1 t1 = time.time() f1(val) @@ -2455,10 +2446,10 @@ def setup_method(self): mode = get_default_mode() self.mode = mode.including("local_func_inv") - def assert_func_pair_optimized( + def assert_func_pair_rewritten( self, func1, func2, data, should_copy=True, is_complex=False ): - # Check that a pair of funcs is optimized properly + """Check that a pair of functions are rewritten properly.""" x = cmatrix() if is_complex else fmatrix() o = func2(func1(x)) @@ -2484,30 +2475,30 @@ def assert_func_pair_optimized( ), "Inverse functions not removed!" def test(self): - # test optimization for consecutive functional inverses + """Test rewrites for consecutive functional inverses.""" dx = np.random.random((5, 4)).astype("float32") - self.assert_func_pair_optimized(deg2rad, rad2deg, dx) + self.assert_func_pair_rewritten(deg2rad, rad2deg, dx) dx = np.random.random((5, 4)).astype("float32") * 180 - self.assert_func_pair_optimized(rad2deg, deg2rad, dx) + self.assert_func_pair_rewritten(rad2deg, deg2rad, dx) # Test the other functional inverses dx = np.random.random((5, 4)).astype("float32") - self.assert_func_pair_optimized(cosh, arccosh, dx) - self.assert_func_pair_optimized(arcsinh, sinh, dx) - self.assert_func_pair_optimized(arctanh, tanh, dx) - self.assert_func_pair_optimized(reciprocal, reciprocal, dx) - self.assert_func_pair_optimized(neg, neg, dx) + self.assert_func_pair_rewritten(cosh, arccosh, dx) + self.assert_func_pair_rewritten(arcsinh, sinh, dx) + self.assert_func_pair_rewritten(arctanh, tanh, dx) + self.assert_func_pair_rewritten(reciprocal, reciprocal, dx) + self.assert_func_pair_rewritten(neg, neg, dx) cx = dx + complex(0, 1) * (dx + 0.01) - self.assert_func_pair_optimized(conj, conj, cx, is_complex=True) + self.assert_func_pair_rewritten(conj, conj, cx, is_complex=True) # Test that non-inverse functions are ran normally - self.assert_func_pair_optimized( + self.assert_func_pair_rewritten( conj, neg, cx, should_copy=False, is_complex=True ) dx = np.random.random((5, 4)).astype("float32") + 0.01 - self.assert_func_pair_optimized(rad2deg, rad2deg, dx, should_copy=False) - self.assert_func_pair_optimized(rad2deg, cosh, dx, should_copy=False) + self.assert_func_pair_rewritten(rad2deg, rad2deg, dx, should_copy=False) + self.assert_func_pair_rewritten(rad2deg, cosh, dx, should_copy=False) def test_integer_upcast(self): """ @@ -2732,10 +2723,9 @@ def setup_method(self): self.mode.check_isfinite = False def function_remove_nan(self, *args, **kwargs): - """ - Wrapper around function for this test. + """Wrapper around function for this test. - It disables checking for NaN removed by optimizations in DebugMode + It disables checking for NaNs removed by rewrites in `DebugMode` (it has false positives in that case). """ f = function(*args, **kwargs) @@ -2786,7 +2776,7 @@ def test_local_mul_switch_sink(self): ].size idx += 1 - # This case caused a missed optimization in the past. + # This case prevented a rewrite from being applied in the past x = dscalar("x") y = at.switch(x < 7, x, sqrt(x - 7)) f = self.function_remove_nan([x], aesara.gradient.grad(y, x), self.mode) @@ -2923,7 +2913,7 @@ def setup_method(self): self.mode = self.mode_fusion.excluding("fusion") def test_local_one_minus_erfc(self): - # test opt: 1-erfc(x) => erf(x) and -erfc(x)+1 => erf(x) + """Test the rewrites ``1 - erfc(x) -> erf(x)`` and ``-erfc(x) + 1 -> erf(x)``.""" val = np.asarray([-30, -3, -2, -1, 0, 1, 2, 3, 30], dtype=config.floatX) x = vector("x") @@ -2946,7 +2936,7 @@ def test_local_one_minus_erfc(self): assert isinstance(topo[1].op.scalar_op, aes.Sub) def test_local_erf_neg_minus_one(self): - # test opt: (-1)+erfc(-x)=>erf(x) + """Test the rewrite ``-1 + erfc(-x) -> erf(x)``.""" val = np.asarray([-30, -3, -2, -1, 0, 1, 2, 3, 30], dtype=config.floatX) x = vector("x") @@ -3052,9 +3042,9 @@ def test_local_grad_log_erfc_neg(self): for inputs, no_match in no_matches: fg = FunctionGraph(inputs, [no_match], clone=False) - TopoOptimizer( - LocalOptGroup(local_grad_log_erfc_neg), order="out_to_in" - ).optimize(fg) + WalkingGraphRewriter( + SequentialNodeRewriter(local_grad_log_erfc_neg), order="out_to_in" + ).rewrite(fg) # Make sure that the graph hasn't been changed assert fg.outputs[0] is no_match @@ -3155,7 +3145,7 @@ def test_elemwise(self): neq, at_pow, ): - g = optimize(FunctionGraph(mats, [op(s1, s2)])) + g = rewrite(FunctionGraph(mats, [op(s1, s2)])) assert str(g).count("Switch") == 1 # integer Ops mats = imatrices("cabxy") @@ -3167,26 +3157,24 @@ def test_elemwise(self): bitwise_or, bitwise_xor, ): - g = optimize(FunctionGraph(mats, [op(s1, s2)])) + g = rewrite(FunctionGraph(mats, [op(s1, s2)])) assert str(g).count("Switch") == 1 # add/mul with more than two inputs u, v = matrices("uv") s3 = at.switch(c, u, v) for op in (add, mul): - g = optimize(FunctionGraph(mats + [u, v], [op(s1, s2, s3)])) + g = rewrite(FunctionGraph(mats + [u, v], [op(s1, s2, s3)])) assert str(g).count("Switch") == 1 class TestLocalSumProd: - """ - Test sum/prod opts in opt.py - """ + """Test sum/prod rewrites.""" def setup_method(self): self.mode = get_default_mode().including("canonicalize", "specialize") def test_local_sum_prod_mul_by_scalar(self): - # Test the optimization local_sum_prod_mul_by_scalar for both Sum and + # Test the rewrite `local_sum_prod_mul_by_scalar` for both Sum and # Prod ops in six cases each : # 1-the inputs to the mul contain a scalar and no non-scalar # 2-the inputs to the mul contain a scalar and one non-scalar @@ -3205,7 +3193,7 @@ def test_local_sum_prod_mul_by_scalar(self): s1_val = np.random.random() s2_val = np.random.random() - def test_reduction_opt( + def test_reduction_rewrite( inputs, inputs_val, reduction_op, expected_output, nb_expected_sum_nodes ): mul_out = mul(*inputs) @@ -3213,8 +3201,8 @@ def test_reduction_opt( out = f(*inputs_val) utt.assert_allclose(out, expected_output) - # Ensure that the optimization has been applied properly by - # ensuring that the optimized graph contains the expected number + # Ensure that the rewrite has been applied properly by + # ensuring that the rewritten graph contains the expected number # of apply nodes for the sum op prod_nodes = [ n for n in f.maker.fgraph.toposort() if isinstance(n.op, reduction_op) @@ -3224,15 +3212,15 @@ def test_reduction_opt( # Test sum # Case 1 - test_reduction_opt([scalar1], [s1_val], Sum, s1_val, 0) + test_reduction_rewrite([scalar1], [s1_val], Sum, s1_val, 0) # Case 2 - test_reduction_opt( + test_reduction_rewrite( [vect, scalar1], [v_val, s1_val], Sum, s1_val * v_val.sum(), 1 ) # Case 3 - test_reduction_opt( + test_reduction_rewrite( [vect, mat, scalar1], [v_val, m_val, s1_val], Sum, @@ -3241,12 +3229,12 @@ def test_reduction_opt( ) # Case 4 - test_reduction_opt( + test_reduction_rewrite( [scalar1, scalar2], [s1_val, s2_val], Sum, s1_val * s2_val, 0 ) # Case 5 - test_reduction_opt( + test_reduction_rewrite( [vect, scalar1, scalar2], [v_val, s1_val, s2_val], Sum, @@ -3255,7 +3243,7 @@ def test_reduction_opt( ) # Case 6 - test_reduction_opt( + test_reduction_rewrite( [vect, mat, scalar1, scalar2], [v_val, m_val, s1_val, s2_val], Sum, @@ -3266,10 +3254,10 @@ def test_reduction_opt( # Test prod # Case 1 - test_reduction_opt([scalar1], [s1_val], Prod, s1_val, 0) + test_reduction_rewrite([scalar1], [s1_val], Prod, s1_val, 0) # Case 2 - test_reduction_opt( + test_reduction_rewrite( [vect, scalar1], [v_val, s1_val], Prod, @@ -3278,7 +3266,7 @@ def test_reduction_opt( ) # Case 3 - test_reduction_opt( + test_reduction_rewrite( [vect, mat, scalar1], [v_val, m_val, s1_val], Prod, @@ -3287,12 +3275,12 @@ def test_reduction_opt( ) # Case 4 - test_reduction_opt( + test_reduction_rewrite( [scalar1, scalar2], [s1_val, s2_val], Prod, s1_val * s2_val, 0 ) # Case 5 - test_reduction_opt( + test_reduction_rewrite( [vect, scalar1, scalar2], [v_val, s1_val, s2_val], Prod, @@ -3301,7 +3289,7 @@ def test_reduction_opt( ) # Case 6 - test_reduction_opt( + test_reduction_rewrite( [vect, mat, scalar1, scalar2], [v_val, m_val, s1_val, s2_val], Prod, @@ -3418,7 +3406,7 @@ def my_sum_prod(data, d, dd): utt.assert_allclose(f(input), input.prod()) assert len(f.maker.fgraph.apply_nodes) == 1 - # test sum prod don't get opt. + # Test that sum prod didn't get rewritten. for d, dd in dims: expected = my_sum_prod(input, d, dd) f = function([a], a.sum(d).prod(dd), mode=self.mode) @@ -3437,7 +3425,6 @@ def my_sum_prod(data, d, dd): assert len(f.maker.fgraph.apply_nodes) == 1 def test_local_sum_prod_alloc(self): - # test local_opt_alloc a = dtensor3() input = np.asarray(np.arange(2 * 3 * 4).reshape(2, 3, 4), dtype="float64") mode = self.mode.including("specialize").excluding("fusion") @@ -3503,8 +3490,10 @@ def test_local_sum_prod_alloc(self): assert not any(isinstance(node.op, Sum) for node in topo) def test_local_sum_sum_int8(self): - # Test that local_sum_sum works when combining two sums on an int8 array. - # This is a regression test for ticket gh-356. + """Test that `local_sum_sum` works when combining two sums on an int8 array. + + This is a regression test for ticket gh-356. + """ x = tensor3(dtype="int8") y = x.sum(axis=0).sum(axis=1) @@ -3514,7 +3503,7 @@ def test_local_sum_sum_int8(self): function([x], y) def test_local_sum_sum_dtype(self): - # Test that local_sum_sum works when specifying dtypes manually. + """Test that `local_sum_sum` works when specifying dtypes manually.""" x = tensor3(dtype="int8") y = x.sum(axis=0, dtype="int32").sum(axis=1, dtype="int64") @@ -3524,7 +3513,7 @@ def test_local_sum_sum_dtype(self): function([x], y) def test_local_sum_prod_mul_by_scalar_stack_trace(self): - # Test that stack trace is copied over correctly for local_sum_prod_mul_by_scalar. + """Test that stack trace is copied over correctly for `local_sum_prod_mul_by_scalar`.""" m0 = ( get_default_mode() .excluding("inplace_elemwise_opt") @@ -3609,9 +3598,9 @@ def test_local_reduce_broadcast_some_0(self): op = node.op assert isinstance(op, CAReduce) - # -- the leading broadcastable dimension has been dropped - # by the local_reduce_broadcastable optimization - # now summation is over the original x's dimension 1. + # The leading broadcastable dimension has been dropped by the + # `local_reduce_broadcastable` rewrite. Now, summation is over + # the original `x`'s dimension 1. assert node.inputs[0].ndim == 2, node assert op.axis == (0,), op.axis @@ -3668,7 +3657,7 @@ def test_local_reduce_join(self): topo = f.maker.fgraph.toposort() assert not isinstance(topo[-1].op, Elemwise) - # This case could be optimized + # This case could be rewritten A = shared(np.array([1, 2, 3, 4, 5]).reshape(5, 1)) f = function([], at_sum(at.concatenate((A, A), axis=1), axis=1), mode=self.mode) utt.assert_allclose(f(), [2, 4, 6, 8, 10]) @@ -3681,7 +3670,7 @@ def test_local_reduce_join(self): topo = f.maker.fgraph.toposort() assert not isinstance(topo[-1].op, Elemwise) - # Test that the optimization does not crash in one case where it + # Test that the rewrite does not crash in one case where it # is not applied. Reported at # https://groups.google.com/d/topic/theano-users/EDgyCU00fFA/discussion out = at_sum([vx, vy, vz], axis=None) @@ -3795,27 +3784,29 @@ def test_local_prod_div_dimshuffle(self): d_val = np.asarray(rng.standard_normal(), config.floatX) default_mode = get_default_mode() - # FusionOptimizer is included to make sure that expected_outer_operator - # remains the same for all optimization modes. - mode_with_opt = default_mode.including( + # `FusionOptimizer` is included to make sure that `expected_outer_operator` + # remains the same for all rewrite modes. + mode_with_rewrite = default_mode.including( "local_sum_prod_div_dimshuffle", "FusionOptimizer" ) - mode_without_opt = default_mode.excluding("local_sum_prod_div_dimshuffle") + mode_without_rewrite = default_mode.excluding("local_sum_prod_div_dimshuffle") # Numerical tests: tests whether the numerical values with and without - # optimizer are equal or not. + # rewrites are equal or not. for i, s in enumerate(prods): f = function( - [a, b, c, d], s, on_unused_input="ignore", mode=mode_without_opt + [a, b, c, d], s, on_unused_input="ignore", mode=mode_without_rewrite + ) + g = function( + [a, b, c, d], s, on_unused_input="ignore", mode=mode_with_rewrite ) - g = function([a, b, c, d], s, on_unused_input="ignore", mode=mode_with_opt) utt.assert_allclose( f(a_val, b_val, c_val, d_val), g(a_val, b_val, c_val, d_val) ) - # Logical tests: tests whether the optimizer has been appplied or not - # by checking graph structure. + # Logical tests: tests whether the rewrite has been appplied or not + # by checking graph structure. prods = [ prod(a / e), prod(a / d), @@ -3840,7 +3831,7 @@ def test_local_prod_div_dimshuffle(self): for i, s in enumerate(prods): g = function( - [a, b, c, d, e], s, on_unused_input="ignore", mode=mode_with_opt + [a, b, c, d, e], s, on_unused_input="ignore", mode=mode_with_rewrite ) assert isinstance( g.maker.fgraph.toposort()[-1].op.scalar_op, expected_outer_operator[i] @@ -3857,32 +3848,36 @@ def test_local_useless_adds(): # Test for all zeros a = scalar() s = add(at.zeros_like(a)) - mode_with_opt = default_mode.including("canonicalization", "local_useless_fill") - f = function([a], s, mode=mode_with_opt) + mode_with_rewrite = default_mode.including("canonicalization", "local_useless_fill") + f = function([a], s, mode=mode_with_rewrite) assert not any(node.op == add for node in f.maker.fgraph.apply_nodes) # test of non-zero dimension a = vector() s = add(at.zeros_like(a)) - mode_with_opt = default_mode.including("canonicalization", "local_useless_elemwise") - f = function([a], s, mode=mode_with_opt) + mode_with_rewrite = default_mode.including( + "canonicalization", "local_useless_elemwise" + ) + f = function([a], s, mode=mode_with_rewrite) assert not any(node.op == add for node in f.maker.fgraph.apply_nodes) # test of 0-d a = scalar() s = add(at.zeros_like(a)) - mode_with_opt = default_mode.including( + mode_with_rewrite = default_mode.including( "canonicalization", "local_useless_fill", "local_useless_elemwise" ) - f = function([a], s, mode=mode_with_opt) + f = function([a], s, mode=mode_with_rewrite) assert not any(node.op == add for node in f.maker.fgraph.apply_nodes) # Test when the 0 input is forcing upcasting a = at.constant(0, dtype="int64") b = at.constant(1, dtype="int32") s = a + b - mode_with_opt = default_mode.including("canonicalization", "local_add_canonizer") - f = function([], s, mode=mode_with_opt) + mode_with_rewrite = default_mode.including( + "canonicalization", "local_add_canonizer" + ) + f = function([], s, mode=mode_with_rewrite) transformed = f.maker.fgraph.outputs[0] assert not any(node.op == add for node in f.maker.fgraph.apply_nodes) assert transformed.type == s.type @@ -3910,9 +3905,8 @@ def setup_method(self): self.mode = get_default_mode() self.mode = self.mode.including("local_intdiv_by_one") - def test1(self): - # Tests removing the extra floor_div by 1 introduced by - # local_subtensor_merge optimization + def test_remove_floor(self): + """Tests removing the extra floor_div by 1 introduced by `local_subtensor_merge` rewrite.""" y = tensor4("y") self.mode = self.mode.excluding("fusion") @@ -3962,8 +3956,8 @@ def test_local_zero_div(t, op): """Test the canonicalization ``0/x -> 0``.""" x = t("x") y = op(0, x) - g = optimize(FunctionGraph([x], [y])) - # the division should be gone + g = rewrite(FunctionGraph([x], [y])) + # The division should be gone divs = [ node for node in g.toposort() @@ -3971,11 +3965,11 @@ def test_local_zero_div(t, op): and isinstance(node.op.scalar_op, type(op.scalar_op)) ] assert len(divs) == 0 - # the output type should match the unoptimized one + # The output type should match the un-rewritten one output = g.outputs[0] assert output.ndim == y.ndim assert output.type == y.type - # and the output should be zero + # The output should be zero if output.owner and isinstance(output.owner.op, Alloc): out_var = output.owner.inputs[0] else: @@ -4074,16 +4068,18 @@ def check_max_log_sum_exp(x, axis, dimshuffle_op=None): ): return - # in mode FAST_COMPILE, the optimisations don't replace the - # MaxAndArgmax op. + # In mode FAST_COMPILE, the rewrites don't replace the + # `MaxAndArgmax` `Op`. if isinstance(node.op, MaxAndArgmax): return - raise Exception("No maximum detected after log_sum_exp optimisation") + # TODO FIXME: Refactor this test so that it makes a direct assertion and + # nothing more. + raise AssertionError("No maximum detected after log_sum_exp rewrite") -def test_local_log_sum_exp1(): - # Tests if optimization is applied by checking the presence of the maximum +def test_local_log_sum_exp_maximum(): + """Test that the rewrite is applied by checking the presence of the maximum.""" x = tensor3("x") check_max_log_sum_exp(x, axis=(0,), dimshuffle_op=None) check_max_log_sum_exp(x, axis=(1,), dimshuffle_op=None) @@ -4101,39 +4097,38 @@ def test_local_log_sum_exp1(): check_max_log_sum_exp(x, axis=(0, 1), dimshuffle_op=sum_keepdims_op) -def test_local_log_sum_exp2(): - # Tests if the optimization works (result is correct) around 1.0 +def test_local_log_sum_exp_near_one(): + """Test that the rewritten result is correct around 1.0.""" x = tensor3("x") x_val = 1.0 + np.random.random((4, 3, 2)).astype(config.floatX) / 10.0 f = compile_graph_log_sum_exp(x, axis=(1,)) naive_ret = np.log(np.sum(np.exp(x_val), axis=1)) - optimised_ret = f(x_val) - assert np.allclose(naive_ret, optimised_ret) + rewritten_ret = f(x_val) + assert np.allclose(naive_ret, rewritten_ret) # If a transpose is applied transpose_op = DimShuffle((False, False), (1, 0)) f = compile_graph_log_sum_exp(x, axis=(1,), dimshuffle_op=transpose_op) naive_ret = np.log(np.sum(np.exp(x_val), axis=1).T) - optimised_ret = f(x_val) - assert np.allclose(naive_ret, optimised_ret) + rewritten_ret = f(x_val) + assert np.allclose(naive_ret, rewritten_ret) -def test_local_log_sum_exp3(): - # Tests if the optimization works (result is correct) for extreme value 100 +def test_local_log_sum_exp_large(): + """Test that the rewrite result is correct for extreme value 100.""" x = vector("x") f = compile_graph_log_sum_exp(x, axis=0) x_val = np.array([-100.0, 100.0]).astype(config.floatX) - optimised_ret = f(x_val) - - assert np.allclose(optimised_ret, 100.0) + rewritten_ret = f(x_val) + assert np.allclose(rewritten_ret, 100.0) def test_local_log_sum_exp_inf(): - # Test that when max = +-inf, optimized output still works correctly + """Test that when max = +-inf, the rewritten output still works correctly.""" x = vector("x") f = compile_graph_log_sum_exp(x, axis=0) @@ -4145,20 +4140,25 @@ def test_local_log_sum_exp_inf(): def test_local_reciprocal_1_plus_exp(): x = vector("x") y = at.reciprocal(1 + exp(x)) - z = optimize_graph(y, include=["canonicalization", "stabilize", "specialize"]) + z = rewrite_graph(y, include=["canonicalization", "stabilize", "specialize"]) assert z.owner.op == sigmoid -class TestSigmoidOpts: +class TestSigmoidRewrites: def get_mode(self, excluding=None): """ Return appropriate mode for the tests. - :param excluding: List of optimizations to exclude. + Parameters + ---------- + excluding + List of rewrites to exclude. - :return: The current default mode unless the `config.mode` option is + Returns + ------- + The current default mode unless the `config.mode` option is set to 'FAST_COMPILE' (in which case it is replaced by the 'FAST_RUN' - mode), without the optimizations specified in `excluding`. + mode), without the rewrites specified in `excluding`. """ if excluding is None: excluding = [] @@ -4310,28 +4310,29 @@ def test_local_1msigmoid(self): m = self.get_mode(excluding=["fusion", "inplace"]) x = fmatrix() - # tests exp_over_1_plus_exp + # Test `exp_over_1_plus_exp` f = aesara.function([x], 1 - exp(x) / (1 + exp(x)), mode=m) - # FIXME: PatternSub does not copy stack trace + # FIXME: PatternNodeRewriter does not copy stack trace # (see https://github.com/Theano/Theano/issues/4581) # assert check_stack_trace(f, ops_to_check=[neg, sigmoid]) assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid] - # tests inv_1_plus_exp + # Test `inv_1_plus_exp` f = aesara.function([x], 1 - at.fill(x, 1.0) / (1 + exp(-x)), mode=m) # assert check_stack_trace(f, ops_to_check=[neg, sigmoid]) assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid] - # Tests float constant + # Test float constant f = aesara.function( [x], np.array(1.000001, dtype="float32") - sigmoid(x), mode=m ) assert [node.op for node in f.maker.fgraph.toposort()] == [neg, sigmoid] def test_local_sigm_times_exp(self): - # Test the `local_sigm_times_exp` optimization. - # exp(x) * sigm(-x) -> sigm(x) - # exp(-x) * sigm(x) -> sigm(-x) + """ + exp(x) * sigm(-x) -> sigm(x) + exp(-x) * sigm(x) -> sigm(-x) + """ def match(func, ops): # print [node.op.scalar_op for node in func.maker.fgraph.toposort()] @@ -4364,15 +4365,16 @@ def match(func, ops): # exp]) def test_perform_sigm_times_exp(self): - # Test the core function doing the `sigm_times_exp` optimization. - # - # It is easier to test different graph scenarios this way than by - # compiling an Aesara function. + """Test the core function doing the `sigm_times_exp` rewrite. + + It is easier to test different graph scenarios this way than by + compiling an Aesara function. + """ x, y, z, t = vectors("x", "y", "z", "t") exp_op = exp - def ok(expr1, expr2): + def check(expr1, expr2): trees = [parse_mul_tree(e) for e in (expr1, expr2)] perform_sigm_times_exp(trees[0]) trees[0] = simplify_mul(trees[0]) @@ -4386,12 +4388,12 @@ def ok(expr1, expr2): aesara.printing.debugprint(compute_mul(trees[1])) assert good - ok(sigmoid(x) * exp_op(-x), sigmoid(-x)) - ok( + check(sigmoid(x) * exp_op(-x), sigmoid(-x)) + check( -x * sigmoid(x) * (y * (-1 * z) * exp_op(-x)), -x * sigmoid(-x) * (y * (-1 * z)), ) - ok( + check( -sigmoid(-x) * ( exp_op(y) @@ -4403,11 +4405,11 @@ def ok(expr1, expr2): * (-sigmoid(y) * (-sigmoid(-z) * 3) * (y * 2 * ((z + t) * exp_op(z)))) * (-sigmoid(x)), ) - ok( + check( exp_op(-x) * -exp_op(-x) * (-sigmoid(x) * -sigmoid(x)), -sigmoid(-x) * sigmoid(-x), ) - ok(-exp_op(x) * -sigmoid(-x) * -exp_op(-x), -sigmoid(-x)) + check(-exp_op(x) * -sigmoid(-x) * -exp_op(-x), -sigmoid(-x)) def test_grad_log1msigm(self): # At some point, this returned nan, because (1 - sigm(x)) was @@ -4421,7 +4423,7 @@ def test_grad_log1msigm(self): c = l.mean() ux = x - lr * aesara.grad(c, x) - # Before the optimization, inf and NaN will be produced in the graph, + # Before the rewriting, inf and NaN will be produced in the graph, # and DebugMode will complain. Everything is fine afterwards. mode = self.get_mode() if not isinstance(mode, aesara.compile.debugmode.DebugMode): @@ -4430,7 +4432,7 @@ def test_grad_log1msigm(self): assert not np.isnan(ux_v) -class TestSoftplusOpts: +class TestSoftplusRewrites: def setup_method(self): if aesara.config.mode == "FAST_COMPILE": m = aesara.compile.mode.get_mode("FAST_RUN").excluding( @@ -4538,10 +4540,7 @@ def test_log1p_neg_sigmoid_to_softpuls(self): class TestSigmoidUtils: - """ - Test utility functions found in 'math_opt.py' used in the optimization of - sigmoid / softplus expressions. - """ + """Test utility functions used in the rewrites for `sigmoid`/`softplus` expressions.""" def test_compute_mul(self): x, y, z = vectors("x", "y", "z") @@ -4588,12 +4587,12 @@ def test_log1mexp_stabilization(): nodes = [node.op for node in f.maker.fgraph.toposort()] assert nodes == [at.log1mexp] - # Check values that would under or overflow without optimization + # Check values that would under or overflow without rewriting assert f([-(2.0**-55)]) != -np.inf overflow_value = -500.0 if config.floatX == "float64" else -100.0 assert f([overflow_value]) < 0 - # Check values around the optimization switch point np.log(0.5) + # Check values around the switch point np.log(0.5) assert np.allclose( f(np.array([-0.8, -0.6], dtype=config.floatX)), np.log(1 - np.exp([-0.8, -0.6])), @@ -4601,10 +4600,7 @@ def test_log1mexp_stabilization(): def test_local_logit_sigmoid(): - """ - Test that graphs of the form logit(sigmoid(x)) and sigmoid(logit(x)) get - optimized to x (sigmoid is the inverse of the logit) - """ + """Test that graphs of the form ``logit(sigmoid(x))`` and ``sigmoid(logit(x))`` get rewritten to ``x``.""" def logit_fn(x): return log(x / (1 - x)) @@ -4612,12 +4608,12 @@ def logit_fn(x): x = fmatrix() out = sigmoid(logit_fn(x)) - fg = optimize(FunctionGraph([x], [out])) + fg = rewrite(FunctionGraph([x], [out])) assert not list(fg.toposort()) assert fg.inputs[0] is fg.outputs[0] out = logit_fn(sigmoid(x)) - fg = optimize(FunctionGraph([x], [out])) + fg = rewrite(FunctionGraph([x], [out])) assert not list(fg.toposort()) assert fg.inputs[0] is fg.outputs[0] @@ -4628,12 +4624,18 @@ def test_local_useless_conj(): # Test for all zeros x = scalar() s = _conj(x) - mode_with_opt = default_mode.including("canonicalization", "local_useless_conj") - f = function([x], s, mode=mode_with_opt) + mode_with_rewrite = default_mode.including("canonicalization", "local_useless_conj") + f = function([x], s, mode=mode_with_rewrite) assert not any(node.op == _conj for node in f.maker.fgraph.apply_nodes) x = zscalar() s = _conj(x) - mode_with_opt = default_mode.including("canonicalization", "local_useless_conj") - f = function([x], s, mode=mode_with_opt) + mode_with_rewrite = default_mode.including("canonicalization", "local_useless_conj") + f = function([x], s, mode=mode_with_rewrite) assert any(node.op == _conj for node in f.maker.fgraph.apply_nodes) + + +def test_deprecations(): + """Make sure we can import from deprecated modules.""" + with pytest.deprecated_call(): + from aesara.tensor.math_opt import AlgebraicCanonizer # noqa: F401 F811 diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py new file mode 100644 index 0000000000..09dc0585d0 --- /dev/null +++ b/tests/tensor/rewriting/test_shape.py @@ -0,0 +1,553 @@ +import copy + +import numpy as np +import pytest + +import aesara.tensor as at +from aesara import shared +from aesara.compile.function import function +from aesara.compile.mode import get_default_mode, get_mode +from aesara.compile.ops import deep_copy_op +from aesara.configdefaults import config +from aesara.graph.basic import Apply, Variable +from aesara.graph.fg import FunctionGraph +from aesara.graph.op import Op +from aesara.graph.rewriting.basic import check_stack_trace, node_rewriter, out2in +from aesara.graph.rewriting.utils import rewrite_graph +from aesara.graph.type import Type +from aesara.tensor.basic import as_tensor_variable +from aesara.tensor.elemwise import DimShuffle, Elemwise +from aesara.tensor.math import add, exp, maximum +from aesara.tensor.rewriting.basic import register_specialize +from aesara.tensor.rewriting.shape import ( + ShapeFeature, + local_reshape_to_dimshuffle, + local_useless_reshape, +) +from aesara.tensor.shape import ( + Reshape, + Shape_i, + SpecifyShape, + reshape, + shape, + specify_shape, +) +from aesara.tensor.subtensor import set_subtensor +from aesara.tensor.type import ( + fmatrix, + iscalar, + lscalar, + matrix, + scalar, + tensor, + tensor3, + tensor4, + vector, +) +from tests import unittest_tools as utt + + +rewrite_mode = config.mode + +if rewrite_mode == "FAST_COMPILE": + rewrite_mode = "FAST_RUN" + +rewrite_mode = get_mode(rewrite_mode) + + +class TestShapeRewriter: + def test_basic(self): + mode = config.mode + if mode == "FAST_COMPILE": + mode = "FAST_RUN" + v = vector() + m = matrix() + f = function([v, m], (v + m).shape, mode=mode) + for node in f.maker.fgraph.toposort(): + assert node.op != add + + def test_constant(self): + mode = config.mode + if mode == "FAST_COMPILE": + mode = "FAST_RUN" + + v = vector() + f = function([v], v.dimshuffle("x", "x", 0).shape[1], mode=mode) + topo = f.maker.fgraph.toposort() + assert len(topo) == 1 + assert topo[0].op == deep_copy_op + + @staticmethod + def max_pool_c01b(c01b, pool_shp, pool_stride, img_shp): + """ + Like max_pool but with input using axes ('c', 0, 1, 'b') + (Alex Krizhevsky format) + + pool_shp, pool_stride and img_shp are int that represent + the same shp in x and y. + """ + mx = None + + # Compute index in pooled space of last needed pool + # (needed = each input pixel must appear in at least one pool) + def last_pool(im_shp, p_shp, p_strd): + rval = int(np.ceil(float(im_shp - p_shp) / p_strd)) + assert p_strd * rval + p_shp >= im_shp + assert p_strd * (rval - 1) + p_shp < im_shp + return rval + + # Compute starting row of the last pool + last_pool_r = last_pool(img_shp, pool_shp, pool_stride) * pool_stride + # Compute number of rows needed in img for all indexes to work out + required_r = last_pool_r + pool_shp + + last_pool_c = last_pool(img_shp, pool_shp, pool_stride) * pool_stride + required_c = last_pool_c + pool_shp + + wide_infinity = at.alloc( + -np.inf, c01b.shape[0], required_r, required_c, c01b.shape[3] + ) + + c01b = set_subtensor(wide_infinity[:, 0:img_shp, 0:img_shp, :], c01b) + + for row_within_pool in range(pool_shp): + row_stop = last_pool_r + row_within_pool + 1 + for col_within_pool in range(pool_shp): + col_stop = last_pool_c + col_within_pool + 1 + cur = c01b[ + :, + row_within_pool:row_stop:pool_stride, + col_within_pool:col_stop:pool_stride, + :, + ] + if mx is None: + mx = cur + else: + mx = maximum(mx, cur) + return mx + + def test_broadcasted_dims(self): + # This test a case that caused a crash during rewriting + shp = (1, 1, 1, 1) + rng = np.random.default_rng(utt.fetch_seed()) + a = shared(rng.random(shp).astype(config.floatX)) + out = self.max_pool_c01b(a, 1, 1, 1) + + # max_pool_c01b use -inf and this will trigger DebugMode error. + mode = copy.copy(get_default_mode()) + mode.check_isfinite = False + f = function([], out, mode=mode) + f() + + def test_constant_merge(self): + # This test the error in gh-1122 that is a caused by the + # combination of merge rewriter and ShapeFeature. + + x = at.constant([0, 0]) + y = x[1:] + x1 = x - at.join(0, y, y) + x1.eval() + + def test_local_track_shape_i(self): + class IdentityNoShape(Op): + """Op that does not infer the output shape from the input one""" + + def make_node(self, x): + x = as_tensor_variable(x) + return Apply(self, [x], [x.type()]) + + def perform(self, node, inp, out_): + (x,) = inp + (out,) = out_ + out[0] = x.copy() + + # def infer_shape(self, fgraph, node, (xshp,)): + # return [tuple([self.shape_i(i)(r) for i in range(r.ndim)])] + + identity_noshape = IdentityNoShape() + + class IdentityShape(Op): + """Op that does infer the output shape from the input one""" + + def make_node(self, x): + x = as_tensor_variable(x) + return Apply(self, [x], [x.type()]) + + def perform(self, node, inp, out_): + (x,) = inp + (out,) = out_ + out[0] = x.copy() + + def infer_shape(self, fgraph, node, xshp_): + # Could also just return. + (xshp,) = xshp_ + return (xshp,) + + identity_shape = IdentityShape() + + @node_rewriter([IdentityNoShape]) + def local_identity_noshape_to_identity_shape(fgraph, node): + """Transform the first `Op` into the second.""" + if isinstance(node.op, IdentityNoShape): + return [identity_shape(node.inputs[0])] + + mode = get_default_mode().including("ShapeOpt", "specialize") + rng = np.random.default_rng(utt.fetch_seed()) + x = tensor3("x") + ins_x = identity_noshape(x) + + # Without the rewrite + f = function([x], ins_x.shape, mode=mode) + xval = rng.standard_normal((3, 4, 7)).astype(config.floatX) + assert np.all(f(xval) == [3, 4, 7]) + f_ops = [node.op for node in f.maker.fgraph.toposort()] + assert len(f_ops) == 5 + assert identity_noshape in f_ops + assert identity_shape not in f_ops + + # Register the rewrite + register_specialize(local_identity_noshape_to_identity_shape) + + mode = get_default_mode().including("ShapeOpt", "specialize") + # The `identity_shape` hOph should not be needed anymore to compute + # the shape + g = function([x], ins_x.shape, mode=mode) + xval = rng.standard_normal((6, 1, 2)).astype(config.floatX) + assert np.all(g(xval) == [6, 1, 2]) + g_ops = [node.op for node in g.maker.fgraph.toposort()] + assert len(g_ops) == 4 + assert identity_noshape not in g_ops + assert identity_shape not in g_ops + + # Test multiple applications of an `Op` without an `Op.infer_shape` + ins_x3 = identity_noshape(identity_noshape(identity_noshape(x))) + h = function([x], ins_x3.shape, mode=mode) + xval = rng.standard_normal((6, 1, 2)).astype(config.floatX) + assert np.all(h(xval) == [6, 1, 2]) + h_ops = [node.op for node in h.maker.fgraph.toposort()] + assert len(h_ops) == 4 + assert identity_noshape not in h_ops + assert identity_shape not in h_ops + + def test_no_shapeopt(self): + """Test that a basic example works even when `ShapeOpt` is excluded.""" + X = matrix() + expr = X.shape[0] + + mode = get_default_mode().excluding("ShapeOpt") + f = function([X], expr, mode=mode) + # FIXME: This is not a good test. + f([[1, 2], [2, 3]]) + + +class TestReshape: + def setup_method(self): + self.mode = rewrite_mode + self.op = Reshape + + def test_local_reshape(self): + a = fmatrix() + b = self.op(3)(a, [2, 3, 4]) + c = self.op(1)(b, [24]) + f = function([a], c, mode=self.mode) + topo = f.maker.fgraph.toposort() + assert sum(isinstance(node.op, self.op) for node in topo) == 1 + + # Check stack trace + assert check_stack_trace(f, ops_to_check=[self.op]) + + +class TestLocalUselessReshape: + def setup_method(self): + self.rng = np.random.default_rng(utt.fetch_seed()) + + def test_0(self): + mode = get_default_mode().including("local_useless_reshape") + i = iscalar("i") + m = at.mgrid[ + 0:i, + ] + f = function([i], m, mode=mode) + topo = f.maker.fgraph.toposort() + assert not any(isinstance(n.op, Reshape) for n in topo) + + def test_1(self): + x = matrix("x") + r = x.reshape(x.shape) + + m0 = get_default_mode() + m1 = m0.including("local_useless_reshape") + f1 = function([x], r, mode=m1) + topo = f1.maker.fgraph.toposort() + assert not any(isinstance(n.op, Reshape) for n in topo) + + m2 = m1.excluding("ShapeOpt") + f2 = function([x], r, mode=m2) + topo = f2.maker.fgraph.toposort() + assert not any(isinstance(n.op, Reshape) for n in topo) + + # We do not need tests checking that stack traces are copied over, + # because local_useless_reshape only removes nodes from the graph + + def test_2(self): + x = matrix("x") + r = x.reshape([Shape_i(i)(x) for i in range(x.ndim)]) + + m0 = get_default_mode() + m1 = m0.including("local_useless_reshape") + f1 = function([x], r, mode=m1) + topo = f1.maker.fgraph.toposort() + assert not any(isinstance(n.op, Reshape) for n in topo) + + m2 = m1.excluding("ShapeOpt") + f2 = function([x], r, mode=m2) + topo = f2.maker.fgraph.toposort() + assert not any(isinstance(n.op, Reshape) for n in topo) + + def test_m1(self): + x = matrix("x") + r = x.reshape((x.shape[0], -1)) + + m0 = get_default_mode() + m1 = m0.including("local_useless_reshape") + f1 = function([x], r, mode=m1) + topo = f1.maker.fgraph.toposort() + assert not any(isinstance(n.op, Reshape) for n in topo) + + m2 = m1.excluding("ShapeOpt") + f2 = function([x], r, mode=m2) + topo = f2.maker.fgraph.toposort() + assert not any(isinstance(n.op, Reshape) for n in topo) + + +class TestLocalReshapeToDimshuffle: + def setup_method(self): + self.rng = np.random.default_rng(utt.fetch_seed()) + + def test_1(self): + reshape_lift = out2in(local_reshape_to_dimshuffle) + useless_reshape = out2in(local_useless_reshape) + x = shared(self.rng.standard_normal((4,))) + y = shared(self.rng.standard_normal((5, 6))) + reshape_x = reshape(x, (1, 4)) + reshape_y = reshape(y, (1, 5, 1, 6, 1, 1)) + + g = FunctionGraph([x, y], [reshape_x, reshape_y]) + assert str(g) == ( + "FunctionGraph(Reshape{2}" + "(, " + "TensorConstant{[1 4]}), " + "Reshape{6}" + "(, " + "TensorConstant{[1 5 1 6 1 1]}))" + ) + + reshape_lift.rewrite(g) + useless_reshape.rewrite(g) + assert str(g) == ( + "FunctionGraph(InplaceDimShuffle{x,0}" + "(), " + "InplaceDimShuffle{x,0,x,1,x,x}" + "(Reshape{2}(, " + "TensorConstant{[5 6]})))" + ) + + # Check stacktrace was copied over correctly after the rewrite was applied + assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape)) + + +def test_local_reshape_lift(): + x = tensor4() + out = exp(x).reshape([x.size]) + assert out.ndim == 1 + mode = get_default_mode() + mode = mode.including("local_reshape_lift") + f = function([x], out, mode=mode) + f(np.random.random((5, 4, 3, 2)).astype(config.floatX)) + topo = f.maker.fgraph.toposort() + assert isinstance(topo[-2].op, Reshape) + assert isinstance(topo[-1].op, Elemwise) + assert check_stack_trace(f, ops_to_check="last") + + +class TestShapeI(utt.InferShapeTester): + def setup_method(self): + super().setup_method() + + def test_perform(self): + rng = np.random.default_rng(utt.fetch_seed()) + + advec = vector() + advec_val = rng.random((3)).astype(config.floatX) + f = function([advec], Shape_i(0)(advec)) + out = f(advec_val) + utt.assert_allclose(out, advec_val.shape[0]) + + admat = matrix() + admat_val = rng.random((4, 3)).astype(config.floatX) + for i in range(2): + f = function([admat], Shape_i(i)(admat)) + out = f(admat_val) + utt.assert_allclose(out, admat_val.shape[i]) + + def test_infer_shape(self): + admat = matrix() + admat_val = np.random.random((3, 4)).astype(config.floatX) + self._compile_and_check([admat], [Shape_i(0)(admat)], [admat_val], Shape_i) + + self._compile_and_check([admat], [Shape_i(1)(admat)], [admat_val], Shape_i) + + +class TestSameShape: + def test_scalar(self): + x = scalar() + cst = at.constant(1) + o = x + cst + fgraph = FunctionGraph([x], [o], clone=False) + shape_feature = ShapeFeature() + fgraph.attach_feature(shape_feature) + assert shape_feature.same_shape(x, o) + + def test_vector(self): + x = vector() + cst = at.constant(1) + o = x + cst + fgraph = FunctionGraph([x], [o], clone=False) + shape_feature = ShapeFeature() + fgraph.attach_feature(shape_feature) + assert shape_feature.same_shape(x, o) + + def test_no_static_shapes(self): + x = vector() + y = vector() + o = x + y + fgraph = FunctionGraph([x, y], [o], clone=False) + shape_feature = ShapeFeature() + fgraph.attach_feature(shape_feature) + # We no longer assume that `x` has the same shape as `y` simply because + # neither has static shape information. Instead, when there is no + # static shape information is available, we assume that `x` and/or `y` + # could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any + # combination of the two. + assert not shape_feature.same_shape(x, o) + # The following case isn't implemented + assert not shape_feature.same_shape(y, o) + + @pytest.mark.parametrize( + "y_dim_0", + [2, pytest.param(None, marks=pytest.mark.xfail(reason="Not implemented"))], + ) + def test_vector_dim(self, y_dim_0): + x = at.tensor(dtype="floatX", shape=(2, None)) + y = at.tensor(dtype="floatX", shape=(y_dim_0, None)) + o = x + y + fgraph = FunctionGraph([x, y], [o], clone=False) + shape_feature = ShapeFeature() + fgraph.attach_feature(shape_feature) + assert shape_feature.same_shape(x, o, 0, 0) + assert not shape_feature.same_shape(x, o, 1, 1) + + def test_vector_dim_err(self): + x = vector() + y = vector() + o = x + y + fgraph = FunctionGraph([x, y], [o], clone=False) + shape_feature = ShapeFeature() + fgraph.attach_feature(shape_feature) + with pytest.raises(IndexError): + shape_feature.same_shape(x, o, 1, 0) + with pytest.raises(IndexError): + shape_feature.same_shape(x, o, 0, 1) + + +@pytest.mark.parametrize( + "shape", + [lscalar(), iscalar()], +) +def test_local_Shape_of_SpecifyShape(shape): + x = vector() + s = specify_shape(x, shape).shape + + fgraph = FunctionGraph(outputs=[s], clone=False) + _ = rewrite_graph(fgraph, clone=False) + + assert x not in fgraph.variables + assert shape in fgraph.variables + + +@pytest.mark.parametrize( + "s1", + [lscalar(), iscalar()], +) +def test_local_Shape_of_SpecifyShape_partial(s1): + x = matrix() + s = specify_shape(x, (s1, None)).shape + + fgraph = FunctionGraph(outputs=[s], clone=False) + assert any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes) + + _ = rewrite_graph(fgraph, clone=False) + + assert x in fgraph.variables + assert s1 in fgraph.variables + assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes) + + +def test_local_Shape_i_of_broadcastable(): + x = tensor(np.float64, [False, True]) + s = Shape_i(1)(x) + + fgraph = FunctionGraph(outputs=[s], clone=False) + _ = rewrite_graph(fgraph, clone=False) + + assert x not in fgraph.variables + assert fgraph.outputs[0].data == 1 + + # A test for a non-`TensorType` + class MyType(Type): + ndim = 1 + + def filter(self, *args, **kwargs): + raise NotImplementedError() + + def __eq__(self, other): + return isinstance(other, MyType) and other.thingy == self.thingy + + class MyVariable(Variable): + pass + + x = MyVariable(MyType(), None, None) + s = Shape_i(0)(x) + fgraph = FunctionGraph(outputs=[s], clone=False) + _ = rewrite_graph(fgraph, clone=False) + + assert fgraph.outputs[0] == s + + +def test_Shape_i_canonicalize(): + """Make sure the canonicalizations work together to produce the correct graphs for shapes in a single dimension. + + In other words, ``shape(x)[i]`` should result in a simple ``Shape_i(0)(x)`` + and nothing else. The rewrites `local_shape_to_shape_i`, + `local_subtensor_remove_broadcastable_index`, and + `local_useless_dimshuffle_makevector` need to work together to accomplish + this, and we confirm that here. + """ + x = vector() + y = shape(x)[0] + + y_fg = FunctionGraph(outputs=[y], copy_inputs=False, features=[ShapeFeature()]) + + y_rewritten_fg = rewrite_graph( + y_fg, + clone=False, + include=[ + "canonicalize", + ], + ) + + y_rewritten = y_rewritten_fg.outputs[0] + + assert isinstance(y_rewritten.owner.op, Shape_i) + assert y_rewritten.owner.op.i == 0 + assert y_rewritten.owner.inputs[0] == x diff --git a/tests/tensor/test_subtensor_opt.py b/tests/tensor/rewriting/test_subtensor.py similarity index 86% rename from tests/tensor/test_subtensor_opt.py rename to tests/tensor/rewriting/test_subtensor.py index 0d7dd1ffe1..754dfc6995 100644 --- a/tests/tensor/test_subtensor_opt.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -10,15 +10,19 @@ from aesara.compile.ops import DeepCopyOp from aesara.configdefaults import config from aesara.graph.basic import Constant, Variable, ancestors -from aesara.graph.opt import check_stack_trace -from aesara.graph.opt_utils import optimize_graph -from aesara.graph.optdb import OptimizationQuery +from aesara.graph.rewriting.basic import check_stack_trace +from aesara.graph.rewriting.db import RewriteDatabaseQuery +from aesara.graph.rewriting.utils import rewrite_graph from aesara.graph.type import Type from aesara.raise_op import Assert from aesara.tensor import inplace from aesara.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.math import Dot, add, dot, exp, sqr +from aesara.tensor.rewriting.subtensor import ( + local_replace_AdvancedSubtensor, + local_subtensor_shape_constant, +) from aesara.tensor.shape import SpecifyShape, Unbroadcast, _shape, shape, specify_shape from aesara.tensor.subtensor import ( AdvancedIncSubtensor, @@ -27,13 +31,10 @@ AdvancedSubtensor1, IncSubtensor, Subtensor, + advanced_inc_subtensor1, inc_subtensor, set_subtensor, ) -from aesara.tensor.subtensor_opt import ( - local_replace_AdvancedSubtensor, - local_subtensor_shape_constant, -) from aesara.tensor.type import ( bmatrix, col, @@ -52,7 +53,7 @@ tensor4, vector, ) -from aesara.tensor.type_other import slicetype +from aesara.tensor.type_other import make_slice, slicetype from tests import unittest_tools as utt from tests.unittest_tools import create_aesara_param @@ -198,150 +199,219 @@ def test_local_useless_inc_subtensor_no_opt(): assert any(isinstance(n.op, IncSubtensor) for n in topo) -def test_local_useless_subtensor(): +class TestLocalUselessSubtensor: x = matrix("x") + s = aes.int32("s") + mode = mode_opt.including( + "local_useless_subtensor", "local_useless_AdvancedSubtensor1" + ) - # Test default - for dims in [ - (slice(0, None),), - (slice(0, None), slice(0, None)), - ]: - f = function([x], exp(x).__getitem__(dims), mode=mode_opt) + @pytest.mark.parametrize( + "idx", + [ + (slice(0, None),), + (slice(0, None), slice(0, None)), + ], + ) + def test_local_useless_subtensor_1(self, idx): + f = function([self.x], exp(self.x).__getitem__(idx), mode=self.mode) prog = f.maker.fgraph.toposort() assert prog[0].op == exp assert len(prog) == 1 - f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something - - x_c = specify_shape(x, (2, 3)) - # Test constant - for dims, res in [ - ((slice(0, 2),), True), - ((slice(0, 2), slice(0, None)), True), - ((slice(0, 2), slice(0, 3)), True), - ((slice(0, None), slice(0, 3)), True), - ((slice(0, 3), slice(0, 13)), True), - ((slice(0, 3), slice(0, 2)), False), - ((slice(0, 1), slice(0, None)), False), - ((slice(0, 1), 1), False), - ]: - f = function([x], exp(x_c).__getitem__(dims), mode=mode_opt) + + x_val = np.array([[0, 1, 2], [3, 4, 5]], dtype=aesara.config.floatX) + idx_val = idx + exp_res = np.exp(x_val)[idx_val] + res = f(x_val) + assert np.allclose(res, exp_res) + + @pytest.mark.parametrize( + "idx, res", + [ + ((slice(0, 2),), True), + ((slice(0, 2), slice(0, None)), True), + ((slice(0, 2), slice(0, 3)), True), + ((slice(0, None), slice(0, 3)), True), + ((slice(0, 3), slice(0, 13)), True), + ((slice(0, 3), slice(0, 2)), False), + ((slice(0, 1), slice(0, None)), False), + ((slice(0, 1), 1), False), + ], + ) + def test_local_useless_subtensor_2(self, idx, res): + x_c = specify_shape(self.x, (2, 3)) + f = function([self.x], exp(x_c).__getitem__(idx), mode=self.mode) prog = f.maker.fgraph.toposort() if res: - assert isinstance(prog[0].op, SpecifyShape), dims - assert prog[1].op == exp, (dims, prog) - assert len(prog) == 2, dims + assert isinstance(prog[0].op, SpecifyShape) + assert prog[1].op == exp + assert len(prog) == 2 else: assert any(isinstance(node.op, Subtensor) for node in prog) - f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something - # Test Variable - for idx, (dims, res) in enumerate( + x_val = np.array([[0, 1, 2], [3, 4, 5]], dtype=aesara.config.floatX) + idx_val = idx + exp_res = np.exp(x_val)[idx_val] + res = f(x_val) + assert np.allclose(res, exp_res) + + @pytest.mark.parametrize( + "idx_fn, res", [ - ((slice(0, x.shape[0]),), True), - ((slice(0, x.shape[1]),), False), + (lambda x: (slice(0, x.shape[0]),), True), + (lambda x: (slice(0, x.shape[1]),), False), ( - ( + lambda x: ( slice(0, x.shape[0]), slice(0, x.shape[1]), ), True, ), ( - ( + lambda x: ( slice(0, x.shape[0]), slice(0, x.shape[0]), ), False, ), ( - ( + lambda x: ( slice(0, x.shape[1]), slice(0, x.shape[0]), ), False, ), ( - ( + lambda x: ( slice(0, x.shape[1]), slice(0, x.shape[1]), ), False, ), - ((slice(0, x.shape[1]), 2), False), + (lambda x: (slice(0, x.shape[1]), 2), False), ( - ( + lambda x: ( slice(0, x.shape[1]), slice(x.shape[0] - x.shape[0], x.shape[1]), ), False, ), - ((slice(0, at.scalar_from_tensor(x.shape[0])),), True), - ] - ): - f = function([x], exp(x).__getitem__(dims), mode=mode_opt) + ( + lambda x: ( + slice( + 0, + at.scalar_from_tensor(x.shape[0]) + if isinstance(x, Variable) + else x.shape[0], + ), + ), + True, + ), + ], + ) + def test_local_useless_subtensor_3(self, idx_fn, res): + idx = idx_fn(self.x) + f = function([self.x], exp(self.x).__getitem__(idx), mode=self.mode) prog = f.maker.fgraph.toposort() if res: - assert prog[0].op == exp, dims - assert len(prog) == 1, dims + assert prog[0].op == exp + assert len(prog) == 1 else: assert any(isinstance(node.op, Subtensor) for node in prog) - f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something - # Test mix Variable and Constant - # Currently not supported - for idx, (dims, res) in enumerate( + + x_val = np.array([[0, 1, 2], [3, 4, 5]], dtype=aesara.config.floatX) + idx_val = idx_fn(x_val) + exp_res = np.exp(x_val)[idx_val] + res = f(x_val) + assert np.allclose(res, exp_res) + + @pytest.mark.parametrize( + "idx_fn, res", [ - ((slice(0, x.shape[0]), slice(0, 3)), False), - ((slice(0, 3), slice(0, x.shape[1])), False), - ] - ): - f = function([x], exp(x_c).__getitem__(dims), mode=mode_opt) + (lambda x: (slice(0, x.shape[0]), slice(0, 3)), False), + (lambda x: (slice(0, 3), slice(0, x.shape[1])), False), + ], + ) + def test_local_useless_subtensor_4(self, idx_fn, res): + # Test mix Variable and Constant + # Currently not supported + x_c = specify_shape(self.x, (2, 3)) + idx = idx_fn(self.x) + f = function([self.x], exp(x_c).__getitem__(idx), mode=self.mode) prog = f.maker.fgraph.toposort() if res: - assert prog[0].op == exp, dims - assert len(prog) == 1, dims + assert prog[0].op == exp + assert len(prog) == 1 else: assert any(isinstance(node.op, Subtensor) for node in prog) - f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something - # Test scalar variable - s = aes.int32("s") - for idx, (dims, res) in enumerate( + x_val = np.array([[0, 1, 2], [3, 4, 5]], dtype=aesara.config.floatX) + idx_val = idx_fn(x_val) + exp_res = np.exp(x_val)[idx_val] + res = f(x_val) + assert np.allclose(res, exp_res) + + @pytest.mark.parametrize( + "idx_fn, res", [ - ((slice(0, s),), False), - ] - ): - f = function([x, s], exp(x).__getitem__(dims), mode=mode_opt) + (lambda s: (slice(0, s),), False), + ], + ) + def test_local_useless_subtensor_5(self, idx_fn, res): + # Test scalar variable + idx = idx_fn(self.s) + f = function([self.x, self.s], exp(self.x).__getitem__(idx), mode=mode_opt) + prog = f.maker.fgraph.toposort() if res: - assert prog[0].op == exp, dims - assert len(prog) == 1, dims + assert prog[0].op == exp + assert len(prog) == 1 else: assert any(isinstance(node.op, Subtensor) for node in prog) - f([[1, 2, 3], [4, 5, 6]], 1) - f([[1, 2, 3], [4, 5, 6]], 3) - - # Test AdvancedSubtensor1 case when all rows are selected by a list/vector - # or ARange op - for dims, res in ( - ([0, 1], True), - ([1, 0], False), - ([0, 0], False), - ([0, 0, 1], False), - (at.arange(2), True), - (at.arange(0, 2), True), - (at.arange(0, 2, 2), False), - (at.arange(0, 2, -1), False), - (at.arange(1, 2), False), - ): - f = function([x], exp(x_c).__getitem__(dims), mode=mode_opt) + + x_val = np.array([[0, 1, 2], [3, 4, 5]], dtype=aesara.config.floatX) + idx_val = idx_fn(1) + exp_res = np.exp(x_val)[idx_val] + res = f(x_val, 1) + assert np.allclose(res, exp_res) + + idx_val = idx_fn(3) + exp_res = np.exp(x_val)[idx_val] + res = f(x_val, 3) + assert np.allclose(res, exp_res) + + @pytest.mark.parametrize( + "idx, res", + [ + ([0, 1], True), + ([1, 0], False), + ([0, 0], False), + ([0, 0, 1], False), + (at.arange(2), True), + (at.arange(0, 2), True), + (at.arange(0, 2, 2), False), + (at.arange(0, 2, -1), False), + (at.arange(1, 2), False), + ], + ) + def test_local_useless_subtensor_6(self, idx, res): + # Test AdvancedSubtensor1 case when all rows are selected by a list/vector + # or ARange op + x_c = specify_shape(self.x, (2, 3)) + f = function([self.x], exp(x_c).__getitem__(idx), mode=mode_opt) prog = f.maker.fgraph.toposort() if res: - assert isinstance(prog[0].op, SpecifyShape), dims - assert prog[1].op == exp, dims - assert len(prog) == 2, dims + assert isinstance(prog[0].op, SpecifyShape) + assert prog[1].op == exp + assert len(prog) == 2 else: assert any(isinstance(node.op, AdvancedSubtensor1) for node in prog) - f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something + + x_val = np.array([[0, 1, 2], [3, 4, 5]], dtype=aesara.config.floatX) + idx_val = idx.eval() if isinstance(idx, Variable) else idx + exp_res = np.exp(x_val)[idx_val] + res = f(x_val) + assert np.allclose(res, exp_res) def test_local_subtensor_remove_broadcastable_index(): @@ -1489,7 +1559,7 @@ def test_stack_trace(self): assert check_stack_trace(f, ops_to_check=(Assert, aes.Cast)) -class TestAllocZero: +class TestSubtensorAllocRewrites: def setup_method(self): mode = get_default_mode() self.mode = mode.including( @@ -1783,207 +1853,6 @@ def test_local_set_to_inc_subtensor(): assert check_stack_trace(f2, ops_to_check="all") -class TestLocalElemwiseAlloc: - dtype = config.floatX - - def setup_method(self): - self.fast_compile_mode = get_mode("FAST_COMPILE") - self.fast_run_mode = get_mode("FAST_RUN") - - self.vec = vector("vec", dtype=self.dtype) - self.mat = matrix("mat", dtype=self.dtype) - self.tens = tensor3("tens", dtype=self.dtype) - - self.alloc_wo_dep = at.alloc(self.vec, 2, 2) - self.alloc_wo_dep_broad = at.alloc(self.vec, 1, 2) - self.alloc_w_dep = at.alloc(self.vec, *self.mat.shape) - self.alloc_w_dep_broad = at.alloc(self.vec, 1, *self.mat.shape) - self.alloc_w_dep_broad2 = at.alloc( - self.vec, self.mat.shape[0], self.mat.shape[1], 1 - ) - self.alloc_w_dep_tens = at.alloc( - self.vec, self.tens.shape[0], self.tens.shape[1] - ) - self.tv_wo_dep = at.alloc(self.vec, 5, 5) - self.tm_wo_dep = at.alloc(self.mat, 5, 5, 5) - self.s = iscalar("s") - self.tv_w_dep = at.alloc(self.vec, self.s, self.s) - self.tm_w_dep = at.alloc(self.mat, 5, 5, 5) - self.row = row(dtype=self.dtype) - self.o = at.alloc(self.row, 5, 5) - - def _verify_alloc_count(self, f, count): - assert ( - sum( - isinstance(elem.op, Alloc) - for elem in f.maker.fgraph.toposort() - if elem.op is not None - ) - == count - ) - - def _verify_assert_count(self, f, count): - assert ( - sum( - isinstance(elem.op, Assert) - for elem in f.maker.fgraph.toposort() - if elem.op is not None - ) - == count - ) - - def test_remove_alloc_wo_dimshuffle(self): - # Exclude local_useless_alloc, since it does not introduce - # assert in all the same cases. - self.fast_run_mode = self.fast_run_mode.excluding( - "local_useless_alloc", "local_alloc_sink_dimshuffle" - ) - # No optimization on alloc - func = function( - [self.vec, self.mat], - self.alloc_wo_dep + self.mat, - mode=self.fast_compile_mode, - ) - self._verify_alloc_count(func, 1) - self._verify_assert_count(func, 0) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(func, ops_to_check="all") - - # Optimization on alloc with assert - func = function( - [self.vec, self.mat], self.alloc_wo_dep + self.mat, mode=self.fast_run_mode - ) - self._verify_alloc_count(func, 0) - self._verify_assert_count(func, 1) - - # Optimization on alloc with assert and broadcast - func = function( - [self.vec, self.mat], - self.alloc_wo_dep_broad + self.mat, - mode=self.fast_run_mode, - ) - self._verify_alloc_count(func, 0) - self._verify_assert_count(func, 1) - - # No optimization on alloc without assert - func = function( - [self.vec, self.mat], - self.alloc_w_dep + self.mat, - mode=self.fast_compile_mode, - ) - self._verify_alloc_count(func, 1) - self._verify_assert_count(func, 0) - - # Optimization on alloc without assert - func = function( - [self.vec, self.mat], self.alloc_w_dep + self.mat, mode=self.fast_run_mode - ) - self._verify_alloc_count(func, 0) - self._verify_assert_count(func, 0) - - # Optimization on alloc without assert and with broadcast - func = function( - [self.vec, self.mat], - self.alloc_w_dep_broad + self.mat, - mode=self.fast_run_mode, - ) - self._verify_alloc_count(func, 0) - self._verify_assert_count(func, 0) - - # Not optimized case on alloc and with broadcast - func = function( - [self.vec, self.mat], - self.alloc_w_dep_broad2 + self.mat, - mode=self.fast_run_mode, - ) - self._verify_alloc_count(func, 1) - self._verify_assert_count(func, 0) - - def test_remove_alloc_w_dimshuffle(self): - # No optimization on dimshuffle with assert - func = function( - [self.vec, self.tens], - self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens, - mode=self.fast_compile_mode, - ) - self._verify_alloc_count(func, 1) - self._verify_assert_count(func, 0) - - # Optimization on dimshuffle with assert - func = function( - [self.vec, self.tens], - self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens, - mode=self.fast_run_mode, - ) - self._verify_alloc_count(func, 0) - self._verify_assert_count(func, 1) - - # No optimization on dimshuffle without assert - func = function( - [self.vec, self.tens], - self.alloc_w_dep_tens.dimshuffle(0, 1, "x") + self.tens, - mode=self.fast_compile_mode, - ) - self._verify_alloc_count(func, 1) - self._verify_assert_count(func, 0) - - # Optimization on dimshuffle without assert - func = function( - [self.vec, self.tens], - self.alloc_w_dep_tens.dimshuffle(0, 1, "x") + self.tens, - mode=self.fast_run_mode, - ) - self._verify_alloc_count(func, 0) - self._verify_assert_count(func, 0) - - def test_multi_input_single_alloc(self): - # No optimization on dimshuffle with assert - func = function( - [self.vec, self.mat], - self.tv_wo_dep + self.tm_wo_dep, - mode=self.fast_compile_mode, - ) - self._verify_alloc_count(func, 2) - self._verify_assert_count(func, 0) - - # Optimization on dimshuffle with assert - func = function( - [self.vec, self.mat], - self.tv_wo_dep + self.tm_wo_dep, - mode=self.fast_run_mode, - ) - self._verify_alloc_count(func, 1) - self._verify_assert_count(func, 0) - - # No optimization on dimshuffle without assert - func = function( - [self.vec, self.mat, self.s], - self.tv_w_dep + self.tm_w_dep, - mode=self.fast_compile_mode, - ) - self._verify_alloc_count(func, 2) - self._verify_assert_count(func, 0) - - # Optimization on dimshuffle without assert - func = function( - [self.vec, self.mat, self.s], - self.tv_w_dep + self.tm_w_dep, - mode=self.fast_run_mode, - ) - self._verify_alloc_count(func, 1) - self._verify_assert_count(func, 1) - - def test_error(self): - t3fft = tensor(dtype=self.dtype, shape=(False, False, True)) - o = self.o.dimshuffle(0, 1, "x") + t3fft - func = function([t3fft, self.row], o, mode=self.fast_run_mode) - self._verify_alloc_count(func, 0) - self._verify_assert_count(func, 1) - d = np.random.random((5, 5, 1)).astype(self.dtype) - r = np.random.random((1, 5)).astype(self.dtype) - func(d, r) - - def test_local_subtensor_of_alloc(): # DebugMode should detect if something goes wrong. @@ -2039,7 +1908,7 @@ def test_local_subtensor_shape_constant(): assert res.data == 1 # Make sure it's part of the canonicalizations - res = optimize_graph(x) + res = rewrite_graph(x) assert isinstance(res, Constant) assert res.data == 1 @@ -2126,14 +1995,16 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): y = specify_shape(x, s)[idx] assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape) - opts = OptimizationQuery(include=[None]) - no_opt_mode = Mode(optimizer=opts) + rewrites = RewriteDatabaseQuery(include=[None]) + no_rewrites_mode = Mode(optimizer=rewrites) - y_val_fn = function([x] + list(s), y, on_unused_input="ignore", mode=no_opt_mode) + y_val_fn = function( + [x] + list(s), y, on_unused_input="ignore", mode=no_rewrites_mode + ) y_val = y_val_fn(*([x_val] + [s_ for s_ in s_val])) # This optimization should appear in the canonicalizations - y_opt = optimize_graph(y, clone=False) + y_opt = rewrite_graph(y, clone=False) if y.ndim == 0: # SpecifyShape should be removed altogether @@ -2172,7 +2043,7 @@ def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx): y = specify_shape(x, s)[idx] # This optimization should appear in the canonicalizations - y_opt = optimize_graph(y, clone=False) + y_opt = rewrite_graph(y, clone=False) assert not isinstance(y_opt.owner.op, SpecifyShape) @@ -2284,3 +2155,152 @@ def test_local_join_subtensors(axis, slices_fn, expected_nodes): f_val = np.concatenate([x_val[slice_val] for slice_val in slices_val], axis=axis) np.testing.assert_array_equal(f(x_val, stop_val), f_val) + + +def test_deprecations(): + """Make sure we can import from deprecated modules.""" + with pytest.deprecated_call(): + from aesara.tensor.subtensor_opt import get_advsubtensor_axis # noqa: F401 F811 + + +def test_local_uint_constant_indices(): + mode = get_default_mode().including("specialize", "local_uint_constant_indices") + rng = np.random.default_rng(20900) + + # Subtensor, don't convert + x = at.vector("x") + idx = at.as_tensor_variable(np.array(-1, np.int64)) + z = x[idx] + + z_fn = aesara.function([x], z, mode=mode) + + deepcopy_node = z_fn.maker.fgraph.outputs[0].owner + subtensor_node = deepcopy_node.inputs[0].owner + assert isinstance(subtensor_node.op, Subtensor) + new_index = subtensor_node.inputs[1] + assert isinstance(new_index, Constant) + assert new_index.type.dtype == "int64" + + # `Subtensor`, one index, convert + x = at.vector("x") + idx = at.as_tensor_variable(np.array(1, np.int64)) + z = x[idx] + + z_fn = aesara.function([x], z, mode=mode) + + deepcopy_node = z_fn.maker.fgraph.outputs[0].owner + subtensor_node = deepcopy_node.inputs[0].owner + assert isinstance(subtensor_node.op, Subtensor) + new_index = subtensor_node.inputs[1] + assert isinstance(new_index, Constant) + assert new_index.type.dtype == "uint8" + + # `Subtensor`, two indices, one slice, convert + x = at.matrix("x") + indices = (at.as_tensor_variable(np.array(1, np.int64)), slice(None, 10)) + z = x[indices] + + z_fn = aesara.function([x], z, mode=mode) + + deepcopy_node = z_fn.maker.fgraph.outputs[0].owner + subtensor_node = deepcopy_node.inputs[0].owner + assert isinstance(subtensor_node.op, Subtensor) + new_index = subtensor_node.inputs[1] + assert isinstance(new_index, Constant) + assert new_index.type.dtype == "uint8" + + # `AdvancedSubtensor`, two indices, one symbolic slice, convert + x = at.matrix("x") + indices = ( + at.as_tensor_variable(np.array(1, np.int64)), + make_slice(slice(None, 10)), + ) + z = x[indices] + + z_fn = aesara.function([x], z, mode=mode) + + subtensor_node = z_fn.maker.fgraph.outputs[0].owner + assert isinstance(subtensor_node.op, AdvancedSubtensor) + new_index = subtensor_node.inputs[1] + assert isinstance(new_index, Constant) + assert new_index.type.dtype == "uint8" + + # `AdvancedSubtensor1`, convert + x = at.vector("x") + idx = at.as_tensor_variable(rng.integers(0, 10, size=10).astype(np.int64)) + z = x[idx] + + z_fn = aesara.function([x], z, mode=mode) + + subtensor_node = z_fn.maker.fgraph.outputs[0].owner + assert isinstance(subtensor_node.op, AdvancedSubtensor1) + new_index = subtensor_node.inputs[1] + assert isinstance(new_index, Constant) + assert new_index.type.dtype == "uint8" + + # AdvancedSubtensor, empty, convert + x = at.matrix("x") + idx = at.as_tensor_variable(1, dtype=np.int64) + z = x[idx, []] + + z_fn = aesara.function([x], z, mode=mode) + + subtensor_node = z_fn.maker.fgraph.outputs[0].owner + assert isinstance(subtensor_node.op, AdvancedSubtensor) + new_index = subtensor_node.inputs[1] + assert isinstance(new_index, Constant) + assert new_index.type.dtype == "uint8" + + # AdvancedSubtensor, bool, don't convert + x = at.matrix("x") + idx = at.as_tensor_variable(np.array([True]), dtype=bool) + z = x[idx, []] + + z_fn = aesara.function([x], z, mode=mode) + + subtensor_node = z_fn.maker.fgraph.outputs[0].owner + assert isinstance(subtensor_node.op, AdvancedSubtensor) + new_index = subtensor_node.inputs[1] + assert isinstance(new_index, Constant) + assert new_index.type.dtype == "bool" + + # `IncSubtensor`, convert + x = at.vector("x") + y = at.scalar("y") + idx = at.as_tensor_variable(1, dtype=np.int64) + z = inc_subtensor(x[idx], y) + + z_fn = aesara.function([x, y], z, mode=mode) + + subtensor_node = z_fn.maker.fgraph.outputs[0].owner + assert isinstance(subtensor_node.op, IncSubtensor) + new_index = subtensor_node.inputs[2] + assert isinstance(new_index, Constant) + assert new_index.type.dtype == "uint8" + + # `AdvancedIncSubtensor1`, convert + x = at.vector("x") + y = at.vector("y") + idx = at.as_tensor_variable(rng.integers(0, 10, size=10).astype(np.int64)) + z = advanced_inc_subtensor1(x, y, idx) + + z_fn = aesara.function([x, y], z, mode=mode) + + subtensor_node = z_fn.maker.fgraph.outputs[0].owner + assert isinstance(subtensor_node.op, AdvancedIncSubtensor1) + new_index = subtensor_node.inputs[2] + assert isinstance(new_index, Constant) + assert new_index.type.dtype == "uint8" + + # `AdvancedIncSubtensor1`, convert + x = at.vector("x") + idx = at.as_tensor_variable(rng.integers(0, 10, size=10).astype(np.int64)) + z = x[idx, None] + + z_fn = aesara.function([x], z, mode=mode) + + subtensor_node = z_fn.maker.fgraph.outputs[0].owner + assert isinstance(subtensor_node.op, AdvancedSubtensor) + new_index = subtensor_node.inputs[1] + assert isinstance(new_index, Constant) + assert new_index.type.dtype == "uint8" diff --git a/tests/tensor/test_opt_uncanonicalize.py b/tests/tensor/rewriting/test_uncanonicalize.py similarity index 94% rename from tests/tensor/test_opt_uncanonicalize.py rename to tests/tensor/rewriting/test_uncanonicalize.py index 784ddb1c87..0f0bcd8534 100644 --- a/tests/tensor/test_opt_uncanonicalize.py +++ b/tests/tensor/rewriting/test_uncanonicalize.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import aesara import aesara.tensor as at @@ -6,14 +7,14 @@ from aesara import scalar as aes from aesara.configdefaults import config from aesara.graph.fg import FunctionGraph -from aesara.graph.opt import out2in +from aesara.graph.rewriting.basic import out2in from aesara.link.basic import PerformLinker from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.math import MaxAndArgmax from aesara.tensor.math import max as at_max from aesara.tensor.math import max_and_argmax from aesara.tensor.math import min as at_min -from aesara.tensor.opt_uncanonicalize import ( +from aesara.tensor.rewriting.uncanonicalize import ( local_alloc_dimshuffle, local_dimshuffle_alloc, local_dimshuffle_subtensor, @@ -218,3 +219,11 @@ def test_local_dimshuffle_subtensor(): assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval( {x: np.ones((5, 1, 6, 7))} ).shape == (5, 3, 7) + + +def test_deprecations(): + """Make sure we can import from deprecated modules.""" + with pytest.deprecated_call(): + from aesara.tensor.opt_uncanonicalize import ( # noqa: F401 F811 + local_reshape_dimshuffle, + ) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index f764c5aaf4..509b651085 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -1,12 +1,9 @@ import itertools -import warnings -from copy import copy, deepcopy from functools import partial from tempfile import mkstemp import numpy as np import pytest -from numpy.testing import assert_array_equal import aesara import aesara.scalar as aes @@ -63,6 +60,7 @@ join, make_vector, mgrid, + moveaxis, nonzero, nonzero_values, ogrid, @@ -90,7 +88,7 @@ ) from aesara.tensor.elemwise import DimShuffle from aesara.tensor.exceptions import NotScalarConstantError -from aesara.tensor.math import dense_dot, eq +from aesara.tensor.math import dense_dot from aesara.tensor.math import sum as at_sum from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padright, specify_shape from aesara.tensor.type import ( @@ -115,7 +113,6 @@ ivector, lscalar, lvector, - matrices, matrix, row, scalar, @@ -137,7 +134,6 @@ _good_broadcast_unary_normal, _grad_broadcast_unary_normal, eval_outputs, - get_numeric_types, inplace_func, integers, integers_ranged, @@ -150,6 +146,8 @@ ) +pytestmark = pytest.mark.filterwarnings("error") + if config.mode == "FAST_COMPILE": mode_opt = "FAST_RUN" else: @@ -868,7 +866,7 @@ def check(dtype, N, M_=None, k=0): assert np.allclose(result, np.tri(N, M_, k, dtype=dtype)) assert result.dtype == np.dtype(dtype) - for dtype in ALL_DTYPES: + for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]: check(dtype, 3) # M != N, k = 0 check(dtype, 3, 5) @@ -884,6 +882,10 @@ def check(dtype, N, M_=None, k=0): check(dtype, 5, 3, -1) def test_tril_triu(self): + """ + TODO FIXME: Parameterize this. + """ + def check_l(m, k=0): m_symb = matrix(dtype=m.dtype) k_symb = iscalar() @@ -936,7 +938,7 @@ def check_u_batch(m): assert np.allclose(result, np.triu(m, k)) assert result.dtype == np.dtype(dtype) - for dtype in ALL_DTYPES: + for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]: m = random_of_dtype((10, 10), dtype) check_l(m, 0) check_l(m, 1) @@ -1022,6 +1024,12 @@ def check(m): rand2d[:4] = 0 check(rand2d) + # Test passing a list + m = [1, 2, 0] + out = flatnonzero(m) + f = function([], out) + assert np.array_equal(f(), np.flatnonzero(m)) + @config.change_flags(compute_test_value="raise") def test_nonzero_values(self): def check(m): @@ -1069,8 +1077,9 @@ def test_can_use_numpy_types(self): f = function([x], y) assert f(np.array([1, 2], dtype=np.int32)).dtype == np.int64 - def test_good_between_real_types(self): - good = itertools.chain( + @pytest.mark.parametrize( + "test_name, obj_dtype", + itertools.chain( multi_dtype_cast_checks((2,), dtypes=REAL_DTYPES), # Casts from foo to foo [ @@ -1080,32 +1089,34 @@ def test_good_between_real_types(self): ) for dtype in ALL_DTYPES ], - ) - for testname, (obj, dtype) in good: - inp = vector(dtype=obj.dtype) - out = cast(inp, dtype=dtype) - f = function([inp], out) - assert f(obj).dtype == np.dtype(dtype) - - # Test astype too - out2 = inp.astype(dtype=dtype) - assert out2.type == out.type - - def test_cast_from_real_to_complex(self): - for real_dtype in REAL_DTYPES: - for complex_dtype in COMPLEX_DTYPES: - inp = vector(dtype=real_dtype) - out = cast(inp, dtype=complex_dtype) - f = function([inp], out) - obj = random_of_dtype((2,), real_dtype) - assert f(obj).dtype == np.dtype(complex_dtype) - - def test_cast_from_complex_to_real_raises_error(self): - for real_dtype in REAL_DTYPES: - for complex_dtype in COMPLEX_DTYPES: - inp = vector(dtype=real_dtype) - with pytest.raises(TypeError): - tensor(cast(inp, dtype=complex_dtype)) + ), + ) + def test_good_between_real_types(self, test_name, obj_dtype): + (obj, dtype) = obj_dtype + inp = vector(dtype=obj.dtype) + out = cast(inp, dtype=dtype) + f = function([inp], out) + assert f(obj).dtype == np.dtype(dtype) + + # Test astype too + out2 = inp.astype(dtype=dtype) + assert out2.type == out.type + + @pytest.mark.parametrize("real_dtype", REAL_DTYPES) + @pytest.mark.parametrize("complex_dtype", COMPLEX_DTYPES) + def test_cast_from_real_to_complex(self, real_dtype, complex_dtype): + inp = vector(dtype=real_dtype) + out = cast(inp, dtype=complex_dtype) + f = function([inp], out) + obj = random_of_dtype((2,), real_dtype) + assert f(obj).dtype == np.dtype(complex_dtype) + + @pytest.mark.parametrize("real_dtype", REAL_DTYPES) + @pytest.mark.parametrize("complex_dtype", COMPLEX_DTYPES) + def test_cast_from_complex_to_real_raises_error(self, real_dtype, complex_dtype): + inp = vector(dtype=complex_dtype) + with pytest.raises(TypeError): + cast(inp, dtype=real_dtype) # TODO: consider moving this function / functionality to gradient.py @@ -1113,46 +1124,6 @@ def test_cast_from_complex_to_real_raises_error(self): # gradient numerically -def test_nan_inf_constant_signature(): - # Test that the signature of a constant tensor containing NaN and Inf - # values is correct. - test_constants = [ - [np.nan, np.inf, 0, 1], - [np.nan, np.inf, -np.inf, 1], - [0, np.inf, -np.inf, 1], - [0, 3, -np.inf, 1], - [0, 3, np.inf, 1], - [np.nan, 3, 4, 1], - [0, 3, 4, 1], - np.nan, - np.inf, - -np.inf, - 0, - 1, - ] - n = len(test_constants) - # We verify that signatures of two rows i, j in the matrix above are - # equal if and only if i == j. - for i in range(n): - for j in range(n): - x = constant(test_constants[i]) - y = constant(test_constants[j]) - assert (x.signature() == y.signature()) == (i == j) - - # Also test that nan !=0 and nan != nan. - x = scalar() - mode = get_default_mode() - if isinstance(mode, aesara.compile.debugmode.DebugMode): - # Disable the check preventing usage of NaN / Inf values. - # We first do a copy of the mode to avoid side effects on other tests. - mode = copy(mode) - mode.check_isfinite = False - f = aesara.function([x], eq(x, np.nan), mode=mode) - - assert f(0) == 0 - assert f(np.nan) == 0 - - def test_basic_allclose(): # This was raised by a user in https://github.com/Theano/Theano/issues/2975 assert tm._allclose(-0.311023883434, -0.311022856884) @@ -1243,6 +1214,9 @@ def test_input_validation(self): with pytest.raises(TypeError, match=".*integer.*"): Split(2)(matrix(), dscalar(), [1, 1]) + with pytest.raises(TypeError, match=".*integer.*"): + Split(2)(matrix(), ivector(), [1, 1]) + with pytest.raises(TypeError, match=".*integer.*"): join(dscalar(), matrix(), matrix()) @@ -1321,16 +1295,9 @@ def test_stack_scalar_make_vector_constant(self): def test_stack_new_interface(self): # Test the new numpy-like interface: stack(tensors, axis=0). - # Testing against old interface - warnings.simplefilter("always", DeprecationWarning) a = imatrix("a") b = imatrix("b") - s1 = stack(a, b) - s2 = stack([a, b]) - f = function([a, b], [s1, s2], mode=self.mode) - v1, v2 = f([[1, 2]], [[3, 4]]) - assert v1.shape == v2.shape - assert np.all(v1 == v2) + # Testing axis parameter s3 = stack([a, b], 1) f = function([a, b], s3, mode=self.mode) @@ -1362,18 +1329,10 @@ def test_stack_new_interface(self): stack([a, b], 4) with pytest.raises(IndexError): stack([a, b], -4) + # Testing depreciation warning - with warnings.catch_warnings(record=True) as w: + with pytest.warns(DeprecationWarning): s = stack(a, b) - assert len(w) == 1 - assert issubclass(w[-1].category, DeprecationWarning) - with warnings.catch_warnings(record=True) as w: - s = stack([a, b]) - s = stack([a, b], 1) - s = stack([a, b], axis=1) - s = stack(tensors=[a, b]) - s = stack(tensors=[a, b], axis=1) - assert not w def test_stack_hessian(self): # Test the gradient of stack when used in hessian, see gh-1589 @@ -1500,6 +1459,13 @@ def test_roll(self): assert (out == want).all() + a = [1, 2, 3, 4, 5, 6] + b = roll(a, get_shift(2)) + want = np.array([5, 6, 1, 2, 3, 4]) + out = aesara.function([], b)() + + assert (out == want).all() + def test_stack_vector(self): a = self.shared(np.array([1, 2, 3], dtype=self.floatX)) b = as_tensor_variable(np.array([7, 8, 9], dtype=self.floatX)) @@ -1972,6 +1938,12 @@ def test_split_neg(self): with pytest.raises(ValueError): f() + def test_split_static_shape(self): + x = TensorType("floatX", shape=(5,))("x") + s = iscalar("s") + y = Split(2)(x, 0, [s, 5 - s])[0] + assert y.type.shape == (None,) + def test_join_inplace(): # Test join to work inplace. @@ -2071,17 +2043,17 @@ def test_ScalarFromTensor(cast_policy): scalar_from_tensor(vector()) -class TestOpCache: - def test_basic(self): - # trigger bug in ticket #162 - v = matrix() - v.name = "v" - gv = fill(v / v, 1.0) / v - (fill(v / v, 1.0) * v) / (v * v) - fn_py = inplace_func([v], gv) - fn_c_or_py = inplace_func([v], gv) +def test_op_cache(): + # TODO: What is this actually testing? + # trigger bug in ticket #162 + v = matrix() + v.name = "v" + gv = fill(v / v, 1.0) / v - (fill(v / v, 1.0) * v) / (v * v) + fn_py = inplace_func([v], gv) + fn_c_or_py = inplace_func([v], gv) - a = random(5, 2).astype(config.floatX) - assert np.all(fn_py(a) == fn_c_or_py(a)) + a = random(5, 2).astype(config.floatX) + assert np.all(fn_py(a) == fn_c_or_py(a)) def test_dimshuffle(): @@ -2221,6 +2193,11 @@ def test_is_flat(): def test_tile(): + """ + TODO FIXME: Split this apart and parameterize. Also, find out why it's + unreasonably slow. + """ + def run_tile(x, x_, reps, use_symbolic_reps): if use_symbolic_reps: rep_symbols = [iscalar() for _ in range(len(reps))] @@ -2258,6 +2235,20 @@ def run_tile(x, x_, reps, use_symbolic_reps): == np.tile(x_, (2, 3, 4, 6)) ) + # Test passing a float + x = scalar() + x_val = 1.0 + assert np.array_equal( + run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,)) + ) + + # Test when x is a list + x = matrix() + x_val = [[1.0, 2.0], [3.0, 4.0]] + assert np.array_equal( + run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,)) + ) + # Test when reps is integer, scalar or vector. # Test 1,2,3,4-dimensional cases. # Test input x has the shape [2], [2, 4], [2, 4, 3], [2, 4, 3, 5]. @@ -2602,96 +2593,69 @@ def test_default_start(self, cast_policy): fstop_v32 = np.float32(fstop_v) assert np.all(ff(fstop_v32) == np.arange(fstop_v)) - @pytest.mark.parametrize( - "cast_policy", - [ - "custom", - "numpy+floatX", - ], - ) - def test_upcast(self, cast_policy): + @config.change_flags(cast_policy="custom") + def test_upcast_custom(self): """Test that arange computes output type adequately.""" - with config.change_flags(cast_policy=cast_policy): - if config.cast_policy == "custom": - assert arange(iscalar()).dtype == "int64" - assert arange(fscalar()).dtype == fscalar().dtype - assert arange(dscalar()).dtype == dscalar().dtype + assert arange(iscalar()).dtype == "int64" + assert arange(fscalar()).dtype == fscalar().dtype + assert arange(dscalar()).dtype == dscalar().dtype - # int32 + float32 -> float64 - assert arange(iscalar(), fscalar()).dtype == dscalar().dtype - assert arange(iscalar(), dscalar()).dtype == dscalar().dtype - assert arange(fscalar(), dscalar()).dtype == dscalar().dtype + # int32 + float32 -> float64 + assert arange(iscalar(), fscalar()).dtype == dscalar().dtype + assert arange(iscalar(), dscalar()).dtype == dscalar().dtype + assert arange(fscalar(), dscalar()).dtype == dscalar().dtype - assert arange(iscalar(), fscalar(), dscalar()).dtype == dscalar().dtype - elif config.cast_policy == "numpy+floatX": - for dtype in get_numeric_types(): - # Test with a single argument. - arange_dtype = arange(scalar(dtype=str(dtype))).dtype - numpy_dtype = np.arange(np.array(1, dtype=dtype)).dtype - if ( - dtype != "float64" - and numpy_dtype == "float64" - and config.cast_policy == "numpy+floatX" - and config.floatX == "float32" - ): - # We want a float32 arange. - assert arange_dtype == "float32" - else: - # Follow numpy. - assert arange_dtype == numpy_dtype - - # Test with two arguments. - for stop_dtype in get_numeric_types(): - arange_dtype = arange( - start=scalar(dtype=str(dtype)), - stop=scalar(dtype=str(stop_dtype)), - ).dtype - numpy_dtype = np.arange( - start=np.array(0, dtype=dtype), - stop=np.array(1, dtype=stop_dtype), - ).dtype - if ( - dtype != "float64" - and stop_dtype != "float64" - and numpy_dtype == "float64" - and config.cast_policy == "numpy+floatX" - and config.floatX == "float32" - ): - # We want a float32 arange. - assert arange_dtype == "float32" - else: - # Follow numpy. - assert arange_dtype == numpy_dtype - - # Test with three arguments. - for step_dtype in get_numeric_types(): - arange_dtype = arange( - start=scalar(dtype=str(dtype)), - stop=scalar(dtype=str(stop_dtype)), - step=scalar(dtype=str(step_dtype)), - ).dtype - numpy_dtype = np.arange( - start=np.array(0, dtype=dtype), - stop=np.array(1, dtype=stop_dtype), - step=np.array(1, dtype=step_dtype), - ).dtype - if ( - dtype != "float64" - and stop_dtype != "float64" - and step_dtype != "float64" - and numpy_dtype == "float64" - and config.cast_policy == "numpy+floatX" - and config.floatX == "float32" - ): - # We want a float32 arange. - assert arange_dtype == "float32" - else: - # Follow numpy. - assert arange_dtype == numpy_dtype + assert arange(iscalar(), fscalar(), dscalar()).dtype == dscalar().dtype + + @pytest.mark.parametrize( + "dtype", [dtype for dtype in ALL_DTYPES if not dtype.startswith("complex")] + ) + @pytest.mark.parametrize( + "stop_dtype", [dtype for dtype in ALL_DTYPES if not dtype.startswith("complex")] + ) + @config.change_flags(cast_policy="numpy+floatX") + def test_upcast_numpy(self, dtype, stop_dtype): + """Make sure our `ARange` output dtypes match NumPy's under different casting policies.""" + # Test with a single argument. + arange_dtype = arange(scalar(dtype=str(dtype))).dtype + numpy_dtype = np.arange(np.array(1, dtype=dtype)).dtype + if ( + dtype != "float64" + and numpy_dtype == "float64" + and config.cast_policy == "numpy+floatX" + and config.floatX == "float32" + ): + # We want a float32 arange. + assert arange_dtype == "float32" + else: + # Follow numpy. + assert arange_dtype == numpy_dtype + + # Test with two arguments. + arange_dtype = arange( + start=scalar(dtype=str(dtype)), + stop=scalar(dtype=str(stop_dtype)), + ).dtype + numpy_dtype = np.arange( + start=np.array(0, dtype=dtype), + stop=np.array(1, dtype=stop_dtype), + ).dtype + + if ( + dtype != "float64" + and stop_dtype != "float64" + and numpy_dtype == "float64" + and config.cast_policy == "numpy+floatX" + and config.floatX == "float32" + ): + # We want a float32 arange. + assert arange_dtype == "float32" + else: + # Follow numpy. + assert arange_dtype == numpy_dtype def test_dtype_cache(self): - # Checks that the same Op is returned on repeated calls to arange - # using the same dtype, but not for different dtypes. + """Check that the same `Op` is returned on repeated calls to `ARange` using the same dtype.""" start, stop, step = iscalars("start", "stop", "step") out1 = arange(start, stop, step) @@ -2858,6 +2822,12 @@ def test_dim1(self): assert np.all(p_val[inv_val] == np.arange(10)) assert np.all(inv_val[p_val] == np.arange(10)) + # Test passing a list + p = [2, 4, 3, 0, 1] + inv = at.inverse_permutation(p) + f = aesara.function([], inv) + assert np.array_equal(f(), np.array([3, 4, 0, 2, 1])) + def test_dim2(self): # Test the inversion of several permutations at a time # Each row of p is a different permutation to inverse @@ -3048,18 +3018,8 @@ def test_default_state(): assert np.allclose(f(np.asarray(2.2, dtype=config.floatX)), 7) -def test_autocast(): - # Call test functions for all possible values of `config.cast_policy`. - for autocast_cfg in ( - "custom", - # 'numpy', # Commented out until it is implemented properly. - "numpy+floatX", - ): - with config.change_flags(cast_policy=autocast_cfg): - eval("_test_autocast_" + autocast_cfg.replace("+", "_"))() - - -def _test_autocast_custom(): +@config.change_flags(cast_policy="custom") +def test_autocast_custom(): # Called from `test_autocast`. assert config.cast_policy == "custom" orig_autocast = autocast_float.dtypes @@ -3109,10 +3069,10 @@ def _test_autocast_custom(): assert (fvector() + 1.0).dtype == "float32" assert (dvector() + np.float32(1.1)).dtype == "float64" assert (dvector() + np.float64(1.1)).dtype == "float64" - assert (dvector() + np.float(1.1)).dtype == "float64" + assert (dvector() + float(1.1)).dtype == "float64" assert (fvector() + np.float32(1.1)).dtype == "float32" assert (fvector() + np.float64(1.1)).dtype == "float64" - assert (fvector() + np.float(1.1)).dtype == config.floatX + assert (fvector() + float(1.1)).dtype == config.floatX assert (lvector() + np.int64(1)).dtype == "int64" assert (lvector() + np.int32(1)).dtype == "int64" assert (lvector() + np.int16(1)).dtype == "int64" @@ -3124,7 +3084,9 @@ def _test_autocast_custom(): assert (fvector() + 1.0).dtype == "float64" -def _test_autocast_numpy(): +@pytest.mark.skip(reason="Not implemented") +@config.change_flags(cast_policy="numpy") +def test_autocast_numpy(): # Called from `test_autocast`. assert config.cast_policy == "numpy" # Go through some typical scalar values. @@ -3144,7 +3106,8 @@ def ok(z): ok(n_x) -def _test_autocast_numpy_floatX(): +@config.change_flags(cast_policy="numpy+floatX") +def test_autocast_numpy_floatX(): # Called from `test_autocast`. assert config.cast_policy == "numpy+floatX" @@ -3520,6 +3483,12 @@ def test_diag(self): with pytest.raises(ValueError): diag(xx) + # Test passing a list + xx = [[1, 2], [3, 4]] + g = diag(xx) + f = function([], g) + assert np.array_equal(f(), np.diag(xx)) + def test_infer_shape(self): rng = np.random.default_rng(utt.fetch_seed()) @@ -3625,30 +3594,6 @@ def test_alloc_diag_values(self): assert np.all(true_grad_input == grad_input) -class TestNumpyAssumptions: - # Verify that some assumptions Aesara makes on Numpy's behavior still hold. - def test_ndarray_copy(self): - # A copy or deepcopy of the ndarray type should not create a new object. - # - # This is because Aesara makes some comparisons of the form: - # if type(x) is np.ndarray - assert copy(np.ndarray) is np.ndarray - assert deepcopy(np.ndarray) is np.ndarray - - def test_dtype_equality(self): - # Ensure dtype string comparisons are consistent. - # - # Aesara often uses string representations of dtypes (e.g. 'float32'). We - # need to make sure that comparing the string representations is the same - # as comparing the dtype objects themselves. - dtypes = get_numeric_types(with_complex=True) - # Perform all pairwise comparisons of dtypes, making sure comparing - # their string representation yields the same result. - for dtype1_idx, dtype1 in enumerate(dtypes): - for dtype2 in dtypes[dtype1_idx + 1 :]: - assert (dtype1 == dtype2) == (str(dtype1) == str(dtype2)) - - def test_transpose(): x1 = dvector("x1") x2 = dmatrix("x2") @@ -4004,75 +3949,6 @@ def test_ARange(self): ) -class TestTensorInstanceMethods: - def setup_method(self): - self.vars = matrices("X", "Y") - self.vals = [m.astype(config.floatX) for m in [random(2, 2), random(2, 2)]] - - def test_repeat(self): - X, _ = self.vars - x, _ = self.vals - assert_array_equal(X.repeat(2).eval({X: x}), x.repeat(2)) - - def test_trace(self): - X, _ = self.vars - x, _ = self.vals - assert_array_equal(X.trace().eval({X: x}), x.trace()) - - def test_ravel(self): - X, _ = self.vars - x, _ = self.vals - assert_array_equal(X.ravel().eval({X: x}), x.ravel()) - - def test_diagonal(self): - X, _ = self.vars - x, _ = self.vals - assert_array_equal(X.diagonal().eval({X: x}), x.diagonal()) - assert_array_equal(X.diagonal(1).eval({X: x}), x.diagonal(1)) - assert_array_equal(X.diagonal(-1).eval({X: x}), x.diagonal(-1)) - for offset, axis1, axis2 in [(1, 0, 1), (-1, 0, 1), (0, 1, 0), (-2, 1, 0)]: - assert_array_equal( - X.diagonal(offset, axis1, axis2).eval({X: x}), - x.diagonal(offset, axis1, axis2), - ) - - def test_take(self): - X, _ = self.vars - x, _ = self.vals - indices = [1, 0, 3] - assert_array_equal(X.take(indices).eval({X: x}), x.take(indices)) - indices = [1, 0, 1] - assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1)) - indices = np.array([-10, 5, 12], dtype="int32") - assert_array_equal( - X.take(indices, 1, mode="wrap").eval({X: x}), - x.take(indices, 1, mode="wrap"), - ) - assert_array_equal( - X.take(indices, -1, mode="wrap").eval({X: x}), - x.take(indices, -1, mode="wrap"), - ) - assert_array_equal( - X.take(indices, 1, mode="clip").eval({X: x}), - x.take(indices, 1, mode="clip"), - ) - assert_array_equal( - X.take(indices, -1, mode="clip").eval({X: x}), - x.take(indices, -1, mode="clip"), - ) - # Test error handling - with pytest.raises(IndexError): - X.take(indices).eval({X: x}) - with pytest.raises(IndexError): - (2 * X.take(indices)).eval({X: x}) - with pytest.raises(TypeError): - X.take([0.0]) - indices = [[1, 0, 1], [0, 1, 1]] - assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1)) - # Test equivalent advanced indexing - assert_array_equal(X[:, indices].eval({X: x}), x[:, indices]) - - class TestSwapaxes: def test_no_dimensional_input(self): with pytest.raises(IndexError): @@ -4109,6 +3985,23 @@ def test_numpy_compare(self): assert np.allclose(n_s, t_s) +def test_moveaxis(): + x = at.zeros((3, 4, 5)) + tuple(moveaxis(x, 0, -1).shape.eval()) == (4, 5, 3) + tuple(moveaxis(x, -1, 0).shape.eval()) == (5, 3, 4) + tuple(moveaxis(x, [0, 1], [-1, -2]).shape.eval()) == (5, 4, 3) + tuple(moveaxis(x, [0, 1, 2], [-1, -2, -3]).shape.eval()) == (5, 4, 3) + + +def test_moveaxis_error(): + x = at.zeros((3, 4, 5)) + with pytest.raises( + ValueError, + match="`source` and `destination` arguments must have the same number of elements", + ): + moveaxis(x, [0, 1], 0) + + class TestChoose(utt.InferShapeTester): op = staticmethod(choose) op_class = Choose @@ -4300,6 +4193,12 @@ def test_identity_like_dtype(): m_out_float = identity_like(m, dtype=np.float64) assert m_out_float.dtype == "float64" + # Test passing list + m = [[0, 1], [1, 3]] + out = at.identity_like(m) + f = aesara.function([], out) + assert np.array_equal(f(), np.eye(2)) + def test_atleast_Nd(): ary1 = dscalar() diff --git a/tests/tensor/test_basic_opt.py b/tests/tensor/test_basic_opt.py deleted file mode 100644 index d8e07d0a06..0000000000 --- a/tests/tensor/test_basic_opt.py +++ /dev/null @@ -1,3631 +0,0 @@ -import contextlib -import copy - -import numpy as np -import pytest - -import aesara -import aesara.scalar as aes -import aesara.tensor as at -from aesara import shared -from aesara.compile import optdb -from aesara.compile.function import function -from aesara.compile.mode import OPT_NONE, Mode, get_default_mode, get_mode -from aesara.compile.ops import DeepCopyOp, deep_copy_op -from aesara.configdefaults import config -from aesara.graph.basic import Apply, Constant, Variable -from aesara.graph.fg import FunctionGraph -from aesara.graph.op import Op -from aesara.graph.opt import check_stack_trace, local_optimizer, out2in -from aesara.graph.opt_utils import optimize_graph -from aesara.graph.optdb import OptimizationQuery -from aesara.graph.type import Type -from aesara.misc.safe_asarray import _asarray -from aesara.printing import pprint -from aesara.raise_op import Assert, CheckAndRaise -from aesara.scalar.basic import Composite -from aesara.tensor.basic import ( - Alloc, - Join, - MakeVector, - ScalarFromTensor, - Split, - TensorFromScalar, - alloc, - as_tensor_variable, - join, - second, - tile, -) -from aesara.tensor.basic_opt import ( - ShapeFeature, - assert_op, - local_alloc_sink_dimshuffle, - local_dimshuffle_lift, - local_merge_alloc, - local_reshape_to_dimshuffle, - local_useless_alloc, - local_useless_dimshuffle_in_reshape, - local_useless_elemwise, - local_useless_reshape, - register_specialize, -) -from aesara.tensor.elemwise import DimShuffle, Elemwise -from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique -from aesara.tensor.math import ( - add, - bitwise_and, - bitwise_or, - bitwise_xor, - cos, - cosh, - dot, - eq, - exp, - floor_div, - ge, - gt, - int_div, - invert, - iround, - le, - log, - log2, - log10, - lt, - maximum, - minimum, - mul, - neg, - neq, -) -from aesara.tensor.math import pow as at_pow -from aesara.tensor.math import reciprocal -from aesara.tensor.math import round as at_round -from aesara.tensor.math import sin, sinh, softplus, sqr, sqrt, sub -from aesara.tensor.math import sum as at_sum -from aesara.tensor.math import tan, tanh, true_div, xor -from aesara.tensor.math_opt import local_lift_transpose_through_dot -from aesara.tensor.shape import ( - Reshape, - Shape_i, - SpecifyShape, - Unbroadcast, - reshape, - shape, - specify_shape, - unbroadcast, -) -from aesara.tensor.subtensor import ( - AdvancedIncSubtensor1, - Subtensor, - advanced_inc_subtensor, - advanced_inc_subtensor1, - inc_subtensor, - set_subtensor, -) -from aesara.tensor.type import ( - TensorType, - dmatrices, - dmatrix, - dscalar, - dvector, - fmatrix, - fscalar, - fvector, - imatrices, - iscalar, - iscalars, - ivector, - lscalar, - lvector, - matrices, - matrix, - scalar, - scalars, - tensor, - tensor3, - tensor4, - values_eq_approx_remove_nan, - vector, - vectors, -) -from tests import unittest_tools as utt - - -mode_opt = config.mode -if mode_opt == "FAST_COMPILE": - mode_opt = "FAST_RUN" -mode_opt = get_mode(mode_opt) - -dimshuffle_lift = out2in(local_dimshuffle_lift) - -_optimizer_stabilize = OptimizationQuery(include=["fast_run"]) -_optimizer_stabilize.position_cutoff = 1.51 -_optimizer_stabilize = optdb.query(_optimizer_stabilize) - -_optimizer_specialize = OptimizationQuery(include=["fast_run"]) -_optimizer_specialize.position_cutoff = 2.01 -_optimizer_specialize = optdb.query(_optimizer_specialize) - -_optimizer_fast_run = OptimizationQuery(include=["fast_run"]) -_optimizer_fast_run = optdb.query(_optimizer_fast_run) - - -def ds(x, y): - return DimShuffle(x.type.broadcastable, y)(x) - - -def optimize(g, level="fast_run"): - if level == "fast_run": - _optimizer_fast_run.optimize(g) - elif level == "specialize": - _optimizer_specialize.optimize(g) - elif level == "stabilize": - _optimizer_stabilize.optimize(g) - else: - raise ValueError(level) - return g - - -def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)): - x = TensorType(shape=xbc, dtype="float64")("x") - y = TensorType(shape=ybc, dtype="float64")("y") - z = TensorType(shape=zbc, dtype="float64")("z") - return x, y, z - - -class TestDimshuffleLift: - def test_double_transpose(self): - x, y, z = inputs() - e = ds(ds(x, (1, 0)), (1, 0)) - g = FunctionGraph([x], [e]) - assert ( - str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))" - ) - dimshuffle_lift.optimize(g) - assert str(g) == "FunctionGraph(x)" - # no need to check_stack_trace as graph is supposed to be empty - - def test_merge2(self): - x, y, z = inputs() - e = ds(ds(x, (1, "x", 0)), (2, 0, "x", 1)) - g = FunctionGraph([x], [e]) - assert ( - str(g) - == "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))" - ), str(g) - dimshuffle_lift.optimize(g) - assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(g, ops_to_check="all") - - def test_elim3(self): - x, y, z = inputs() - e = ds(ds(ds(x, (0, "x", 1)), (2, 0, "x", 1)), (1, 0)) - g = FunctionGraph([x], [e]) - assert str(g) == ( - "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}" - "(InplaceDimShuffle{0,x,1}(x))))" - ), str(g) - dimshuffle_lift.optimize(g) - assert str(g) == "FunctionGraph(x)", str(g) - # no need to check_stack_trace as graph is supposed to be empty - - def test_lift(self): - x, y, z = inputs([False] * 1, [False] * 2, [False] * 3) - e = x + y + z - g = FunctionGraph([x, y, z], [e]) - - # It does not really matter if the DimShuffles are inplace - # or not. - init_str_g_inplace = ( - "FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}" - "(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))" - ) - init_str_g_noinplace = ( - "FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}" - "(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))" - ) - assert str(g) in (init_str_g_inplace, init_str_g_noinplace), str(g) - - opt_str_g_inplace = ( - "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}" - "(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))" - ) - opt_str_g_noinplace = ( - "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}" - "(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))" - ) - dimshuffle_lift.optimize(g) - assert str(g) in (opt_str_g_inplace, opt_str_g_noinplace), str(g) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(g, ops_to_check="all") - - def test_recursive_lift(self): - v = vector(dtype="float64") - m = matrix(dtype="float64") - out = ((v + 42) * (m + 84)).T - g = FunctionGraph([v, m], [out]) - init_str_g = ( - "FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}" - "(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}" - "(, " - "InplaceDimShuffle{x}(TensorConstant{42}))), " - "Elemwise{add,no_inplace}" - "(, " - "InplaceDimShuffle{x,x}(TensorConstant{84})))))" - ) - assert str(g) == init_str_g - new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)[0] - new_g = FunctionGraph(g.inputs, [new_out]) - opt_str_g = ( - "FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}" - "(InplaceDimShuffle{0,x}(), " - "InplaceDimShuffle{x,x}(TensorConstant{42})), " - "Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}" - "(), " - "InplaceDimShuffle{x,x}(TensorConstant{84}))))" - ) - assert str(new_g) == opt_str_g - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(new_g, ops_to_check="all") - - def test_useless_dimshuffle(self): - x, _, _ = inputs() - e = ds(x, (0, 1)) - g = FunctionGraph([x], [e]) - assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))" - dimshuffle_lift.optimize(g) - assert str(g) == "FunctionGraph(x)" - # Check stacktrace was copied over correctly after opt was applied - assert hasattr(g.outputs[0].tag, "trace") - - def test_dimshuffle_on_broadcastable(self): - x, y, z = inputs([False, True], [True, False, True], [False, False, True]) - u = at.constant(1) - ds_x = ds(x, (0, "x")) # useless - ds_y = ds(y, (2, 1, 0)) # useless - ds_z = ds(z, (2, 1, 0)) # useful - ds_u = ds(u, ("x")) # useful - g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u]) - assert ( - str(g) - == "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))" - ) - dimshuffle_lift.optimize(g) - assert ( - str(g) - == "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))" - ) - # Check stacktrace was copied over correctly after opt was applied - assert hasattr(g.outputs[0].tag, "trace") - - -def test_local_useless_dimshuffle_in_reshape(): - vec = TensorType(shape=(False,), dtype="float64")("vector") - mat = TensorType(shape=(False, False), dtype="float64")("mat") - row = TensorType(shape=(True, False), dtype="float64")("row") - col = TensorType(shape=(False, True), dtype="float64")("col") - - reshape_dimshuffle_vector = reshape(vec.dimshuffle("x", 0), vec.shape) - reshape_dimshuffle_mat = reshape(mat.dimshuffle("x", 0, "x", 1), mat.shape) - reshape_dimshuffle_row = reshape(row.dimshuffle(1, "x"), row.shape) - reshape_dimshuffle_col = reshape(col.dimshuffle(0), col.shape) - - g = FunctionGraph( - [vec, mat, row, col], - [ - reshape_dimshuffle_vector, - reshape_dimshuffle_mat, - reshape_dimshuffle_row, - reshape_dimshuffle_col, - ], - ) - - assert str(g) == ( - "FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), " - "Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), " - "Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), " - "Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))" - ) - useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape) - useless_dimshuffle_in_reshape.optimize(g) - assert str(g) == ( - "FunctionGraph(Reshape{1}(vector, Shape(vector)), " - "Reshape{2}(mat, Shape(mat)), " - "Reshape{2}(row, Shape(row)), " - "Reshape{2}(col, Shape(col)))" - ) - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(g, ops_to_check="all") - - # Check that the optimization does not get applied when the order - # of dimensions has changed. - reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape) - h = FunctionGraph([mat], [reshape_dimshuffle_mat2]) - str_h = str(h) - useless_dimshuffle_in_reshape.optimize(h) - assert str(h) == str_h - - -class TestFusion: - opts = OptimizationQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - "inplace", - ], - exclude=["cxx_only", "BlasOpt"], - ) - mode = Mode(get_default_mode().linker, opts) - _shared = staticmethod(shared) - topo_exclude = () - - def my_init(dtype="float64", num=0): - return np.zeros((5, 5), dtype=dtype) + num - - fw, fx, fy, fz = [ - tensor(dtype="float32", shape=[False] * 2, name=n) for n in "wxyz" - ] - dw, dx, dy, dz = [ - tensor(dtype="float64", shape=[False] * 2, name=n) for n in "wxyz" - ] - ix, iy, iz = [tensor(dtype="int32", shape=[False] * 2, name=n) for n in "xyz"] - fv = fvector("v") - fs = fscalar("s") - fwv = my_init("float32", 1) - fxv = my_init("float32", 2) - fyv = my_init("float32", 3) - fzv = my_init("float32", 4) - fvv = _asarray(np.random.random(5), dtype="float32") - fsv = np.asarray(np.random.random(), dtype="float32") - dwv = my_init("float64", 5) - ixv = _asarray(my_init(num=60), dtype="int32") - iyv = _asarray(my_init(num=70), dtype="int32") - izv = _asarray(my_init(num=70), dtype="int32") - fwx = fw + fx - ftanx = tan(fx) - - @pytest.mark.parametrize( - "case", - [ - ( - fx + fy + fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + fzv, - "float32", - ), # 0 - ( - fx * fy * fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv * fyv * fzv, - "float32", - ), # 1 - ( - fx + fy * fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv * fzv, - "float32", - ), # 2 - ( - fx * fy + fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv * fyv + fzv, - "float32", - ), # 3 - ( - fw + fx + fy + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + fx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), # 5 - ( - ((fw + fx) + fy) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + (fx + fy)) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + (fx + fy) + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - fw + (fx + (fy + fz)), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), - ( - (fw + fx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv, - "float32", - ), # 10 - ( - fw * fx * fy * fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv * fxv * fyv * fzv, - "float32", - ), - ( - fw + fx * fy * fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv * fyv * fzv, - "float32", - ), - ( - fx + fy * fz * fx, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv * fzv * fxv, - "float32", - ), - ( - fx * fy + fz + fy, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv * fyv + fzv + fyv, - "float32", - ), - ( - fx * fy * fz * fw + fx + fy + fz + fw, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fxv * fyv * fzv * fwv + fxv + fyv + fzv + fwv, - "float32", - ), # 15 - # test with constant - ( - (fw + fx) + (fy + fz) + 2.0, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - ((fw + fx) + 2.0 + fy) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - (fw + (fx + 2.0 + fy)) + fz, - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - (fw + (fx + fy) + 2 + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - ( - fw + (fx + (fy + fz) + 2.0), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), # 20 - ( - 2 + (fw + fx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 1, - fwv + fxv + fyv + fzv + 2, - "float32", - ), - # mix float32 and float64 - ( - 2 + (dw + fx) + (fy + fz), - (dw, fx, fy, fz), - (dwv, fxv, fyv, fzv), - 1, - dwv + fxv + fyv + fzv + 2, - "float64", - ), - ( - 2 + (fw + dw) + (fy + fz), - (fw, dw, fy, fz), - (fwv, dwv, fyv, fzv), - 1, - fwv + dwv + fyv + fzv + 2, - "float64", - ), - ( - 2 + (fw + fx) + (dw + fz), - (fw, fx, dw, fz), - (fwv, fxv, dwv, fzv), - 1, - fwv + fxv + dwv + fzv + 2, - "float64", - ), - ( - 2 + (fw + fx) + (fy + dw), - (fw, fx, fy, dw), - (fwv, fxv, fyv, dwv), - 1, - fwv + fxv + fyv + dwv + 2, - "float64", - ), # 25 - # test when their is other op then elemwise. - ( - (fwx.sum()) + (fwx) + (fy + fz), - (fw, fx, fy, fz), - (fwv, fxv, fyv, fzv), - 4, - (fwv + fxv).sum() + fwv + fxv + fyv + fzv, - "float32", - ), - # test other elemwise op - ( - fx + fy + cos(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.cos(fzv), - "float32", - ), - ( - fx + fy + cosh(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.cosh(fzv), - "float32", - ), - ( - fx + fy + abs(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.absolute(fzv), - "float32", - ), - ( - ix + iy + abs(iz), - (ix, iy, iz), - (ixv, iyv, izv), - 1, - ixv + iyv + np.absolute(izv), - "int32", - ), # 30 - ( - fx + fy + log(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.log(fzv), - "float32", - ), - ( - fx + fy + log2(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.log2(fzv), - "float32", - ), - ( - fx + fy + log10(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.log10(fzv), - "float32", - ), - ( - fx + fy**fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv**fzv, - "float32", - ), # pow - ( - fx + fy + exp(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv + fyv + np.exp(fzv), - "float32", - ), # 35 - ( - fx - fy - fz, - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv - fzv, - "float32", - ), - ( - fx - (fy / fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv / fzv), - "float32", - ), - ( - fx - true_div(fy, 2), - (fx, fy), - (fxv, fyv), - 1, - fxv - (fyv / 2), - "float32", - ), - ( - fx - true_div(fy, fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv / fzv), - "float32", - ), - ( - fx - int_div(ix * 100, iy * 1000), - (fx, ix, iy), - (fxv, ixv, iyv), - 1, - fxv - ((ixv * 100) // (iyv * 1000)), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), # 40 - (fx - (fy / 2), (fx, fy), (fxv, fyv), 1, fxv - (fyv / 2), "float32"), - ( - fx - (fy % fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv % fzv), - "float32", - ), - ( - fx - (fy > fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv > fzv), - "float32", - ), - ( - fx - (fy >= fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv >= fzv), - "float32", - ), - ( - fx - (fy < fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv < fzv), - "float32", - ), # 45 - ( - fx - (fy <= fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv <= fzv), - "float32", - ), - ( - fx - eq(fy, fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv == fzv), - "float32", - ), - ( - fx - neq(fy, fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - (fyv != fzv), - "float32", - ), - ( - fx - fy + tan(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.tan(fzv), - "float32", - ), - ( - fx - fy + tanh(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.tanh(fzv), - "float32", - ), # 50 - ( - fx - fy + sin(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.sin(fzv), - "float32", - ), - ( - fx - fy + sinh(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.sinh(fzv), - "float32", - ), - ( - fx - fy + sqr(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + (fzv * fzv), - "float32", - ), - ( - fx - fy + sqrt(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.sqrt(fzv), - "float32", - ), - ( - fx - fy + reciprocal(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + (1 / fzv), - "float32", - ), # 55 - ( - fx - fy + neg(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + (-fzv), - "float32", - ), - ( - fx - fy + at_round(fz), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - fxv - fyv + np.round(fzv), - "float32", - ), - ( - ix - iy + iround(fz), - (ix, iy, fz), - (ixv, iyv, fzv), - 1, - ixv - iyv + np.round(fzv), - "int64", - ), - # Bit op - ( - fx - bitwise_or(iy, iz), - (fx, iy, iz), - (fxv, iyv, izv), - 1, - fxv - (iyv | izv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), - ( - fx - xor(iy, iz), - (fx, iy, iz), - (fxv, iyv, izv), - 1, - fxv - (iyv ^ izv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), # 60 - ( - fx - bitwise_and(iy, iz), - (fx, iy, iz), - (fxv, iyv, izv), - 1, - fxv - (iyv & izv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), - ( - fx - invert(iy), - (fx, iy), - (fxv, iyv), - 1, - fxv - (~iyv), - { - "custom": "float64", - "numpy + floatX": config.floatX, - "numpy": "float64", - }, - ), - ( - fx - at.cast(fy, dtype="float64"), - (fx, fy), - (fxv, fyv), - 1, - fxv - np.asarray(fyv, "float64"), - "float64", - ), - ( - at_pow(fx * fy + fz, fx * fy), - (fx, fy, fz), - (fxv, fyv, fzv), - 1, - np.power(fxv * fyv + fzv, fxv * fyv), - "float32", - ), - ( - fv + fy**fz, - (fv, fy, fz), - (fvv, fyv, fzv), - 2, - fvv + fyv**fzv, - "float32", - ), # fused with a dimshuffle #65 - ( - fv - fy + tanh(fz), - (fv, fy, fz), - (fvv, fyv, fzv), - 2, - fvv - fyv + np.tanh(fzv), - "float32", - ), # fused with a dimshuffle - # Cases where the same input is reused many times. - ( - mul(fx, fx, fx, fx), - (fx,), - (fxv,), - 1, - fxv * fxv * fxv * fxv, - "float32", - ), - ( - mul(fx, ftanx, ftanx), - (fx,), - (fxv,), - 1, - fxv * np.tan(fxv) * np.tan(fxv), - "float32", - ), - ( - mul(fx, ftanx, ftanx, fx), - (fx,), - (fxv,), - 1, - fxv * np.tan(fxv) * np.tan(fxv) * fxv, - "float32", - ), - ( - mul(ftanx, ftanx, fx + fy), - (fx, fy), - (fxv, fyv), - 1, - np.tan(fxv) * np.tan(fxv) * (fxv + fyv), - "float32", - ), # 70 - # Cases with different broadcast pattern. They should not - # be merged as this would duplicate computation - # The graph should have 2 elemwise and 1 dimshuffle - ( - fx * sin(fs), - (fx, fs), - (fxv, fsv), - 3, - fxv * np.sin(fsv), - "float32", - ), - ], - ) - def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True): - """Verify that `Elemwise` fusion works.""" - - g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case - - if isinstance(out_dtype, dict): - out_dtype = out_dtype[config.cast_policy] - - if self._shared is None: - f = function(list(sym_inputs), g, mode=self.mode) - for x in range(nb_repeat): - out = f(*val_inputs) - else: - out = self._shared(np.zeros((5, 5), dtype=out_dtype), "out") - assert out.dtype == g.dtype - f = function(sym_inputs, [], updates=[(out, g)], mode=self.mode) - for x in range(nb_repeat): - f(*val_inputs) - out = out.get_value() - - atol = 1e-8 - if out_dtype == "float32": - atol = 1e-6 - - assert np.allclose(out, answer * nb_repeat, atol=atol) - - topo = f.maker.fgraph.toposort() - topo_ = [n for n in topo if not isinstance(n.op, self.topo_exclude)] - if assert_len_topo: - - assert len(topo_) == nb_elemwise - - if nb_elemwise == 1: - # if no variable appears multiple times in the - # input of g, - # check that the number of input to the Composite - # Elemwise is ok - if len(set(g.owner.inputs)) == len(g.owner.inputs): - expected_len_sym_inputs = sum( - not isinstance(x, Constant) for x in topo_[0].inputs - ) - assert expected_len_sym_inputs == len(sym_inputs) - - assert out_dtype == out.dtype - - def test_fusion_35_inputs(self): - r"""Make sure we don't fuse too many `Op`\s and go past the 31 function arguments limit.""" - inpts = vectors(["i%i" % i for i in range(35)]) - - # Make an elemwise graph looking like: - # sin(i34 + sin(i33 + sin(... i1 + sin(i0) ...))) - out = sin(inpts[0]) - for idx in range(1, 35): - out = sin(inpts[idx] + out) - - with config.change_flags(cxx=""): - f = function(inpts, out, mode=self.mode) - - # Make sure they all weren't fused - composite_nodes = [ - node - for node in f.maker.fgraph.toposort() - if isinstance(getattr(node.op, "scalar_op", None), aes.basic.Composite) - ] - assert not any(len(node.inputs) > 31 for node in composite_nodes) - - @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") - def test_big_fusion(self): - # In the past, pickle of Composite generated in that case - # crashed with max recursion limit. So we were not able to - # generate C code in that case. - factors = [] - sd = dscalar() - means = dvector() - - cst_05 = at.constant(0.5) - cst_m05 = at.constant(-0.5) - cst_2 = at.constant(2) - cst_m2 = at.constant(-2) - ones = at.constant(np.ones(10)) - n = 85 - if config.mode in ["DebugMode", "DEBUG_MODE"]: - n = 10 - - for i in range(n): - f = cst_m05 * sd**cst_m2 * (ones - means[i]) ** cst_2 + cst_05 * log( - cst_05 * (sd**cst_m2) / np.pi - ) - factors.append(at_sum(f)) - - logp = add(*factors) - - vars = [sd, means] - - # Make sure that C compilation is used - mode = Mode("cvm", self.opts) - dlogp = function(vars, [aesara.grad(logp, v) for v in vars], mode=mode) - - # Make sure something was fused - assert any( - isinstance(getattr(node.op, "scalar_op", None), aes.basic.Composite) - for node in dlogp.maker.fgraph.toposort() - ) - - def test_add_mul_fusion_inplace(self): - - opts = OptimizationQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - "inplace", - ], - exclude=["cxx_only", "BlasOpt"], - ) - - mode = Mode(self.mode.linker, opts) - - x, y, z = dmatrices("xyz") - out = dot(x, y) + x + y + z - f = function([x, y, z], out, mode=mode) - topo = [n for n in f.maker.fgraph.toposort()] - assert len(topo) == 2 - assert topo[-1].op.inplace_pattern - - new_out = f.maker.fgraph.outputs[0] - assert isinstance(new_out.owner.op, Elemwise) - assert isinstance(new_out.owner.op.scalar_op, aes.basic.Add) - assert len(new_out.owner.inputs) == 4 - - # TODO: Do we really need to do this? - _ = f( - np.random.random((5, 5)), np.random.random((5, 5)), np.random.random((5, 5)) - ) - - @pytest.mark.skipif(not config.cxx, reason="No cxx compiler") - def test_no_c_code(self): - r"""Make sure we avoid fusions for `Op`\s without C code implementations.""" - - # This custom `Op` has no `c_code` method - class NoCCodeOp(aes.basic.UnaryScalarOp): - def impl(self, x): - return x * 2 - - no_c_code_op = Elemwise(NoCCodeOp(aes.basic.upgrade_to_float)) - - mode = Mode(linker="cvm") - mode._optimizer = mode._optimizer.including( - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - "inplace", - ) - - x = vector() - out = x * no_c_code_op(x + 1) - f = function([x], out, mode=mode) - - assert not any( - isinstance(getattr(n.op, "scalar_op"), aes.basic.Composite) - for n in f.maker.fgraph.toposort() - ) - - @pytest.mark.parametrize("test_value", [np.c_[[1.0]], np.c_[[]]]) - def test_test_values(self, test_value): - """Make sure that `local_elemwise_fusion_op` uses test values correctly when they have zero dimensions. - - The test values we're talking about are the ones used when C implementations - are checked. - - """ - - opts = OptimizationQuery( - include=[ - "local_elemwise_fusion", - "composite_elemwise_fusion", - "canonicalize", - ], - exclude=["cxx_only", "BlasOpt"], - ) - - mode = Mode(self.mode.linker, opts) - - x, y, z = dmatrices("xyz") - - x.tag.test_value = test_value - y.tag.test_value = test_value - z.tag.test_value = test_value - - if test_value.size == 0: - cm = pytest.raises(ValueError) - else: - cm = contextlib.suppress() - - with config.change_flags( - compute_test_value="raise", compute_test_value_opt="raise" - ): - out = x * y + z - with cm: - f = function([x, y, z], out, mode=mode) - - if test_value.size != 0: - # Confirm that the fusion happened - assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, Composite) - assert len(f.maker.fgraph.toposort()) == 1 - - x_c, y_c, z_c = f.maker.fgraph.outputs[0].owner.inputs - assert np.array_equal( - f.maker.fgraph.outputs[0].tag.test_value, np.c_[[2.0]] - ) - - -class TimesN(aes.basic.UnaryScalarOp): - """ - Used in test TestCompositeCodegen - - Must be outside of the class, otherwise, the c cache code can't - pickle this class and this cause stuff printing during test. - """ - - def __eq__(self, other): - return super().__eq__(other) and self.n == other.n - - def __hash__(self): - return super().__hash__() ^ hash(self.n) - - def __init__(self, n, *args, **kwargs): - self.n = n - aes.basic.UnaryScalarOp.__init__(self, *args, **kwargs) - - def impl(self, x): - return x * self.n - - def c_support_code_apply(self, node, nodename): - n = str(self.n) - return ( - """ - float %(nodename)s_timesn(float x) { return x * %(n)s; } - """ - % locals() - ) - - def c_code(self, node, name, inputs, outputs, sub): - (x,) = inputs - (z,) = outputs - return f"{z} = {name}_timesn({x});" - - -class TestCompositeCodegen: - """ - Test The Composite Ops code generation in a case where there is multiple - scalar ops with support code. - """ - - def setup_method(self): - upgrade_to_float = aes.basic.upgrade_to_float - - self.scal_times_2 = TimesN(2, upgrade_to_float, name="times_2") - self.times_2 = Elemwise(self.scal_times_2, name="times_2") - - self.scal_times_3 = TimesN(3, upgrade_to_float, name="times_3") - self.times_3 = Elemwise(self.scal_times_3, name="times_3") - - self.x = fvector() - - def test_nested_composite(self): - y = self.times_2(self.x) - z = self.times_3(y) - f = function([self.x], z) - if config.mode != "FAST_COMPILE": - assert len(f.maker.fgraph.toposort()) == 1 - fval = f([1, 2, 3]) - assert np.all(fval == [6, 12, 18]) - - def test_local_useless_composite(self): - x = aes.float32() - c = aes.Composite([x], [x + 1, x - 1]) - X = matrix() - o = Elemwise(scalar_op=c)(X) - mode = get_default_mode().including("local_useless_composite") - - f = function([X], o[0], mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert len(topo[0].outputs) == 1 - utt.assert_allclose(f([[1.0]]), [[2.0]]) - - f = function([X], o[1], mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert len(topo[0].outputs) == 1 - utt.assert_allclose(f([[1.0]]), [[0.0]]) - - -def test_local_useless_slice(): - # test a simple matrix - x = matrix("x") - mode_unopt = get_default_mode().excluding( - "local_useless_slice", "local_mul_canonizer" - ) - mode_opt = ( - get_default_mode() - .including("local_useless_slice") - .excluding("local_mul_canonizer") - ) - - # test with and without the useless slice - o = 2 * x[0, :] - f_unopt = function([x], o, mode=mode_unopt) - f_opt = function([x], o, mode=mode_opt) - rng = np.random.default_rng(utt.fetch_seed()) - test_inp = rng.integers(-10, 10, (4, 4)).astype("float32") - assert all( - f_opt(test_inp) == f_unopt(test_inp) - ), "The optimization caused a mismatch in the result" - # test to see if the slice is truly gone - apply_node = f_opt.maker.fgraph.toposort()[0] - subtens = apply_node.op - assert not any( - isinstance(idx, slice) for idx in subtens.idx_list - ), "Slice should be gone" - - # Now test that the stack trace is copied over properly, - # before before and after optimization. - assert check_stack_trace(f_unopt, ops_to_check="all") - assert check_stack_trace(f_opt, ops_to_check="all") - - # test a 4d tensor - z = tensor4("z") - o2 = z[1, :, :, 1] - o3 = z[0, :, :, :] - f_opt_check = function([z], o2, mode=mode_opt) - f_opt_check_apply = function([z], o3, mode=mode_opt) - - # The optimization shouldn't apply here - apply_node = f_opt_check.maker.fgraph.toposort()[0] - subtens = apply_node.op - assert [isinstance(idx, slice) for idx in subtens.idx_list].count(True) == 2 - # But it should here - apply_node = f_opt_check_apply.maker.fgraph.toposort()[0] - subtens = apply_node.op - assert not any(isinstance(idx, slice) for idx in subtens.idx_list) - - # Finally, test that the stack trace is copied over properly, - # before before and after optimization. - assert check_stack_trace(f_opt_check, ops_to_check=Subtensor) - assert check_stack_trace(f_opt_check_apply, ops_to_check=Subtensor) - - -def test_local_useless_fill(): - x = dvector() - y = dvector() - z = lvector() - - x_ = np.random.random((5,)) - y_ = np.random.random((5,)) - z_ = (np.random.random((5,)) * 5).astype("int64") - - # basic case - f = function([x], at.fill(x, x) * 2, mode=mode_opt) - assert [node.op for node in f.maker.fgraph.toposort()] == [mul] - res = f(x_) - exp_res = np.broadcast_to(x_, x_.shape) * 2 - assert np.array_equal(res, exp_res) - - # basic case - f = function([x, y], at.second(y, x) * 2, mode=mode_opt) - assert [node.op for node in f.maker.fgraph.toposort()] == [mul] - res = f(x_, y_) - exp_res = np.broadcast_to(x_, y_.shape) * 2 - assert np.array_equal(res, exp_res) - - # basic case - f = function([x, y], at.fill(x, y) * 2, mode=mode_opt) - assert [node.op for node in f.maker.fgraph.toposort()] == [mul] - res = f(x_, y_) - exp_res = np.broadcast_to(y_, x_.shape) * 2 - assert np.array_equal(res, exp_res) - - # now with different type(cast) - f = function([x, z], at.fill(z, x) * 2, mode=mode_opt) - assert [node.op for node in f.maker.fgraph.toposort()] == [mul] - res = f(x_, z_) - exp_res = np.broadcast_to(x_, z_.shape) * 2 - assert np.array_equal(res, exp_res) - - # now with different type(cast) - f = function([x, z], at.fill(x, z) * 2, mode=mode_opt) - assert [node.op for node in f.maker.fgraph.toposort()] == [mul] - res = f(x_, z_) - exp_res = np.broadcast_to(z_, x_.shape) * 2 - assert np.array_equal(res, exp_res) - - # now cutting out the input ?? - f = function([x, y], at.fill(x, y) * 2, mode=mode_opt) - assert [node.op for node in f.maker.fgraph.toposort()] == [mul] - res = f(x_, y_) - exp_res = np.broadcast_to(y_, x_.shape) * 2 - assert np.array_equal(res, exp_res) - - -def test_local_fill_to_alloc(): - x = dvector() - m = dmatrix() - - x_ = np.random.random((5,)) - m_ = np.random.random((5, 5)) - - y = at.fill(m, x) - - mode = mode_opt.including("stabilize", "local_fill_to_alloc").excluding( - "useless", "local_useless_fill" - ) - - f = function([m, x], y, mode=mode) - assert Alloc in [node.op.__class__ for node in f.maker.fgraph.toposort()] - - res = f(m_, x_) - exp_res = np.broadcast_to(x_, m_.shape) - assert np.array_equal(res, exp_res) - - y = at.fill(x, m) - - f = function([m, x], y, mode=mode) - - assert Alloc not in [node.op.__class__ for node in f.maker.fgraph.toposort()] - - res = f(m_, x_) - assert np.array_equal(res, m_) - - -class TestLocalCanonicalizeAlloc: - def setup_method(self): - self.rng = np.random.default_rng(utt.fetch_seed()) - - def test_inconsistent_constant(self): - x = at.as_tensor(self.rng.standard_normal((3, 7))) - a = at.alloc(x, 6, 7) - - assert a.owner and isinstance(a.owner.op, Alloc) - - # `local_useless_alloc` should replace the `Alloc` with an `Assert` - with pytest.raises(AssertionError): - f = function([], a, mode=mode_opt) - - x = at.as_tensor(self.rng.standard_normal((6, 7))) - a = at.alloc(x, 6, 7) - - f = function([], a, mode=mode_opt) - - # The optimization should then be applied, and remove Alloc - assert not any( - isinstance(node.op, (Alloc, Assert)) for node in f.maker.fgraph.toposort() - ) - - def test_inconsistent_shared(self): - # These shapes don't match! - x = shared(self.rng.standard_normal((3, 7))) - a = at.alloc(x, 6, 7) - - assert a.owner and isinstance(a.owner.op, Alloc) - - f = function([], a, mode=mode_opt) - - # The optimization should then be applied, and remove Alloc - assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()) - assert any(isinstance(node.op, Assert) for node in f.maker.fgraph.toposort()) - - with pytest.raises(AssertionError): - f() - - good_x_val = self.rng.standard_normal((6, 7)) - x.set_value(good_x_val) - - assert np.array_equal(f(), good_x_val) - - def test_basic_fill(self): - x = matrix("x") - y = at.fill(x, x) - - # The optimization 'locall_fill_to_alloc' should call at.alloc, - # which should return x and not alloc(x, ...) - f = function([x], [y], mode=mode_opt.including("local_fill_to_alloc")) - assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()) - - def test_basic_tile(self): - x = matrix("x") - y = at.tile(x, (1,) * 2) - - mode = mode_opt.including( - "local_dimshuffle_lift", - "local_useless_dimshuffle_in_reshape", - "local_alloc_sink_dimshuffle", - ) - f = function([x], [y], mode=mode) - - assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()) - - @pytest.mark.parametrize( - "x, has_alloc", - [ - (at.alloc(np.ones((2,)), 1, 3, 2), True), - (at.alloc(np.array(1.0), 1, 1), False), - (at.alloc(np.ones((1, 1)), 1, 1, 2), True), - (at.alloc(np.ones((1, 1)), 1, 2), True), - ], - ) - def test_useless_alloc_with_shape_one(self, x, has_alloc): - g = FunctionGraph(outputs=[x]) - assert any(isinstance(node.op, Alloc) for node in g.toposort()) - - alloc_lift = out2in(local_alloc_sink_dimshuffle) - alloc_lift.optimize(g) - - if has_alloc: - assert any(isinstance(node.op, Alloc) for node in g.toposort()) - else: - assert not any(isinstance(node.op, Alloc) for node in g.toposort()) - - -class TestLocalUselessIncSubtensorAlloc: - opt_name = "local_useless_inc_subtensor_alloc" - - def setup_method(self): - # The optimization requires the shape feature so we need to compile in - # FAST_RUN mode. - mode = config.mode - if mode == "FAST_COMPILE": - mode = "FAST_RUN" - self.mode = get_mode(mode) - self.rng = np.random.default_rng(utt.fetch_seed()) - - def test_advanced_inc_subtensor(self): - x = vector("x") - y = scalar("y") - i = matrix("i", dtype="int64") - z = advanced_inc_subtensor(x, at.alloc(y, *i.shape), i) - mode1 = self.mode.excluding(self.opt_name) - mode2 = self.mode.including(self.opt_name) - f1 = function([x, i, y], z, mode=mode1) - f2 = function([x, i, y], z, mode=mode2) - - # the alloc op should still be there - assert ( - len([n for n in f1.maker.fgraph.toposort() if isinstance(n.op, Alloc)]) == 1 - ) - # the alloc op should have been removed - assert ( - len([n for n in f2.maker.fgraph.toposort() if isinstance(n.op, Alloc)]) == 0 - ) - - x_value = np.random.standard_normal((5)).astype(config.floatX) - y_value = np.random.standard_normal() - i_value = self.rng.integers(0, 3, size=(2, 3)) - - r1 = f1(x_value, i_value, y_value) - r2 = f2(x_value, i_value, y_value) - - utt.assert_allclose(r1, r2) - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f1, ops_to_check=AdvancedIncSubtensor1) - assert check_stack_trace(f2, ops_to_check=AdvancedIncSubtensor1) - - def test_advanced_inc_subtensor1(self): - x = vector("x") - y = scalar("y") - i = vector("i", dtype="int64") - z = advanced_inc_subtensor1(x, at.alloc(y, *i.shape), i) - mode1 = self.mode.excluding(self.opt_name) - mode2 = self.mode.including(self.opt_name) - f1 = function([x, i, y], z, mode=mode1) - f2 = function([x, i, y], z, mode=mode2) - - # the alloc op should still be there - assert ( - len([n for n in f1.maker.fgraph.toposort() if isinstance(n.op, Alloc)]) == 1 - ) - # the alloc op should have been removed - assert ( - len([n for n in f2.maker.fgraph.toposort() if isinstance(n.op, Alloc)]) == 0 - ) - - x_value = np.random.standard_normal((5)).astype(config.floatX) - y_value = np.random.standard_normal() - i_value = self.rng.integers(0, 3, size=2) - - r1 = f1(x_value, i_value, y_value) - r2 = f2(x_value, i_value, y_value) - - utt.assert_allclose(r1, r2) - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f1, ops_to_check=AdvancedIncSubtensor1) - assert check_stack_trace(f2, ops_to_check="all") - - def test_incsubtensor(self): - x = vector("x") - y = scalar("y") - i = scalar("i", dtype="int64") - z = inc_subtensor(x[:i], at.alloc(y, i)) - mode1 = self.mode.excluding(self.opt_name) - mode2 = self.mode.including(self.opt_name) - f1 = function([x, i, y], z, mode=mode1) - f2 = function([x, i, y], z, mode=mode2) - - # the alloc op should still be there - assert ( - len([n for n in f1.maker.fgraph.toposort() if isinstance(n.op, Alloc)]) == 1 - ) - # the alloc op should have been removed - assert ( - len([n for n in f2.maker.fgraph.toposort() if isinstance(n.op, Alloc)]) == 0 - ) - - x_value = np.random.standard_normal((5)).astype(config.floatX) - y_value = np.random.standard_normal() - i_value = 3 - - r1 = f1(x_value, i_value, y_value) - r2 = f2(x_value, i_value, y_value) - - utt.assert_allclose(r1, r2) - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f1, ops_to_check="last") - assert check_stack_trace(f2, ops_to_check="last") - - -class TestShapeOptimizer: - def test_basic(self): - mode = config.mode - if mode == "FAST_COMPILE": - mode = "FAST_RUN" - v = vector() - m = matrix() - f = function([v, m], (v + m).shape, mode=mode) - for node in f.maker.fgraph.toposort(): - assert node.op != add - - def test_constant(self): - mode = config.mode - if mode == "FAST_COMPILE": - mode = "FAST_RUN" - - v = vector() - f = function([v], v.dimshuffle("x", "x", 0).shape[1], mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert topo[0].op == deep_copy_op - - @staticmethod - def max_pool_c01b(c01b, pool_shp, pool_stride, img_shp): - """ - Like max_pool but with input using axes ('c', 0, 1, 'b') - (Alex Krizhevsky format) - - pool_shp, pool_stride and img_shp are int that represent - the same shp in x and y. - """ - mx = None - - # Compute index in pooled space of last needed pool - # (needed = each input pixel must appear in at least one pool) - def last_pool(im_shp, p_shp, p_strd): - rval = int(np.ceil(float(im_shp - p_shp) / p_strd)) - assert p_strd * rval + p_shp >= im_shp - assert p_strd * (rval - 1) + p_shp < im_shp - return rval - - # Compute starting row of the last pool - last_pool_r = last_pool(img_shp, pool_shp, pool_stride) * pool_stride - # Compute number of rows needed in img for all indexes to work out - required_r = last_pool_r + pool_shp - - last_pool_c = last_pool(img_shp, pool_shp, pool_stride) * pool_stride - required_c = last_pool_c + pool_shp - - wide_infinity = at.alloc( - -np.inf, c01b.shape[0], required_r, required_c, c01b.shape[3] - ) - - c01b = set_subtensor(wide_infinity[:, 0:img_shp, 0:img_shp, :], c01b) - - for row_within_pool in range(pool_shp): - row_stop = last_pool_r + row_within_pool + 1 - for col_within_pool in range(pool_shp): - col_stop = last_pool_c + col_within_pool + 1 - cur = c01b[ - :, - row_within_pool:row_stop:pool_stride, - col_within_pool:col_stop:pool_stride, - :, - ] - if mx is None: - mx = cur - else: - mx = maximum(mx, cur) - return mx - - def test_broadcasted_dims(self): - # This test a case that caused a crash during optimization - shp = (1, 1, 1, 1) - rng = np.random.default_rng(utt.fetch_seed()) - a = shared(rng.random(shp).astype(config.floatX)) - out = self.max_pool_c01b(a, 1, 1, 1) - - # max_pool_c01b use -inf and this will trigger DebugMode error. - mode = copy.copy(get_default_mode()) - mode.check_isfinite = False - f = function([], out, mode=mode) - f() - - def test_constant_merge(self): - # This test the error in gh-1122 that is a caused by the - # combination of merge optimizer and ShapeFeature. - - x = at.constant([0, 0]) - y = x[1:] - x1 = x - at.join(0, y, y) - x1.eval() - - def test_local_track_shape_i(self): - class IdentityNoShape(Op): - """Op that does not infer the output shape from the input one""" - - def make_node(self, x): - x = as_tensor_variable(x) - return Apply(self, [x], [x.type()]) - - def perform(self, node, inp, out_): - (x,) = inp - (out,) = out_ - out[0] = x.copy() - - # def infer_shape(self, fgraph, node, (xshp,)): - # return [tuple([self.shape_i(i)(r) for i in range(r.ndim)])] - - identity_noshape = IdentityNoShape() - - class IdentityShape(Op): - """Op that does infer the output shape from the input one""" - - def make_node(self, x): - x = as_tensor_variable(x) - return Apply(self, [x], [x.type()]) - - def perform(self, node, inp, out_): - (x,) = inp - (out,) = out_ - out[0] = x.copy() - - def infer_shape(self, fgraph, node, xshp_): - # Could also just return. - (xshp,) = xshp_ - return (xshp,) - - identity_shape = IdentityShape() - - @local_optimizer([IdentityNoShape]) - def local_identity_noshape_to_identity_shape(fgraph, node): - """Optimization transforming the first Op into the second""" - if isinstance(node.op, IdentityNoShape): - return [identity_shape(node.inputs[0])] - - mode = get_default_mode().including("ShapeOpt", "specialize") - rng = np.random.default_rng(utt.fetch_seed()) - x = tensor3("x") - ins_x = identity_noshape(x) - - # Without the optimization - f = function([x], ins_x.shape, mode=mode) - xval = rng.standard_normal((3, 4, 7)).astype(config.floatX) - assert np.all(f(xval) == [3, 4, 7]) - f_ops = [node.op for node in f.maker.fgraph.toposort()] - assert len(f_ops) == 5 - assert identity_noshape in f_ops - assert identity_shape not in f_ops - - # Register the optimization - register_specialize(local_identity_noshape_to_identity_shape) - - mode = get_default_mode().including("ShapeOpt", "specialize") - # With the optimization - # The identity_shape op should not be needed anymore to compute - # the shape - g = function([x], ins_x.shape, mode=mode) - xval = rng.standard_normal((6, 1, 2)).astype(config.floatX) - assert np.all(g(xval) == [6, 1, 2]) - g_ops = [node.op for node in g.maker.fgraph.toposort()] - assert len(g_ops) == 4 - assert identity_noshape not in g_ops - assert identity_shape not in g_ops - - # test multiple level of op without infer_shape - ins_x3 = identity_noshape(identity_noshape(identity_noshape(x))) - h = function([x], ins_x3.shape, mode=mode) - xval = rng.standard_normal((6, 1, 2)).astype(config.floatX) - assert np.all(h(xval) == [6, 1, 2]) - h_ops = [node.op for node in h.maker.fgraph.toposort()] - assert len(h_ops) == 4 - assert identity_noshape not in h_ops - assert identity_shape not in h_ops - - def test_no_shapeopt(self): - # Test that a basic example works even when ShapeOpt is excluded - X = matrix() - expr = X.shape[0] - - mode = get_default_mode().excluding("ShapeOpt") - f = function([X], expr, mode=mode) - # FIXME: This is not a good test. - f([[1, 2], [2, 3]]) - - -class TestUselessCheckAndRaise: - def test_basic(self): - mode = get_default_mode().including( - "canonicalize", "local_remove_useless_assert" - ) - x = scalar() - y = scalar() - f = function([x, y], assert_op(x, eq(x, y)), mode=mode) - assert f(1, 1) == 1 - with pytest.raises(AssertionError): - f(1, 0) - - def test_local_remove_useless_1(self): - """Remove `CheckAndRaise`s when all the conditions are always true.""" - x = scalar() - fg = FunctionGraph(outputs=[assert_op(x, 1)], clone=False) - fg_res = optimize_graph(fg, include=["canonicalize", "specialize"]) - topo = fg_res.toposort() - assert not any(isinstance(node.op, CheckAndRaise) for node in topo) - - def test_local_remove_useless_2(self): - """Remove `CheckAndRaise` conditions that are always true.""" - x = scalar() - y = scalar() - fg = FunctionGraph(outputs=[assert_op(x, y, 1)], clone=False) - fg_res = optimize_graph(fg, include=["canonicalize", "specialize"]) - topo = fg_res.toposort() - (assert_node,) = [node for node in topo if isinstance(node.op, CheckAndRaise)] - assert assert_node.inputs == [x, y] - - def test_local_remove_useless_3(self): - """Don't remove `CheckAndRaise` conditions that are always false.""" - x = scalar() - y = scalar() - fg = FunctionGraph(outputs=[assert_op(x, y, 0)], clone=False) - fg_res = optimize_graph(fg, include=["canonicalize", "specialize"]) - topo = fg_res.toposort() - (assert_node,) = [node for node in topo if isinstance(node.op, CheckAndRaise)] - assert assert_node.inputs[:2] == [x, y] - assert assert_node.inputs[-1].data == 0 - - -def test_local_remove_all_assert(): - r"""Remove all `Assert`\s.""" - mode = get_default_mode().including("canonicalize", "local_remove_all_assert") - - x = scalar() - y = scalar() - f = function([x, y], assert_op(x, y), mode=mode) - # Without the optimization, this would fail - assert f(1, 0) == 1 - topo = f.maker.fgraph.toposort() - assert not any(isinstance(node.op, CheckAndRaise) for node in topo) - - mode = get_default_mode() - a = assert_op(x, eq(x, 0).any()) - f = function([x], a, mode=mode.excluding("unsafe")) - topo = f.maker.fgraph.toposort() - a_op = [n for n in topo if isinstance(n.op, Assert)] - assert len(a_op) == 1 - - -class TestTile: - def test_local_useless_tile(self): - v = vector() - m = matrix() - mode = None - if config.mode == "FAST_COMPILE": - mode = "FAST_RUN" - for var, data in [(v, [1, 2, 3]), (m, [[1, 2], [3, 4]])]: - # When len(repeat pattern) <= var.ndim, everything is removed - # for ndim in range(1, var.ndim): - for ndim in range(var.ndim + 1): - f = function([var], tile(var, (1,) * ndim), mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, DeepCopyOp) - f(data) - # In this case the opt only removes nodes, - # no need to check_stack_trace - # When len(repeat pattern) > var.ndim, only a dimshuffle should be - # left, but there can be a DeepCopy as well - for ndim in range(var.ndim + 1, var.ndim + 3): - f = function([var], tile(var, (1,) * ndim), mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) <= 2 - assert isinstance(topo[0].op, DimShuffle) - assert check_stack_trace(f, ops_to_check=[DimShuffle]) - f(data) - - -class TestUnbroadcast: - def setup_method(self): - self.mode = get_default_mode().including("canonicalize") - - def test_local_useless_unbroadcast(self): - x1 = tensor("float64", shape=(1, 2)) - x2 = tensor("float64", shape=(2, 1)) - unbroadcast_op = Unbroadcast(0) - - f = function([x1], unbroadcast_op(x1), mode=self.mode) - assert ( - sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort()) - == 1 - ) - - f = function([x2], unbroadcast_op(x2), mode=self.mode) - assert ( - sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort()) - == 0 - ) - - def test_local_unbroadcast_lift(self): - x = tensor("float64", shape=(1, 1)) - y = unbroadcast(at.exp(unbroadcast(x, 0)), 1) - - assert ( - sum( - isinstance(node.op, Unbroadcast) - for node in FunctionGraph([x], [y], copy_inputs=False).toposort() - ) - == 2 - ) - - f = function([x], y, mode=self.mode) - assert ( - sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort()) - == 1 - ) - - np.testing.assert_almost_equal(f([[1]]), np.exp([[1]])) - - -class TestUselessElemwise: - def setup_method(self): - self.mode = get_default_mode().including("canonicalize", "local_fill_to_alloc") - - def test_eq(self): - x = dmatrix() - y = dmatrix() - f = function([x, y], eq(x, y), mode=self.mode) - vx = np.random.random((5, 4)) - vy = np.random.random((5, 4)) - f(vx, vy) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, Elemwise) - assert isinstance(topo[0].op.scalar_op, aes.EQ) - f2 = function([x], eq(x, x), mode=self.mode) - assert np.all(f2(vx) == np.ones((5, 4))) - topo2 = f2.maker.fgraph.toposort() - # Shape_i{1}(), - # Shape_i{0}(), Alloc([[1]], Shape_i{0}.0, - # Shape_i{1}.0 - assert len(topo2) == 3 - assert isinstance(topo2[-1].op, Alloc) - - def test_neq(self): - x = dmatrix() - y = dmatrix() - f = function([x, y], neq(x, y), mode=self.mode) - vx = np.random.random((5, 4)) - vy = np.random.random((5, 4)) - f(vx, vy) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, Elemwise) - assert isinstance(topo[0].op.scalar_op, aes.NEQ) - f2 = function([x], neq(x, x), mode=self.mode) - assert np.all(f2(vx) == np.zeros((5, 4))) - topo2 = f2.maker.fgraph.toposort() - assert len(topo2) == 3 - assert isinstance(topo2[-1].op, Alloc) - - def test_mul(self): - x = dmatrix() - y = dmatrix() - f = function([x], mul(x), mode=self.mode) - vx = np.random.random((5, 4)) - vy = np.random.random((5, 4)) - f(vx) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert topo[0].op == deep_copy_op - f2 = function([x, y], mul(x, y), mode=self.mode) - assert np.all(f2(vx, vy) == vx * vy) - topo2 = f2.maker.fgraph.toposort() - assert len(topo2) == 1 - assert isinstance(topo2[0].op, Elemwise) - assert isinstance(topo2[0].op.scalar_op, aes.Mul) - - def test_add(self): - x = dmatrix() - y = dmatrix() - f = function([x], add(x), mode=self.mode) - vx = np.random.random((5, 4)) - vy = np.random.random((5, 4)) - f(vx) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert topo[0].op == deep_copy_op - f2 = function([x, y], add(x, y), mode=self.mode) - assert np.all(f2(vx, vy) == vx + vy) - topo2 = f2.maker.fgraph.toposort() - assert len(topo2) == 1 - assert isinstance(topo2[0].op, Elemwise) - assert isinstance(topo2[0].op.scalar_op, aes.Add) - - def test_identity(self): - # aes.identity is used in 2 Elemwise functions: - # tensor_copy, and view - x = matrix() - f = function([x], at.tensor_copy(x), mode=self.mode) - vx = np.random.random((5, 4)).astype(config.floatX) - f(vx) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert topo[0].op == deep_copy_op - - -class TestCastCast: - def setup_method(self): - mode = get_default_mode() - self.mode = mode.including("local_cast_cast") - - def test_consecutive(self): - x = fmatrix() - o = Elemwise(aes.Cast(aes.ScalarType("float64")))(x.astype("float64")) - f = function([x], o, mode=self.mode) - dx = np.random.random((5, 4)).astype("float32") - f(dx) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op.scalar_op, aes.basic.Cast) - - x = dmatrix() - o = Elemwise(aes.Cast(aes.ScalarType("float32")))(x.astype("float32")) - f = function([x], o, mode=self.mode) - dx = np.random.random((5, 4)) - f(dx) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op.scalar_op, aes.basic.Cast) - - def test_upcast(self): - # Upcast followed by any other cast - x = fmatrix() - o = Elemwise(aes.Cast(aes.ScalarType("complex128")))(x.astype("complex64")) - f = function([x], o, mode=self.mode) - dx = np.random.random((5, 4)).astype("float32") - f(dx) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op.scalar_op, aes.basic.Cast) - - # Upcast followed by a downcast back to the base type - x = fmatrix() - o = Elemwise(aes.Cast(aes.ScalarType("float32")))(x.astype("float64")) - f = function([x], o, mode=self.mode) - dx = np.random.random((5, 4)).astype("float32") - f(dx) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, DeepCopyOp) - - # Downcast followed by an upcast back to the base type - # Optimization shouldn't be applied - x = dmatrix() - o = Elemwise(aes.Cast(aes.ScalarType("float64")))(x.astype("float32")) - f = function([x], o, mode=self.mode) - dx = np.random.random((5, 4)) - f(dx) - topo = f.maker.fgraph.toposort() - assert ( - len(topo) == 1 and isinstance(topo[0].op.scalar_op, aes.basic.Composite) - ) or (len(topo) > 1) - - -def test_constant_folding(): - # Test that constant folding get registered at fast_compile - # An error removed that registration during the registration. - x = dvector() - mode = get_mode("FAST_COMPILE").excluding("fusion") - f = function([x], [x * 2, x + x], mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 2 - - # Test that we do not crash when constant folding elemwise scalar - # as they should not generate c code. - - x = at.constant(3) - assert x.ndim == 0 - mode = get_mode("FAST_COMPILE").excluding("fusion") - f = function([], [x * 2, x + x], mode=mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 2 - assert all(isinstance(n.op, DeepCopyOp) for n in topo) - - -@pytest.mark.xfail( - reason="Aesara optimizes constant before stabilization. " - "This breaks stabilization optimizations in some " - "cases. See #504.", - raises=AssertionError, -) -def test_constant_get_stabilized(): - # Currently Aesara enables the `constant_folding` optimization before stabilization optimization. - # This caused some stabilization optimizations to not be activated and that - # caused inf values to appear when they should not. - - # We can't simply move the `constant_folding` optimization to - # specialize since this will break other optimizations. We will need to - # partially duplicate some canonicalize optimizations to fix this issue. - - x2 = scalar() - y2 = log(1 + exp(x2)) - mode = get_default_mode() - mode.check_isfinite = False - f2 = function([x2], y2, mode=mode) - - assert len(f2.maker.fgraph.toposort()) == 1 - assert f2.maker.fgraph.toposort()[0].op == softplus - assert f2(800) == 800 - - x = at.as_tensor_variable(800) - y = log(1 + exp(x)) - f = function([], y, mode=mode) - # When this error is fixed, the following line should be ok. - assert f() == 800, f() - - -class TestLocalSwitchSink: - def setup_method(self): - # condition values - self.condm = np.asarray([[0.1, 0, 1, -1], [0.0, 0.0, 0.0, 0.0], [1, 1, 1, 1]]) - self.condv = np.asarray([0.1, 0, 1, -1]) - self.conds = [0.1, 0, 1, -1] - - # x values - self.xm = np.ones((3, 4)) - self.xv = np.ones((4,)) - self.xs = 1.0 - - # expected results - self.resm = ( - [np.asarray([[1, 0, 1, 0], [0, 0, 0, 0], [1, 1, 1, 1]])] * 3 - + [np.asarray([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]])] - + 2 * [np.asarray([[1, 0, 1, 0]])] - + [[np.ones((3, 4)), np.zeros((3, 4)), np.ones((3, 4)), np.zeros((3, 4))]] - + [[np.ones((4,)), np.zeros((4,)), np.ones((4,)), np.zeros((4,))]] - + [[np.asarray(1.0), np.asarray(0.0), np.asarray(1.0), np.asarray(0.0)]] - ) - - self.mode = ( - get_default_mode() - .including("canonicalize", "fast_run") - .excluding("gpu", "fusion") - ) - self.mode = copy.copy(self.mode) - self.mode.check_isfinite = False - - def function_remove_nan(self, *args, **kwargs): - """ - Wrapper around function for this test. - - It disables checking for NaN removed by optimizations in DebugMode - (it has false positives in that case). - """ - f = function(*args, **kwargs) - - def wrapped_f(*args, **kwargs): - # This is a bit ugly since it changes the global value of - # TensorType.values_eq_approx. - old_values_eq_approx = staticmethod(TensorType.values_eq_approx) - TensorType.values_eq_approx = staticmethod(values_eq_approx_remove_nan) - try: - out = f(*args, **kwargs) - finally: - TensorType.values_eq_approx = old_values_eq_approx - return out - - return wrapped_f - - def test_local_mul_switch_sink(self): - c = dscalar() - idx = 0 - for condition in [ - (dmatrix("cond"), self.condm), - (dvector("cond"), self.condv), - (dscalar("cond"), self.conds), - ]: - for x in [ - (dmatrix("x"), self.xm), - (dvector("x"), self.xv), - (dscalar("x"), self.xs), - ]: - y = mul( - at.switch(condition[0] > 0, 1.0 * x[0], 0.0 * x[0]), - at.switch(condition[0] > 0, 1.0 * x[0], log(c) * x[0]), - ) - f = self.function_remove_nan( - [condition[0], x[0], c], [y], mode=self.mode - ) - if type(condition[1]) is list: - for i in range(len(condition[1])): - res = f(condition[1][i], x[1], -1) - assert ( - res == np.asarray(self.resm[idx][i]) - ).sum() == self.resm[idx][i].size - else: - res = f(condition[1], x[1], -1) - assert (res == np.asarray(self.resm[idx])).sum() == self.resm[ - idx - ].size - idx += 1 - - # This case caused a missed optimization in the past. - x = dscalar("x") - y = at.switch(x < 7, x, sqrt(x - 7)) - f = self.function_remove_nan([x], aesara.gradient.grad(y, x), self.mode) - assert f(5) == 1, f(5) - - @pytest.mark.slow - def test_local_div_switch_sink(self): - c = dscalar() - idx = 0 - for condition in [ - (dmatrix("cond"), self.condm), - (dvector("cond"), self.condv), - (dscalar("cond"), self.conds), - ]: - for x in [ - (dmatrix("x"), self.xm), - (dvector("x"), self.xv), - (dscalar("x"), self.xs), - ]: - y = true_div( - at.switch(condition[0] > 0, 1.0 * x[0], 0.0 * x[0]), - at.switch(condition[0] > 0, 1.0 * x[0], log(c) * x[0]), - ) - f = self.function_remove_nan( - [condition[0], x[0], c], [y], mode=self.mode - ) - if type(condition[1]) is list: - for i in range(len(condition[1])): - res = f(condition[1][i], x[1], -1) - assert ( - res == np.asarray(self.resm[idx][i]) - ).sum() == self.resm[idx][i].size - else: - res = f(condition[1], x[1], -1) - assert (res == np.asarray(self.resm[idx])).sum() == self.resm[ - idx - ].size - idx += 1 - - -class TestLocalUselessSwitch: - def setup_method(self): - self.mode = mode_opt.excluding("constant_folding") - - @pytest.mark.parametrize( - "dtype1", - ["int32", "int64"], - ) - @pytest.mark.parametrize( - "dtype2", - ["int32", "int64"], - ) - @pytest.mark.parametrize( - "cond", - [0, 1, np.array([True])], - ) - def test_const(self, dtype1, dtype2, cond): - x = matrix("x", dtype=dtype1) - y = matrix("y", dtype=dtype2) - z = at.switch(cond, x, y) - f = function([x, y], z, mode=self.mode) - assert not any( - node.op - for node in f.maker.fgraph.toposort() - if ( - isinstance(node.op, Elemwise) - and isinstance(node.op.scalar_op, aes.basic.Switch) - ) - ) - vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1) - vy = np.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2) - np_res = np.where(cond, vx, vy) - assert np.array_equal(f(vx, vy), np_res) - - @pytest.mark.parametrize( - "dtype1", - ["int32", "int64"], - ) - def test_left_is_right(self, dtype1): - x = matrix("x", dtype=dtype1) - varc = matrix("varc", dtype=dtype1) - z1 = at.switch(1, x, x) - z0 = at.switch(0, x, x) - z2 = at.switch(varc, x, x) - f1 = function([x], z1, mode=self.mode) - f0 = function([x], z0, mode=self.mode) - f2 = function([x, varc], z2, mode=self.mode) - - topo = f1.maker.fgraph.toposort() - assert len(topo) == 1 - assert topo[0].op == deep_copy_op - - topo = f0.maker.fgraph.toposort() - assert len(topo) == 1 - assert topo[0].op == deep_copy_op - - topo = f2.maker.fgraph.toposort() - assert len(topo) == 1 - assert topo[0].op == deep_copy_op - - vx = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1) - vc = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1) - assert np.array_equal(f1(vx), vx) - assert np.array_equal(f0(vx), vx) - assert np.array_equal(f2(vx, vc), vx) - - @pytest.mark.parametrize( - "dtype1", - ["float32", "float64"], - ) - def test_shape_le_0(self, dtype1): - x = matrix("x", dtype=dtype1) - z0 = at.switch(le(x.shape[0], 0), 0, x.shape[0]) - f0 = function([x], z0, mode=self.mode) - assert isinstance(f0.maker.fgraph.toposort()[0].op, Shape_i) - - z1 = at.switch(le(x.shape[1], 0), 0, x.shape[1]) - f1 = function([x], z1, mode=self.mode) - assert isinstance(f1.maker.fgraph.toposort()[0].op, Shape_i) - - vx = np.random.standard_normal((0, 5)).astype(dtype1) - assert f0(vx) == 0 - assert f1(vx) == 5 - - def test_broadcasting_1(self): - # test switch(cst, matrix, row) - x = matrix("x", dtype="int32") - y = vector("y", dtype="int64") - - z = at.switch(1, x, y) - f = function([x, y], z, mode=self.mode) - - start_var = f.maker.fgraph.outputs[0].owner.inputs[0] - assert isinstance(start_var.owner.op, Elemwise) - assert isinstance(start_var.owner.op.scalar_op, aes.basic.Cast) - assert not any(node.op == at.switch for node in f.maker.fgraph.toposort()) - - vx = np.array([[1, 2, 3], [4, 5, 6]], dtype="int32") - vy = np.array([10, 11, 12], dtype="int64") - np_res = np.where(1, vx, vy) - assert np.array_equal(f(vx, vy), np_res) - - z = at.switch(0, x, y) - f = function([x, y], z, mode=self.mode) - - assert isinstance(f.maker.fgraph.outputs[0].owner.op, Alloc) - assert f.maker.fgraph.inputs[1] == f.maker.fgraph.outputs[0].owner.inputs[0] - assert not any(node.op == at.switch for node in f.maker.fgraph.toposort()) - - vx = np.array([[1, 2, 3], [4, 5, 6]], dtype="int32") - vy = np.array([10, 11, 12], dtype="int64") - np_res = np.where(0, vx, vy) - assert np.array_equal(f(vx, vy), np_res) - - def test_broadcasting_2(self): - # test switch(cst, vector, matrix) - - x = vector("x", dtype="int32") - y = matrix("y", dtype="int64") - - z = at.switch(1, x, y) - f = function([x, y], z, mode=self.mode) - - assert isinstance(f.maker.fgraph.outputs[0].owner.op, Alloc) - assert not any(node.op == at.switch for node in f.maker.fgraph.toposort()) - - vx = np.array([4, 5, 6], dtype="int32") - vy = np.array([[7, 8, 9], [10, 11, 12]], dtype="int64") - np_res = np.where(1, vx, vy) - assert np.array_equal(f(vx, vy), np_res) - - z = at.switch(0, x, y) - f = function([x, y], z, mode=self.mode) - - assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp) - assert not any(node.op == at.switch for node in f.maker.fgraph.toposort()) - - vx = np.array([4, 5, 6], dtype="int32") - vy = np.array([[7, 8, 9], [10, 11, 12]], dtype="int64") - np_res = np.where(0, vx, vy) - assert np.array_equal(f(vx, vy), np_res) - - def test_broadcasting_3(self): - # test switch(matrix, same_vector, same_vector) - - x = matrix("x", dtype="int32") - y = vector("y", dtype="int64") - z = at.switch(x, y, y) - f = function([x, y], z, mode=self.mode) - vx = np.array([[0, 1], [1, 0]], dtype="int32") - vy = np.array([7, 8], dtype="int64") - utt.assert_allclose(f(vx, vy), np.where(vx, vy, vy)) - - assert isinstance(f.maker.fgraph.outputs[0].owner.op, Alloc) - assert not any(node.op == at.switch for node in f.maker.fgraph.toposort()) - - -class TestLocalMergeSwitchSameCond: - @pytest.mark.parametrize( - "op", - [ - add, - sub, - mul, - true_div, - int_div, - floor_div, - minimum, - maximum, - gt, - lt, - ge, - le, - eq, - neq, - at_pow, - ], - ) - def test_elemwise_float_ops(self, op): - # float Ops - mats = matrices("cabxy") - c, a, b, x, y = mats - s1 = at.switch(c, a, b) - s2 = at.switch(c, x, y) - - g = optimize(FunctionGraph(mats, [op(s1, s2)])) - assert str(g).count("Switch") == 1 - - @pytest.mark.parametrize( - "op", - [ - bitwise_and, - bitwise_or, - bitwise_xor, - ], - ) - def test_elemwise_int_ops(self, op): - # integer Ops - mats = imatrices("cabxy") - c, a, b, x, y = mats - s1 = at.switch(c, a, b) - s2 = at.switch(c, x, y) - g = optimize(FunctionGraph(mats, [op(s1, s2)])) - assert str(g).count("Switch") == 1 - - @pytest.mark.parametrize("op", [add, mul]) - def test_elemwise_multi_inputs(self, op): - # add/mul with more than two inputs - mats = imatrices("cabxy") - c, a, b, x, y = mats - s1 = at.switch(c, a, b) - s2 = at.switch(c, x, y) - u, v = matrices("uv") - s3 = at.switch(c, u, v) - g = optimize(FunctionGraph(mats + [u, v], [op(s1, s2, s3)])) - assert str(g).count("Switch") == 1 - - -class TestLocalOptAlloc: - """ - TODO FIXME: These tests are incomplete; they need to `assert` something. - """ - - dtype = "float32" - - def test_sum_upcast(self): - s = lscalar() - a = at.alloc(np.asarray(5, dtype=self.dtype), s, s) - with config.change_flags(warn_float64="raise"): - f = function([s], a.sum()) - f(5) - - def test_prod_upcast(self): - s = lscalar() - a = at.alloc(np.asarray(5, dtype=self.dtype), s, s) - - with config.change_flags(warn_float64="raise"): - f = function([s], a.prod()) - f(5) - - @config.change_flags(on_opt_error="raise") - def test_sum_bool_upcast(self): - s = lscalar() - a = at.alloc(np.asarray(True, dtype="bool"), s, s) - f = function([s], a.sum()) - f(5) - # test with user specified dtype - f = function([s], a.sum(dtype=self.dtype)) - f(5) - # test only 1 axis summed - f = function([s], a.sum(axis=0, dtype=self.dtype)) - f(5) - - -class TestLocalOptAllocF16(TestLocalOptAlloc): - dtype = "float16" - - -def test_local_join_1(): - # test for vector - a = vector("a") - s = at.stack([a]) - f = function([a], s, mode=mode_opt) - val = f([1]) - assert np.all(val == [1]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 0 - assert f.maker.fgraph.outputs[0].dtype == config.floatX - - # test for matrix join(0,a) - a = matrix("a") - s = join(0, a) - f = function([a], s, mode=mode_opt) - val = f([[1]]) - assert np.all(val == [[1]]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 0 - assert f.maker.fgraph.outputs[0].dtype == config.floatX - - # test for matrix join(1,a) - s = join(1, a) - f = function([a], s, mode=mode_opt) - val = f([[1]]) - assert np.all(val == [[1]]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 0 - assert f.maker.fgraph.outputs[0].dtype == config.floatX - - # test we don't apply when their is 2 inputs - s = join(1, a, a) - f = function([a], s, mode=mode_opt) - val = f([[1]]) - assert np.all(val == [[1]]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 1 - assert f.maker.fgraph.outputs[0].dtype == config.floatX - - -def test_local_join_empty(): - # test for vector, vector, empty to vector - empty_vec = np.asarray([], dtype=config.floatX) - a = vector("a") - s = at.join(0, a, a, empty_vec) - f = function([a], s, mode=mode_opt) - val = f([1]) - assert np.all(val == [1]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 1 - assert all( - not isinstance(n.op, Join) or len(n.inputs) == 3 - for n in e - if isinstance(n.op, Join) - ) - assert f.maker.fgraph.outputs[0].dtype == config.floatX - - # test for matrix join(1,a) - empty_mat = np.asarray([[]], dtype=config.floatX) - m = matrix("m") - s = join(1, empty_mat, m, m, m) - f = function([m], s, mode=mode_opt) - val = f([[1]]) - assert np.all(val == [[1]]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 1 - assert all( - not isinstance(n.op, Join) or len(n.inputs) == 4 - for n in e - if isinstance(n.op, Join) - ) - assert f.maker.fgraph.outputs[0].dtype == config.floatX - # test for vector, vector, empty to matrix - # We can't optimize this case. - s = at.stack([a, a, empty_vec]) - f = function([a], s, mode=mode_opt) - val = f([]) - assert np.all(val == [1]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 1 - assert all( - not isinstance(n.op, Join) or len(n.inputs) == 4 - for n in e - if isinstance(n.op, Join) - ) - assert f.maker.fgraph.outputs[0].dtype == config.floatX - # test for matrix join(0,a) - # We can't optimize this case. - s = join(0, m, np.asarray([[2.0]], dtype=config.floatX), m) - f = function([m], s, mode=mode_opt) - val = f([[1]]) - assert np.all(val == [[1], [2], [1]]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 1 - assert all( - not isinstance(n.op, Join) or len(n.inputs) == 4 - for n in e - if isinstance(n.op, Join) - ) - assert f.maker.fgraph.outputs[0].dtype == config.floatX - - -def test_local_join_make_vector(): - a, b, c, d, e = scalars("abcde") - v = vector("v") - mv = MakeVector(config.floatX) - s = at.join(0, mv(a), v, mv(b, c), mv(d, e)) - f = function([a, b, c, d, e, v], s, mode=mode_opt) - val = f(1, 2, 3, 4, 6, [7, 8]) - assert np.all(val == [1, 7, 8, 2, 3, 4, 6]) - e = f.maker.fgraph.toposort() - assert len([n for n in e if isinstance(n.op, Join)]) == 1 - assert all( - not isinstance(n.op, Join) or len(n.inputs) == 4 - for n in e - if isinstance(n.op, Join) - ) - assert f.maker.fgraph.outputs[0].dtype == config.floatX - - assert check_stack_trace(f, ops_to_check="all") - - -@pytest.mark.parametrize( - "dtype", - [ - "int8", - "int16", - "int32", - "int64", - "uint8", - "uint16", - "uint32", - "uint64", - "float32", - "float64", - "complex64", - "complex128", - ], -) -def test_local_tensor_scalar_tensor(dtype): - t_type = TensorType(dtype=dtype, shape=()) - t = t_type() - s = at.scalar_from_tensor(t) - t2 = at.tensor_from_scalar(s) - - f = function([t], t2, mode=mode_opt) - e = f.maker.fgraph.toposort() - assert not any( - n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor)) - ) - - -@pytest.mark.parametrize( - "dtype", - [ - "int8", - "int16", - "int32", - "int64", - "uint8", - "uint16", - "uint32", - "uint64", - "float32", - "float64", - "complex64", - "complex128", - ], -) -def test_local_scalar_tensor_scalar(dtype): - s_type = aes.ScalarType(dtype=dtype) - s = s_type() - t = at.tensor_from_scalar(s) - s2 = at.scalar_from_tensor(t) - - f = function([s], s2, mode=mode_opt) - e = f.maker.fgraph.toposort() - assert not any( - n for n in e if isinstance(n.op, (TensorFromScalar, ScalarFromTensor)) - ) - - -def test_local_useless_split(): - x = matrix("x") - splits = ivector("splits") - opt = at.split(x, splits, n_splits=1) - nonopt = at.split(x, splits, n_splits=3) - - mode = get_default_mode().including("local_useless_split") - f_opt = function([x, splits], opt, mode=mode) - f_nonopt = function([x, splits], nonopt, mode=mode) - - f_opt(np.random.random((4, 4)).astype(config.floatX), [4]) - f_nonopt(np.random.random((4, 4)).astype(config.floatX), [1, 2, 1]) - graph_opt = f_opt.maker.fgraph.toposort() - graph_nonopt = f_nonopt.maker.fgraph.toposort() - - assert isinstance(graph_opt[-1].op, DeepCopyOp) - assert len(graph_nonopt) == 1 - assert isinstance(graph_nonopt[0].op, Split) - - assert check_stack_trace(f_opt, ops_to_check=[Assert]) - assert check_stack_trace(f_nonopt, ops_to_check="all") - - -@pytest.mark.parametrize("i", list(range(1, 4))) -def test_local_flatten_lift(i): - x = tensor4() - out = at.flatten(exp(x), i) - assert out.ndim == i - mode = get_default_mode() - mode = mode.including("local_reshape_lift") - f = function([x], out, mode=mode) - x_np = np.random.random((5, 4, 3, 2)).astype(config.floatX) - out_np = f(x_np) - topo = f.maker.fgraph.toposort() - shape_out_np = tuple(x_np.shape[: i - 1]) + (np.prod(x_np.shape[i - 1 :]),) - assert shape_out_np == out_np.shape - - reshape_nodes = [n for n in topo if isinstance(n.op, Reshape)] - assert len(reshape_nodes) == 1 and at.is_flat(reshape_nodes[0].outputs[0], ndim=i) - assert isinstance(topo[-1].op, Elemwise) - - -class TestReshape: - def setup_method(self): - self.mode = mode_opt - self.op = Reshape - - def test_local_reshape(self): - a = fmatrix() - b = self.op(3)(a, [2, 3, 4]) - c = self.op(1)(b, [24]) - f = function([a], c, mode=self.mode) - topo = f.maker.fgraph.toposort() - assert sum(isinstance(node.op, self.op) for node in topo) == 1 - - # Check stack trace - assert check_stack_trace(f, ops_to_check=[self.op]) - - -class TestLocalUselessReshape: - def setup_method(self): - self.rng = np.random.default_rng(utt.fetch_seed()) - - def test_0(self): - mode = get_default_mode().including("local_useless_reshape") - i = iscalar("i") - m = at.mgrid[ - 0:i, - ] - f = function([i], m, mode=mode) - topo = f.maker.fgraph.toposort() - assert not any(isinstance(n.op, Reshape) for n in topo) - - def test_1(self): - x = matrix("x") - r = x.reshape(x.shape) - - m0 = get_default_mode() - m1 = m0.including("local_useless_reshape") - f1 = function([x], r, mode=m1) - topo = f1.maker.fgraph.toposort() - assert not any(isinstance(n.op, Reshape) for n in topo) - - m2 = m1.excluding("ShapeOpt") - f2 = function([x], r, mode=m2) - topo = f2.maker.fgraph.toposort() - assert not any(isinstance(n.op, Reshape) for n in topo) - - # We do not need tests checking that stack traces are copied over, - # because local_useless_reshape only removes nodes from the graph - - def test_2(self): - x = matrix("x") - r = x.reshape([Shape_i(i)(x) for i in range(x.ndim)]) - - m0 = get_default_mode() - m1 = m0.including("local_useless_reshape") - f1 = function([x], r, mode=m1) - topo = f1.maker.fgraph.toposort() - assert not any(isinstance(n.op, Reshape) for n in topo) - - m2 = m1.excluding("ShapeOpt") - f2 = function([x], r, mode=m2) - topo = f2.maker.fgraph.toposort() - assert not any(isinstance(n.op, Reshape) for n in topo) - - def test_m1(self): - x = matrix("x") - r = x.reshape((x.shape[0], -1)) - - m0 = get_default_mode() - m1 = m0.including("local_useless_reshape") - f1 = function([x], r, mode=m1) - topo = f1.maker.fgraph.toposort() - assert not any(isinstance(n.op, Reshape) for n in topo) - - m2 = m1.excluding("ShapeOpt") - f2 = function([x], r, mode=m2) - topo = f2.maker.fgraph.toposort() - assert not any(isinstance(n.op, Reshape) for n in topo) - - -class TestLocalReshapeToDimshuffle: - def setup_method(self): - self.rng = np.random.default_rng(utt.fetch_seed()) - - def test_1(self): - reshape_lift = out2in(local_reshape_to_dimshuffle) - useless_reshape = out2in(local_useless_reshape) - x = shared(self.rng.standard_normal((4,))) - y = shared(self.rng.standard_normal((5, 6))) - reshape_x = reshape(x, (1, 4)) - reshape_y = reshape(y, (1, 5, 1, 6, 1, 1)) - - g = FunctionGraph([x, y], [reshape_x, reshape_y]) - assert str(g) == ( - "FunctionGraph(Reshape{2}" - "(, " - "TensorConstant{[1 4]}), " - "Reshape{6}" - "(, " - "TensorConstant{[1 5 1 6 1 1]}))" - ) - - reshape_lift.optimize(g) - useless_reshape.optimize(g) - assert str(g) == ( - "FunctionGraph(InplaceDimShuffle{x,0}" - "(), " - "InplaceDimShuffle{x,0,x,1,x,x}" - "(Reshape{2}(, " - "TensorConstant{[5 6]})))" - ) - - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape)) - - -def test_local_reshape_lift(): - x = tensor4() - out = exp(x).reshape([x.size]) - assert out.ndim == 1 - mode = get_default_mode() - mode = mode.including("local_reshape_lift") - f = function([x], out, mode=mode) - f(np.random.random((5, 4, 3, 2)).astype(config.floatX)) - topo = f.maker.fgraph.toposort() - assert isinstance(topo[-2].op, Reshape) - assert isinstance(topo[-1].op, Elemwise) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(f, ops_to_check="last") - - -class TestLiftTransposeThroughDot: - def simple_optimize(self, g): - out2in(local_useless_elemwise).optimize(g) - out2in(local_lift_transpose_through_dot).optimize(g) - out2in(local_useless_elemwise).optimize(g) - return g - - def test_matrix_matrix(self): - a, b = matrices("ab") - g = self.simple_optimize(FunctionGraph([a, b], [dot(a, b).T])) - sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{1,0}(a)))" - assert str(g) == sg, (str(g), sg) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(g, ops_to_check="all") - - def test_row_matrix(self): - a = vector("a") - b = matrix("b") - g = optimize( - FunctionGraph([a, b], [dot(a.dimshuffle("x", 0), b).T]), - level="stabilize", - ) - sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{0,x}(a)))" - assert str(g) == sg, (str(g), sg) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(g, ops_to_check="all") - - def test_matrix_col(self): - a = vector("a") - b = matrix("b") - g = optimize( - FunctionGraph([a, b], [dot(b, a.dimshuffle(0, "x")).T]), - level="stabilize", - ) - sg = "FunctionGraph(dot(InplaceDimShuffle{x,0}(a), InplaceDimShuffle{1,0}(b)))" - assert str(g) == sg, (str(g), sg) - # Check stacktrace was copied over correctly after opt was applied - assert check_stack_trace(g, ops_to_check="all") - - -def test_local_upcast_elemwise_constant_inputs(): - s = dvector("s") - x = at_sum(log(10**s)) - f = function([s], [aesara.gradient.grad(x, s)]) - f([-42, -2.1, -1, -0.5, 0, 0.2, 1, 2, 12]) - - # This test a corner where the optimization should not be applied. - with config.change_flags(floatX="float32"): - v = lvector() - function([v], true_div(v, 2)) - - -class TestShapeI(utt.InferShapeTester): - def setup_method(self): - super().setup_method() - - def test_perform(self): - rng = np.random.default_rng(utt.fetch_seed()) - - advec = vector() - advec_val = rng.random((3)).astype(config.floatX) - f = function([advec], Shape_i(0)(advec)) - out = f(advec_val) - utt.assert_allclose(out, advec_val.shape[0]) - - admat = matrix() - admat_val = rng.random((4, 3)).astype(config.floatX) - for i in range(2): - f = function([admat], Shape_i(i)(admat)) - out = f(admat_val) - utt.assert_allclose(out, admat_val.shape[i]) - - def test_infer_shape(self): - admat = matrix() - admat_val = np.random.random((3, 4)).astype(config.floatX) - self._compile_and_check([admat], [Shape_i(0)(admat)], [admat_val], Shape_i) - - self._compile_and_check([admat], [Shape_i(1)(admat)], [admat_val], Shape_i) - - -class TestSameShape: - def test_scalar(self): - x = scalar() - cst = at.constant(1) - o = x + cst - fgraph = FunctionGraph([x], [o], clone=False) - shape_feature = ShapeFeature() - fgraph.attach_feature(shape_feature) - assert shape_feature.same_shape(x, o) - - def test_vector(self): - x = vector() - cst = at.constant(1) - o = x + cst - fgraph = FunctionGraph([x], [o], clone=False) - shape_feature = ShapeFeature() - fgraph.attach_feature(shape_feature) - assert shape_feature.same_shape(x, o) - - def test_no_static_shapes(self): - x = vector() - y = vector() - o = x + y - fgraph = FunctionGraph([x, y], [o], clone=False) - shape_feature = ShapeFeature() - fgraph.attach_feature(shape_feature) - # We no longer assume that `x` has the same shape as `y` simply because - # neither has static shape information. Instead, when there is no - # static shape information is available, we assume that `x` and/or `y` - # could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any - # combination of the two. - assert not shape_feature.same_shape(x, o) - # The following case isn't implemented - assert not shape_feature.same_shape(y, o) - - @pytest.mark.parametrize( - "y_dim_0", - [2, pytest.param(None, marks=pytest.mark.xfail(reason="Not implemented"))], - ) - def test_vector_dim(self, y_dim_0): - x = at.tensor(dtype="floatX", shape=(2, None)) - y = at.tensor(dtype="floatX", shape=(y_dim_0, None)) - o = x + y - fgraph = FunctionGraph([x, y], [o], clone=False) - shape_feature = ShapeFeature() - fgraph.attach_feature(shape_feature) - assert shape_feature.same_shape(x, o, 0, 0) - assert not shape_feature.same_shape(x, o, 1, 1) - - def test_vector_dim_err(self): - x = vector() - y = vector() - o = x + y - fgraph = FunctionGraph([x, y], [o], clone=False) - shape_feature = ShapeFeature() - fgraph.attach_feature(shape_feature) - with pytest.raises(IndexError): - shape_feature.same_shape(x, o, 1, 0) - with pytest.raises(IndexError): - shape_feature.same_shape(x, o, 0, 1) - - -@pytest.mark.parametrize( - "shape", - [lscalar(), iscalar()], -) -def test_local_Shape_of_SpecifyShape(shape): - x = vector() - s = specify_shape(x, shape).shape - - fgraph = FunctionGraph(outputs=[s], clone=False) - _ = optimize_graph(fgraph, clone=False) - - assert x not in fgraph.variables - assert shape in fgraph.variables - - -@pytest.mark.parametrize( - "s1", - [lscalar(), iscalar()], -) -def test_local_Shape_of_SpecifyShape_partial(s1): - x = matrix() - s = specify_shape(x, (s1, None)).shape - - fgraph = FunctionGraph(outputs=[s], clone=False) - assert any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes) - - _ = optimize_graph(fgraph, clone=False) - - assert x in fgraph.variables - assert s1 in fgraph.variables - assert not any(isinstance(apply.op, SpecifyShape) for apply in fgraph.apply_nodes) - - -def test_local_Shape_i_of_broadcastable(): - x = tensor(np.float64, [False, True]) - s = Shape_i(1)(x) - - fgraph = FunctionGraph(outputs=[s], clone=False) - _ = optimize_graph(fgraph, clone=False) - - assert x not in fgraph.variables - assert fgraph.outputs[0].data == 1 - - # A test for a non-`TensorType` - class MyType(Type): - ndim = 1 - - def filter(self, *args, **kwargs): - raise NotImplementedError() - - def __eq__(self, other): - return isinstance(other, MyType) and other.thingy == self.thingy - - class MyVariable(Variable): - pass - - x = MyVariable(MyType(), None, None) - s = Shape_i(0)(x) - fgraph = FunctionGraph(outputs=[s], clone=False) - _ = optimize_graph(fgraph, clone=False) - - assert fgraph.outputs[0] == s - - -def test_assert_op_gradient(): - x = vector("x") - assert_op = Assert() - cost = at_sum(assert_op(x, x.size < 2)) - grad = aesara.gradient.grad(cost, x) - func = function([x], grad) - - x_val = np.ones(shape=(1,), dtype=config.floatX) - assert func(x_val) == 1 - - -def test_local_merge_alloc(): - # Add this opt to the default mode, - # otherwise, FAST_COMPILE fails. - default_mode = get_default_mode() - opt_mode = default_mode.including("local_merge_alloc") - - x = iscalar("x") - y = iscalar("y") - y2 = iscalar("y2") - z = iscalar("z") - w = iscalar("w") - m = fscalar("m") - # case 1 - # Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) - output = at.alloc(at.alloc(m, 1, y, 1, 1), x, y, z, w) - f = function([m, x, y, z, w], output, mode=opt_mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, Alloc) - o = f(0.0, 1, 2, 3, 4) - assert o.shape == (1, 2, 3, 4) - - # case 2 - # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) - output = at.alloc(at.alloc(m, y, 1, 1), x, y, z, w) - f = function([m, x, y, z, w], output, mode=opt_mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, Alloc) - o = f(0.0, 1, 2, 3, 4) - assert o.shape == (1, 2, 3, 4) - - # case 3 - # Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> - # Alloc(m, x, assert(y1, y1==y2), z, w) - output = at.alloc(at.alloc(m, y, 1, 1), x, y2, z, w) - f = function([m, x, y, y2, z, w], output, mode=opt_mode) - topo = f.maker.fgraph.toposort() - assert len(topo) == 3 - assert isinstance(topo[-2].op, Assert) - assert isinstance(topo[-1].op, Alloc) - o = f(0.0, 1, 2, 2, 3, 4) - assert o.shape == (1, 2, 3, 4) - with pytest.raises((AssertionError, ValueError)): - f(0.0, 1, 2, 5, 3, 4) - - -def test_local_useless_alloc(): - - useless_alloc = out2in(local_useless_alloc) - merge_alloc = out2in(local_merge_alloc) - - x = iscalar("x") - y = iscalar("y") - y2 = iscalar("y2") - z = iscalar("z") - w = iscalar("w") - m = fscalar("m") - - # case 1 - # Alloc(Alloc(m, x, 1, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) - output = at.alloc(at.alloc(m, 1, y, 1, 1), x, y, z, w) - g = FunctionGraph([m, x, y, z, w], [output]) - - useless_alloc.optimize(g) - merge_alloc.optimize(g) - useless_alloc.optimize(g) - - topo = g.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, Alloc) - - # case 2 - # Alloc(Alloc(m, y, 1, 1), x, y, z, w) -> Alloc(m, x, y, z, w) - output = at.alloc(at.alloc(m, y, 1, 1), x, y, z, w) - g = FunctionGraph([m, x, y, z, w], [output]) - - useless_alloc.optimize(g) - merge_alloc.optimize(g) - useless_alloc.optimize(g) - - topo = g.toposort() - assert len(topo) == 1 - assert isinstance(topo[0].op, Alloc) - - # case 3 - # Alloc(Alloc(m, y1, 1, 1), x, y2, z, w) -> - # Alloc(m, x, assert(y1, y1==y2), z, w) - output = at.alloc(at.alloc(m, y, 1, 1), x, y2, z, w) - g = FunctionGraph([m, x, y, y2, z, w], [output]) - - useless_alloc.optimize(g) - merge_alloc.optimize(g) - useless_alloc.optimize(g) - - topo = g.toposort() - assert len(topo) == 3 - assert isinstance(topo[-2].op, Assert) - assert isinstance(topo[-1].op, Alloc) - - -@pytest.mark.parametrize("return_index", [False]) -@pytest.mark.parametrize("return_counts", [False]) -@pytest.mark.parametrize("return_inverse", [False]) -def test_local_Unique_scalar(return_index, return_counts, return_inverse): - x = dscalar() - y = unique( - x, - return_index=return_index, - return_counts=return_counts, - return_inverse=return_inverse, - axis=None, - ) - - y_fg = FunctionGraph(outputs=[y], copy_inputs=False) - y_opt_fg = optimize_graph( - y_fg, clone=False, include=["canonicalize", "local_Unique_scalar"] - ) - y_opt = y_opt_fg.outputs[0] - y_opt_start = y_opt - - assert isinstance(y_opt_start.owner.op, DimShuffle) - assert y_opt_start.owner.inputs[0] == x - - default_mode = get_default_mode() - opt_mode = default_mode.excluding("local_Unique_scalar") - y_fn = function([x], [y, y_opt], mode=opt_mode) - - x_val = np.array(-10.0, dtype=np.float64) - y_exp_val, y_val = y_fn(x_val) - assert np.array_equal(y_exp_val, y_val) - - -@pytest.mark.parametrize( - "x_val, axis, new_shape", - [ - (np.array(-10, dtype=np.int64), None, ()), - (np.array(-10, dtype=np.int64), None, (2, 3)), - (np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)), - ], -) -@pytest.mark.parametrize("return_index", [False]) -@pytest.mark.parametrize("return_counts", [False]) -@pytest.mark.parametrize("return_inverse", [False]) -def test_local_Unique_Alloc_lift( - x_val, axis, new_shape, return_index, return_counts, return_inverse -): - x = as_tensor_variable(x_val).type() - y = unique( - alloc(x, *new_shape), - return_index=return_index, - return_counts=return_counts, - return_inverse=return_inverse, - axis=axis, - ) - - if isinstance(y, list): - y, *_ = y - - # This approach allows us to directly confirm that `x` is in the result. - y_fg = FunctionGraph(outputs=[y], copy_inputs=False) - y_opt_fg = optimize_graph( - y_fg, - clone=False, - include=["canonicalize", "local_Unique_Alloc_lift"], - exclude=["local_Unique_scalar"], - ) - y_opt = y_opt_fg.outputs[0] - y_opt_start = y_opt - - assert isinstance(y_opt_start.owner.op, Unique) - assert y_opt_start.owner.inputs[0] == x - assert not any(isinstance(node.op, Alloc) for node in y_opt_fg.apply_nodes) - - default_mode = get_default_mode() - # The optimization has already been applied to `y_opt`, so we can--and - # should--exclude it from the compilation of both our reference, `y`, and - # the optimized result, `y_opt`. - # The remaining exclusions simply allow us to perform the check below that - # makes sure the original `Alloc` is present in our reference (sub)graph. - opt_mode = default_mode.excluding( - "local_useless_alloc", "local_alloc_sink_dimshuffle", "local_Unique_Alloc_lift" - ) - y_fn = function([x], [y, y_opt], mode=opt_mode) - # Make sure that the original `Alloc` is used to compute the reference `y` - # result - assert any(isinstance(node.op, Alloc) for node in y_fn.maker.fgraph.apply_nodes) - - y_exp_val, y_val = y_fn(x_val) - assert np.array_equal(y_exp_val, y_val) - - -@pytest.mark.parametrize( - "x_val, axis, new_shape", - [ - (np.array(-10, dtype=np.int64), None, (2, 3)), - (np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)), - ], -) -@pytest.mark.parametrize("return_index", [False]) -@pytest.mark.parametrize("return_counts", [False]) -@pytest.mark.parametrize("return_inverse", [False]) -def test_local_Unique_BroadcastTo( - x_val, axis, new_shape, return_index, return_counts, return_inverse -): - x = as_tensor_variable(x_val).type() - y = unique( - BroadcastTo()(x, tuple(new_shape)), - return_index=return_index, - return_counts=return_counts, - return_inverse=return_inverse, - axis=axis, - ) - - if isinstance(y, list): - y, *_ = y - - # This approach allows us to directly confirm that `x` is in the result. - y_fg = FunctionGraph(outputs=[y], copy_inputs=False) - y_opt_fg = optimize_graph( - y_fg, - clone=False, - include=["canonicalize", "local_Unique_BroadcastTo_lift"], - exclude=["local_Unique_scalar"], - ) - y_opt = y_opt_fg.outputs[0] - y_opt_start = y_opt - - assert isinstance(y_opt_start.owner.op, Unique) - assert y_opt_start.owner.inputs[0] == x - assert not any(isinstance(node.op, BroadcastTo) for node in y_opt_fg.apply_nodes) - - default_mode = get_default_mode() - # The optimization has already been applied to `y_opt`, so we can--and - # should--exclude it from the compilation of both our reference, `y`, and - # the optimized result, `y_opt`. - opt_mode = default_mode.excluding("local_Unique_BroadcastTo_lift") - y_fn = function([x], [y, y_opt], mode=opt_mode) - # Make sure that the original `BroadcastTo` is used to compute the - # reference `y` result - assert any( - isinstance(node.op, BroadcastTo) for node in y_fn.maker.fgraph.apply_nodes - ) - - y_exp_val, y_val = y_fn(x_val) - assert np.array_equal(y_exp_val, y_val) - - -@pytest.mark.parametrize( - "x_val, unique_axis, repeats, repeat_axis", - [ - (np.array([[-10, -3], [-10, 2]], dtype=np.int64), None, (1, 2), 0), - ], -) -@pytest.mark.parametrize("return_index", [False]) -@pytest.mark.parametrize("return_counts", [False]) -@pytest.mark.parametrize("return_inverse", [False]) -def test_local_Unique_Repeat( - x_val, - unique_axis, - repeats, - repeat_axis, - return_index, - return_counts, - return_inverse, -): - x = as_tensor_variable(x_val).type() - y = unique( - repeat(x, tuple(repeats), axis=repeat_axis), - return_index=return_index, - return_counts=return_counts, - return_inverse=return_inverse, - axis=unique_axis, - ) - - if isinstance(y, list): - y, *_ = y - - # This approach allows us to directly confirm that `x` is in the result. - y_fg = FunctionGraph(outputs=[y], copy_inputs=False) - y_opt_fg = optimize_graph( - y_fg, - clone=False, - include=["canonicalize", "local_Unique_Repeat_lift"], - exclude=["local_Unique_scalar"], - ) - y_opt = y_opt_fg.outputs[0] - y_opt_start = y_opt - - assert isinstance(y_opt_start.owner.op, Unique) - assert y_opt_start.owner.inputs[0] == x - assert not any(isinstance(node.op, Repeat) for node in y_opt_fg.apply_nodes) - - default_mode = get_default_mode() - # The optimization has already been applied to `y_opt`, so we can--and - # should--exclude it from the compilation of both our reference, `y`, and - # the optimized result, `y_opt`. - opt_mode = default_mode.excluding("local_Unique_Repeat_lift") - y_fn = function([x], [y, y_opt], mode=opt_mode) - # Make sure that the original `BroadcastTo` is used to compute the - # reference `y` result - assert any(isinstance(node.op, Repeat) for node in y_fn.maker.fgraph.apply_nodes) - - y_exp_val, y_val = y_fn(x_val) - assert np.array_equal(y_exp_val, y_val) - - -@pytest.mark.parametrize( - "x_val, unique_axis, new_shape", - [ - (np.array(-10, dtype=np.int64), None, ()), - (np.array(-10, dtype=np.int64), None, (2, 3)), - (np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)), - ], -) -@pytest.mark.parametrize("return_index", [False]) -@pytest.mark.parametrize("return_counts", [False]) -@pytest.mark.parametrize("return_inverse", [False]) -def test_local_Unique_second( - x_val, unique_axis, new_shape, return_index, return_counts, return_inverse -): - x = as_tensor_variable(x_val).type() - a = np.zeros(tuple(new_shape), dtype=x.dtype) - y = unique( - second(a, x), - return_index=return_index, - return_counts=return_counts, - return_inverse=return_inverse, - axis=unique_axis, - ) - - if isinstance(y, list): - y, *_ = y - - # This approach allows us to directly confirm that `x` is in the result. - y_fg = FunctionGraph(outputs=[y], copy_inputs=False) - y_opt_fg = optimize_graph( - y_fg, - clone=False, - include=["canonicalize", "local_Unique_second_lift"], - exclude=["local_Unique_scalar", "topo_constant_folding"], - ) - y_opt = y_opt_fg.outputs[0] - y_opt_start = y_opt - - assert isinstance(y_opt_start.owner.op, Unique) - - y_opt_start = y_opt_start.owner.inputs[0] - - if y_opt_start.owner and isinstance(y_opt_start.owner.op, DimShuffle): - y_opt_start = y_opt_start.owner.inputs[0] - - assert y_opt_start == x - assert not any( - isinstance(node.op.scalar_op, aes.Second) - for node in y_opt_fg.apply_nodes - if isinstance(node.op, Elemwise) - ) - - # The optimization has already been applied to `y_opt`, so we can--and - # should--exclude it from the compilation of both our reference, `y`, and - # the optimized result, `y_opt`. - y_fn = function([x], [y, y_opt], mode=Mode(optimizer=OPT_NONE)) - - # Make sure that the original `BroadcastTo` is used to compute the - # reference `y` result - assert any( - isinstance(node.op.scalar_op, aes.Second) - for node in y_fn.maker.fgraph.apply_nodes - if isinstance(node.op, Elemwise) - ) - - y_exp_val, y_val = y_fn(x_val) - assert np.array_equal(y_exp_val, y_val) - - -def test_local_merge_consecutive_specify_shape(): - x = matrix() - s = at.as_tensor([iscalar(), iscalar()]) - y = specify_shape(specify_shape(x, s), s) - - y_fg = FunctionGraph(outputs=[y], copy_inputs=False) - y_opt_fg = optimize_graph( - y_fg, - clone=False, - include=["canonicalize", "local_merge_consecutive_specify_shape"], - ) - y_opt = y_opt_fg.outputs[0] - - assert isinstance(y_opt.owner.op, SpecifyShape) - assert y_opt.owner.inputs[0] == x - - -def test_local_merge_consecutive_specify_shape2(): - x = tensor3() - s1, s2, s3, s4 = iscalars("s1", "s2", "s3", "s4") - y = specify_shape(specify_shape(x, [s1, s2, None]), [None, s3, s4]) - - y_fg = FunctionGraph(outputs=[y], copy_inputs=False) - y_opt_fg = optimize_graph( - y_fg, - clone=False, - include=["canonicalize", "local_merge_consecutive_specify_shape"], - ) - y_opt = y_opt_fg.outputs[0] - - assert isinstance(y_opt.owner.op, SpecifyShape) - assert tuple(y_opt.owner.inputs) == (x, s1, s3, s4) - - -def test_printing(): - a, b = scalars("ab") - mv = MakeVector(config.floatX) - v = mv(a, b) - assert pprint(v) == "[a, b]" - - -def test_local_remove_scalar_BroadcastTo(): - x = dscalar() - y = BroadcastTo()(x, ()) - - assert isinstance(y.owner.op, BroadcastTo) - - res = optimize_graph( - y, clone=False, include=["canonicalize", "local_remove_scalar_BroadcastTo"] - ) - - assert res is x - - -def test_local_useless_dimshuffle_makevector(): - a = scalar() - x = MakeVector(config.floatX)(a) - y = x.dimshuffle(()) - - y_fg = FunctionGraph(outputs=[y], copy_inputs=False) - - y_opt_fg = optimize_graph( - y_fg, - clone=False, - include=["canonicalize", "local_useless_dimshuffle_makevector"], - ) - - assert y_opt_fg.outputs[0] == a - - -def test_Shape_i_canonicalize(): - """Make sure the canonicalizations work together to produce the correct graphs for shapes in a single dimension. - - In other words, ``shape(x)[i]`` should result in a simple ``Shape_i(0)(x)`` - and nothing else. The rewrites `local_shape_to_shape_i`, - `local_subtensor_remove_broadcastable_index`, and - `local_useless_dimshuffle_makevector` need to work together to accomplish - this, and we confirm that here. - """ - x = vector() - y = shape(x)[0] - - y_fg = FunctionGraph(outputs=[y], copy_inputs=False, features=[ShapeFeature()]) - - y_opt_fg = optimize_graph( - y_fg, - clone=False, - include=[ - "canonicalize", - ], - ) - - y_opt = y_opt_fg.outputs[0] - - assert isinstance(y_opt.owner.op, Shape_i) - assert y_opt.owner.op.i == 0 - assert y_opt.owner.inputs[0] == x - - -@pytest.mark.parametrize( - "expr, x_shape, y_shape", - [ - pytest.param( - lambda x, y: at.mul(y, at.alloc(1, x)), - (), - (), - marks=pytest.mark.xfail(reason="Not implemented"), - ), - (lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)), - (lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)), - (lambda x, y: at.mul(at.alloc(x, 15, 1), at.alloc(y, 15, 1)), (15, 1), (15, 1)), - (lambda x, y: at.mul(at.alloc(x, 15, 2), at.alloc(y, 15, 2)), (15, 2), (15, 2)), - (lambda x, y: at.mul(at.alloc(x, 15, 2).dimshuffle(1, 0), y), (15, 2), (2, 15)), - (lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2)), - ( - lambda x, y: at.mul(at.alloc(x, 1, 15, 2).dimshuffle(0, 2, 1), y), - (15, 2), - (2, 15), - ), - ], -) -def test_local_elemwise_alloc(expr, x_shape, y_shape): - x = at.tensor("int64", (False,) * len(x_shape)) - y = at.tensor("int64", (False,) * len(y_shape)) - z = expr(x, y) - - z_opt = aesara.function( - [x, y], - z, - mode=get_default_mode().including("local_elemwise_alloc"), - on_unused_input="ignore", - ) - - assert not any(isinstance(node.op, Alloc) for node in z_opt.maker.fgraph.toposort()) - - z_no_opt = aesara.function( - [x, y], - z, - mode=get_default_mode().excluding("local_elemwise_alloc"), - on_unused_input="ignore", - ) - - x_val = np.arange(np.prod(x_shape), dtype=np.int64).reshape(x_shape) - y_val = np.arange(np.prod(y_shape), dtype=np.int64).reshape(y_shape) - - res = z_opt(x_val, y_val) - exp_res = z_no_opt(x_val, y_val) - assert np.array_equal(res, exp_res) - - -def test_local_elemwise_alloc_single_input(): - # Test that rewrite is not triggered when there is only one Alloc in an Elemwise - x = at.matrix("x") - z = at.exp(at.alloc(x, 15, 1)) - - z_fg = FunctionGraph(outputs=[z], copy_inputs=False, features=[ShapeFeature()]) - - z_opt_fg = optimize_graph(z_fg, clone=False, include=["local_elemwise_alloc"]) - assert any(isinstance(node.op, Alloc) for node in z_opt_fg.apply_nodes) diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index 3cb011c478..96de2667ab 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -17,7 +17,7 @@ from aesara.configdefaults import config from aesara.gradient import grad from aesara.graph.fg import FunctionGraph -from aesara.graph.opt import in2out +from aesara.graph.rewriting.basic import in2out from aesara.graph.utils import InconsistencyError from aesara.misc.safe_asarray import _asarray from aesara.tensor import inplace diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index d4638e5b74..6bd514f277 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -17,10 +17,11 @@ from aesara.link.c.basic import CLinker, OpWiseCLinker from aesara.tensor import as_tensor_variable from aesara.tensor.basic import second -from aesara.tensor.basic_opt import ShapeError from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise +from aesara.tensor.exceptions import ShapeError from aesara.tensor.math import all as at_all from aesara.tensor.math import any as at_any +from aesara.tensor.math import exp from aesara.tensor.type import ( TensorType, bmatrix, @@ -854,6 +855,40 @@ def test_shape_types(self): assert all(isinstance(v.type, TensorType) for v in out_shape) + def test_static_shape_unary(self): + x = tensor("float64", shape=(None, 0, 1, 5)) + exp(x).type.shape == (None, 0, 1, 5) + + def test_static_shape_binary(self): + x = tensor("float64", shape=(None, 5)) + y = tensor("float64", shape=(None, 5)) + assert (x + y).type.shape == (None, 5) + + x = tensor("float64", shape=(None, 5)) + y = tensor("float64", shape=(10, 5)) + assert (x + y).type.shape == (10, 5) + + x = tensor("float64", shape=(1, 5)) + y = tensor("float64", shape=(10, 5)) + assert (x + y).type.shape == (10, 5) + + x = tensor("float64", shape=(None, 1)) + y = tensor("float64", shape=(1, 1)) + assert (x + y).type.shape == (None, 1) + + x = tensor("float64", shape=(0, 0, 0)) + y = tensor("float64", shape=(0, 1, None)) + assert (x + y).type.shape == (0, 0, 0) + + def test_invalid_static_shape(self): + x = tensor("float64", shape=(2,)) + y = tensor("float64", shape=(3,)) + with pytest.raises( + ValueError, + match=re.escape("Incompatible Elemwise input shapes [(2,), (3,)]"), + ): + x + y + def test_not_implemented_elemwise_grad(): # Regression test for unimplemented gradient in an Elemwise Op. diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 359b0571bd..22759226bf 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -8,8 +8,8 @@ from aesara import tensor as at from aesara.compile.mode import Mode from aesara.configdefaults import config -from aesara.graph.basic import applys_between -from aesara.graph.optdb import OptimizationQuery +from aesara.graph.basic import Constant, applys_between +from aesara.graph.rewriting.db import RewriteDatabaseQuery from aesara.raise_op import Assert from aesara.tensor.elemwise import DimShuffle from aesara.tensor.extra_ops import ( @@ -1143,6 +1143,35 @@ def shape_tuple(x, use_bcast=True): assert isinstance(b_at[-1].owner.op, Assert) +def test_broadcast_shape_constants(): + """Make sure `broadcast_shape` uses constants when it can.""" + x1_shp_at = iscalar("x1") + y2_shp_at = iscalar("y2") + b_at = broadcast_shape((x1_shp_at, 2), (3, y2_shp_at), arrays_are_shapes=True) + assert len(b_at) == 2 + assert isinstance(b_at[0].owner.op, Assert) + assert b_at[0].owner.inputs[0].value.item() == 3 + assert isinstance(b_at[1].owner.op, Assert) + assert b_at[1].owner.inputs[0].value.item() == 2 + + b_at = broadcast_shape((1, 2), (3, 2), arrays_are_shapes=True) + assert len(b_at) == 2 + assert all(isinstance(x, Constant) for x in b_at) + assert b_at[0].value.item() == 3 + assert b_at[1].value.item() == 2 + + b_at = broadcast_shape((1,), (1, 1), arrays_are_shapes=True) + assert len(b_at) == 2 + assert all(isinstance(x, Constant) for x in b_at) + assert b_at[0].value.item() == 1 + assert b_at[1].value.item() == 1 + + b_at = broadcast_shape((1,), (1,), arrays_are_shapes=True) + assert len(b_at) == 1 + assert all(isinstance(x, Constant) for x in b_at) + assert b_at[0].value.item() == 1 + + @pytest.mark.parametrize( ("s1_vals", "s2_vals", "exp_res"), [ @@ -1169,6 +1198,31 @@ def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res): assert tuple(res.eval(eval_point)) == exp_res +def test_broadcast_shape_symbolic_one_symbolic(): + """Test case for a constant non-broadcast shape and a symbolic shape.""" + one_at = at.as_tensor(1, dtype=np.int64) + three_at = at.as_tensor(3, dtype=np.int64) + int_div = one_at / one_at + + assert int_div.owner.op == at.true_div + + index_shapes = [ + (one_at, one_at, three_at), + (one_at, int_div, one_at), + (one_at, one_at, int_div), + ] + + res_shape = broadcast_shape(*index_shapes, arrays_are_shapes=True) + + from aesara.graph.rewriting.utils import rewrite_graph + + res_shape = rewrite_graph(res_shape) + + assert res_shape[0].data == 1 + assert res_shape[1].data == 1 + assert res_shape[2].data == 3 + + class TestBroadcastTo(utt.InferShapeTester): def setup_method(self): super().setup_method() @@ -1188,24 +1242,70 @@ def test_avoid_useless_subtensors(self): assert y.owner.inputs[1].owner is None assert y.owner.inputs[2].owner is None - @config.change_flags(compute_test_value="raise") - def test_perform(self): - a = scalar() - a.tag.test_value = 5 + @pytest.mark.parametrize("linker", ["cvm", "py"]) + def test_perform(self, linker): + a = aesara.shared(5) s_1 = iscalar("s_1") - s_1.tag.test_value = 4 shape = (s_1, 1) bcast_res = broadcast_to(a, shape) - assert bcast_res.broadcastable == (False, True) + bcast_fn = aesara.function( + [s_1], bcast_res, mode=Mode(optimizer=None, linker=linker) + ) + bcast_fn.vm.allow_gc = False + + bcast_at = bcast_fn(4) bcast_np = np.broadcast_to(5, (4, 1)) - bcast_at = bcast_res.get_test_value() assert np.array_equal(bcast_at, bcast_np) - assert np.shares_memory(bcast_at, a.get_test_value()) + + bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0] + bcast_in = bcast_fn.vm.storage_map[a] + bcast_out = bcast_fn.vm.storage_map[bcast_var] + + if linker != "py": + assert np.shares_memory(bcast_out[0], bcast_in[0]) + + @pytest.mark.skipif( + not config.cxx, reason="G++ not available, so we need to skip this test." + ) + def test_memory_leak(self): + import gc + import tracemalloc + + from aesara.link.c.cvm import CVM + + n = 100_000 + x = aesara.shared(np.ones(n, dtype=np.float64)) + y = broadcast_to(x, (5, n)) + + f = aesara.function([], y, mode=Mode(optimizer=None, linker="cvm")) + assert isinstance(f.vm, CVM) + + assert len(f.maker.fgraph.apply_nodes) == 2 + assert any( + isinstance(node.op, BroadcastTo) for node in f.maker.fgraph.apply_nodes + ) + + tracemalloc.start() + + blocks_last = None + block_diffs = [] + for i in range(1, 50): + x.set_value(np.ones(n)) + _ = f() + _ = gc.collect() + blocks_i, _ = tracemalloc.get_traced_memory() + if blocks_last is not None: + blocks_diff = (blocks_i - blocks_last) // 10**3 + block_diffs.append(blocks_diff) + blocks_last = blocks_i + + tracemalloc.stop() + assert np.allclose(np.mean(block_diffs), 0) @pytest.mark.parametrize( "fn,input_dims", @@ -1256,7 +1356,7 @@ def test_inplace(self): q = b[np.r_[0, 1, 3]] e = at.set_subtensor(q, np.r_[0, 0, 0]) - opts = OptimizationQuery(include=["inplace"]) + opts = RewriteDatabaseQuery(include=["inplace"]) py_mode = Mode("py", opts) e_fn = function([d], e, mode=py_mode) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index ba3e17f7aa..71114a03dd 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -1,7 +1,6 @@ import builtins import operator import pickle -import warnings from copy import copy from functools import reduce from itertools import product @@ -144,6 +143,7 @@ from tests import unittest_tools as utt from tests.link.test_link import make_function from tests.tensor.utils import ( + ALL_DTYPES, _bad_build_broadcast_binary_normal, _bad_runtime_broadcast_binary_normal, _bad_runtime_reciprocal, @@ -177,7 +177,6 @@ copymod, div_grad_rtol, eval_outputs, - get_numeric_types, ignore_isfinite_mode, inplace_func, integers, @@ -2225,139 +2224,137 @@ class TestArithmeticCast: """ - def test_arithmetic_cast(self): - dtypes = get_numeric_types(with_complex=True) + @pytest.mark.parametrize( + "op", + [ + operator.add, + operator.sub, + operator.mul, + operator.truediv, + operator.floordiv, + ], + ) + @pytest.mark.parametrize("a_type", ALL_DTYPES) + @pytest.mark.parametrize("b_type", ALL_DTYPES) + @pytest.mark.parametrize( + "combo", + [ + ("scalar", "scalar"), + ("array", "array"), + ("scalar", "array"), + ("array", "scalar"), + ("i_scalar", "i_scalar"), + ], + ) + def test_arithmetic_cast(self, op, a_type, b_type, combo): + if op is operator.floordiv and ( + a_type.startswith("complex") or b_type.startswith("complex") + ): + pytest.skip("Not supported by NumPy") # Here: # scalar == scalar stored as a 0d array # array == 1d array # i_scalar == scalar type used internally by Aesara - def Aesara_scalar(dtype): + def aesara_scalar(dtype): return scalar(dtype=str(dtype)) def numpy_scalar(dtype): return np.array(1, dtype=dtype) - def Aesara_array(dtype): + def aesara_array(dtype): return vector(dtype=str(dtype)) def numpy_array(dtype): return np.array([1], dtype=dtype) - def Aesara_i_scalar(dtype): + def aesara_i_scalar(dtype): return aes.ScalarType(str(dtype))() def numpy_i_scalar(dtype): return numpy_scalar(dtype) - with warnings.catch_warnings(): - # Avoid deprecation warning during tests. - warnings.simplefilter("ignore", category=DeprecationWarning) - for cfg in ("numpy+floatX",): # Used to test 'numpy' as well. - with config.change_flags(cast_policy=cfg): - for op in ( - operator.add, - operator.sub, - operator.mul, - operator.truediv, - operator.floordiv, - ): - for a_type in dtypes: - for b_type in dtypes: - - # We will test all meaningful combinations of - # scalar and array operations. - for combo in ( - ("scalar", "scalar"), - ("array", "array"), - ("scalar", "array"), - ("array", "scalar"), - ("i_scalar", "i_scalar"), - ): - - Aesara_args = list( - map(eval, [f"Aesara_{c}" for c in combo]) - ) - numpy_args = list( - map(eval, [f"numpy_{c}" for c in combo]) - ) - Aesara_dtype = op( - Aesara_args[0](a_type), - Aesara_args[1](b_type), - ).type.dtype - - # For numpy we have a problem: - # http://projects.scipy.org/numpy/ticket/1827 - # As a result we only consider the highest data - # type that numpy may return. - numpy_dtypes = [ - op( - numpy_args[0](a_type), numpy_args[1](b_type) - ).dtype, - op( - numpy_args[1](b_type), numpy_args[0](a_type) - ).dtype, - ] - numpy_dtype = aes.upcast( - *list(map(str, numpy_dtypes)) - ) - if numpy_dtype == Aesara_dtype: - # Same data type found, all is good! - continue - if ( - cfg == "numpy+floatX" - and config.floatX == "float32" - and a_type != "float64" - and b_type != "float64" - and numpy_dtype == "float64" - ): - # We should keep float32. - assert Aesara_dtype == "float32" - continue - if "array" in combo and "scalar" in combo: - # For mixed scalar / array operations, - # Aesara may differ from numpy as it does - # not try to prevent the scalar from - # upcasting the array. - array_type, scalar_type = ( - (a_type, b_type)[list(combo).index(arg)] - for arg in ("array", "scalar") - ) - up_type = aes.upcast(array_type, scalar_type) - if ( - # The two data types are different. - scalar_type != array_type - and - # The array type is not enough to hold - # the scalar type as well. - array_type != up_type - and - # Aesara upcasted the result array. - Aesara_dtype == up_type - and - # But Numpy kept its original type. - array_type == numpy_dtype - ): - # Then we accept this difference in - # behavior. - continue - - if ( - cfg == "numpy+floatX" - and a_type == "complex128" - and (b_type == "float32" or b_type == "float16") - and combo == ("scalar", "array") - and Aesara_dtype == "complex128" - and numpy_dtype == "complex64" - ): - # In numpy 1.6.x adding a complex128 with - # a float32 may result in a complex64. As - # of 1.9.2. this is still the case so it is - # probably by design - pytest.skip("Known issue with" "numpy see #761") - # In any other situation: something wrong is - # going on! - raise AssertionError() + with config.change_flags(cast_policy="numpy+floatX"): + # We will test all meaningful combinations of + # scalar and array operations. + aesara_args = list(map(eval, [f"aesara_{c}" for c in combo])) + numpy_args = list(map(eval, [f"numpy_{c}" for c in combo])) + aesara_arg_1 = aesara_args[0](a_type) + aesara_arg_2 = aesara_args[1](b_type) + aesara_dtype = op( + aesara_arg_1, + aesara_arg_2, + ).type.dtype + + # For numpy we have a problem: + # http://projects.scipy.org/numpy/ticket/1827 + # As a result we only consider the highest data + # type that numpy may return. + numpy_arg_1 = numpy_args[0](a_type) + numpy_arg_2 = numpy_args[1](b_type) + numpy_dtypes = [ + op(numpy_arg_1, numpy_arg_2).dtype, + op(numpy_arg_2, numpy_arg_1).dtype, + ] + numpy_dtype = aes.upcast(*list(map(str, numpy_dtypes))) + + if numpy_dtype == aesara_dtype: + # Same data type found, all is good! + return + + if ( + config.floatX == "float32" + and a_type != "float64" + and b_type != "float64" + and numpy_dtype == "float64" + ): + # We should keep float32. + assert aesara_dtype == "float32" + return + + if "array" in combo and "scalar" in combo: + # For mixed scalar / array operations, + # Aesara may differ from numpy as it does + # not try to prevent the scalar from + # upcasting the array. + array_type, scalar_type = ( + (a_type, b_type)[list(combo).index(arg)] + for arg in ("array", "scalar") + ) + up_type = aes.upcast(array_type, scalar_type) + if ( + # The two data types are different. + scalar_type != array_type + and + # The array type is not enough to hold + # the scalar type as well. + array_type != up_type + and + # Aesara upcasted the result array. + aesara_dtype == up_type + and + # But Numpy kept its original type. + array_type == numpy_dtype + ): + # Then we accept this difference in + # behavior. + return + + if ( + {a_type, b_type} == {"complex128", "float32"} + or {a_type, b_type} == {"complex128", "float16"} + and set(combo) == {"scalar", "array"} + and aesara_dtype == "complex128" + and numpy_dtype == "complex64" + ): + # In numpy 1.6.x adding a complex128 with + # a float32 may result in a complex64. As + # of 1.9.2. this is still the case so it is + # probably by design + pytest.skip("Known issue with" "numpy see #761") + # In any other situation: something wrong is + # going on! + raise AssertionError() def test_divmod(): diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 9671e6c60d..b1647e77a0 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -53,6 +53,7 @@ def scipy_special_gammal(k, x): expected_erfc = scipy.special.erfc expected_erfinv = scipy.special.erfinv expected_erfcinv = scipy.special.erfcinv +expected_owenst = scipy.special.owens_t expected_gamma = scipy.special.gamma expected_gammaln = scipy.special.gammaln expected_psi = scipy.special.psi @@ -146,6 +147,55 @@ def scipy_special_gammal(k, x): mode=mode_no_scipy, ) +rng = np.random.default_rng(seed=utt.fetch_seed()) +_good_broadcast_binary_owenst = dict( + normal=( + random_ranged(-5, 5, (2, 3), rng=rng), + random_ranged(-5, 5, (2, 3), rng=rng), + ), + empty=(np.asarray([], dtype=config.floatX), np.asarray([], dtype=config.floatX)), + int=( + integers_ranged(-5, 5, (2, 3), rng=rng), + integers_ranged(-5, 5, (2, 3), rng=rng), + ), + uint8=( + integers_ranged(1, 6, (2, 3), rng=rng).astype("uint8"), + integers_ranged(1, 6, (2, 3), rng=rng).astype("uint8"), + ), + uint16=( + integers_ranged(1, 10, (2, 3), rng=rng).astype("uint16"), + integers_ranged(1, 10, (2, 3), rng=rng).astype("uint16"), + ), + uint64=( + integers_ranged(1, 10, (2, 3), rng=rng).astype("uint64"), + integers_ranged(1, 10, (2, 3), rng=rng).astype("uint64"), + ), +) + +_grad_broadcast_binary_owenst = dict( + normal=( + random_ranged(-5, 5, (2, 3), rng=rng), + random_ranged(-5, 5, (2, 3), rng=rng), + ) +) + +TestOwensTBroadcast = makeBroadcastTester( + op=at.owens_t, + expected=expected_owenst, + good=_good_broadcast_binary_owenst, + grad=_grad_broadcast_binary_owenst, + eps=2e-10, + mode=mode_no_scipy, +) +TestOwensTInplaceBroadcast = makeBroadcastTester( + op=inplace.owens_t_inplace, + expected=expected_owenst, + good=_good_broadcast_binary_owenst, + eps=2e-10, + mode=mode_no_scipy, + inplace=True, +) + rng = np.random.default_rng(seed=utt.fetch_seed()) _good_broadcast_unary_gammaln = dict( normal=(random_ranged(-1 + 1e-2, 10, (2, 3), rng=rng),), diff --git a/tests/tensor/test_merge.py b/tests/tensor/test_merge.py index 765ee5dc1c..0879d30e06 100644 --- a/tests/tensor/test_merge.py +++ b/tests/tensor/test_merge.py @@ -4,7 +4,7 @@ from aesara.graph.basic import Apply, Variable from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op -from aesara.graph.opt import MergeOptimizer +from aesara.graph.rewriting.basic import MergeOptimizer from aesara.graph.type import Type @@ -72,7 +72,7 @@ def test_merge_with_weird_eq(): x = at.constant(np.asarray(1), name="x") y = at.constant(np.asarray(1), name="y") g = FunctionGraph([x, y], [x + y]) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) assert len(g.apply_nodes) == 1 node = list(g.apply_nodes)[0] @@ -84,7 +84,7 @@ def test_merge_with_weird_eq(): x = at.constant(np.ones(5), name="x") y = at.constant(np.ones(5), name="y") g = FunctionGraph([x, y], [x + y]) - MergeOptimizer().optimize(g) + MergeOptimizer().rewrite(g) assert len(g.apply_nodes) == 1 node = list(g.apply_nodes)[0] diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 153ea2ef49..1db5d510ec 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -11,8 +11,8 @@ from aesara.misc.safe_asarray import _asarray from aesara.tensor import as_tensor_variable, get_vector_length, row from aesara.tensor.basic import MakeVector, constant -from aesara.tensor.basic_opt import ShapeFeature from aesara.tensor.elemwise import DimShuffle, Elemwise +from aesara.tensor.rewriting.shape import ShapeFeature from aesara.tensor.shape import ( Reshape, Shape_i, diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index ccfee38005..bc57c97950 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -13,7 +13,7 @@ from aesara.compile.io import In from aesara.configdefaults import config from aesara.graph.op import get_test_value -from aesara.graph.opt_utils import is_same_graph +from aesara.graph.rewriting.utils import is_same_graph from aesara.printing import pprint from aesara.scalar.basic import as_scalar from aesara.tensor import get_vector_length @@ -351,7 +351,7 @@ def test_err_bounds(self): t = n[7] assert isinstance(t.owner.op, Subtensor) # Silence expected error messages - _logger = logging.getLogger("aesara.graph.opt") + _logger = logging.getLogger("aesara.graph.rewriting.basic") oldlevel = _logger.level _logger.setLevel(logging.CRITICAL) try: @@ -432,7 +432,7 @@ def test_err_bounds0(self): t = n[idx] assert isinstance(t.owner.op, Subtensor) # Silence expected warnings - _logger = logging.getLogger("aesara.graph.opt") + _logger = logging.getLogger("aesara.graph.rewriting.basic") oldlevel = _logger.level _logger.setLevel(logging.CRITICAL) try: @@ -2463,90 +2463,83 @@ def test_basic_shape(): assert get_test_value(res) == (2,) -@config.change_flags(compute_test_value="raise") -def test_indexed_result_shape(): - _test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True])) - - test_shape = (5, 6, 7, 8) - test_array = np.arange(np.prod(test_shape)).reshape(test_shape) - - def idx_as_tensor(x): - if isinstance(x, (slice, type(None))): - return x - else: - return at.as_tensor(x) - - def bcast_shape_tuple(x): - if not hasattr(x, "shape"): - return x - return tuple( - s if not bcast else 1 for s, bcast in zip(tuple(x.shape), x.broadcastable) - ) - - def compare_index_shapes(test_array, test_idx): - res = indexed_result_shape( - at.as_tensor(test_array).shape, [idx_as_tensor(i) for i in test_idx] - ) - exp_res = test_array[test_idx].shape - assert np.array_equal(tuple(get_test_value(r) for r in res), exp_res) - - # Test shape-only version - res = indexed_result_shape( - at.as_tensor(test_array).shape, - [bcast_shape_tuple(idx_as_tensor(i)) for i in test_idx], - indices_are_shapes=True, - ) - exp_res = test_array[test_idx].shape - assert np.array_equal(tuple(get_test_value(r) for r in res), exp_res) - - # Simple basic indices - test_idx = (slice(None, None),) - compare_index_shapes(test_array, test_idx) - - # Advanced indices - test_idx = (2,) - compare_index_shapes(test_array, test_idx) - - test_idx = _test_idx[:1] - compare_index_shapes(test_array, test_idx) +def idx_as_tensor(x): + if isinstance(x, (slice, type(None))): + return x + else: + return at.as_tensor(x) - test_idx = _test_idx[:2] - compare_index_shapes(test_array, test_idx) - # A Mix of advanced and basic indices - test_idx = _test_idx[:2] + (slice(None, None),) - compare_index_shapes(test_array, test_idx) - - test_idx = (slice(None, None),) + _test_idx[1:] - compare_index_shapes(test_array, test_idx) - - test_idx = (slice(None, None), None) + _test_idx[1:2] - compare_index_shapes(test_array, test_idx) +def bcast_shape_tuple(x): + if not hasattr(x, "shape"): + return x + return tuple( + s if not bcast else 1 for s, bcast in zip(tuple(x.shape), x.broadcastable) + ) - test_idx = (np.array(1), slice(None, None), None) - compare_index_shapes(test_array, test_idx) - test_idx = (slice(None, None), None, np.array(1)) - compare_index_shapes(test_array, test_idx) +test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True])) - test_idx = _test_idx[:1] + (slice(None, None),) + _test_idx[1:2] - compare_index_shapes(test_array, test_idx) - test_idx = ( - _test_idx[:1] + (slice(None, None),) + _test_idx[1:2] + (slice(None, None),) +@pytest.mark.parametrize( + "test_array, test_idx", + [ + (np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), (slice(None, None),)), + (np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), (2,)), + (np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), test_idx[:1]), + (np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), test_idx[:2]), + ( + np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), + test_idx[:2] + (slice(None, None),), + ), + ( + np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), + (slice(None, None),) + test_idx[:1], + ), + ( + np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), + (slice(None, None), None) + test_idx[1:2], + ), + ( + np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), + (np.array(1), slice(None, None), None), + ), + ( + np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), + (slice(None, None), None, np.array(1)), + ), + ( + np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), + test_idx[:1] + (slice(None, None),) + test_idx[1:2], + ), + ( + np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), + test_idx[:1] + (slice(None, None),) + test_idx[1:2] + (slice(None, None),), + ), + ( + np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), + test_idx[:1] + (None,) + test_idx[1:2], + ), + (np.arange(np.prod((5, 4))).reshape((5, 4)), ([1, 3, 2], slice(1, 3))), + (np.arange(np.prod((5, 4))).reshape((5, 4)), (slice(1, 3), [1, 3, 2])), + ], +) +@config.change_flags(compute_test_value="raise") +def test_indexed_result_shape(test_array, test_idx): + res = indexed_result_shape( + at.as_tensor(test_array).shape, [idx_as_tensor(i) for i in test_idx] ) - compare_index_shapes(test_array, test_idx) - - test_idx = _test_idx[:1] + (None,) + _test_idx[1:2] - compare_index_shapes(test_array, test_idx) - - test_shape = (5, 4) - test_array = np.arange(np.prod(test_shape)).reshape(test_shape) - test_idx = ([1, 3, 2], slice(1, 3)) - compare_index_shapes(test_array, test_idx) - - test_idx = (slice(1, 3), [1, 3, 2]) - compare_index_shapes(test_array, test_idx) + exp_res = test_array[test_idx].shape + assert np.array_equal(tuple(get_test_value(r) for r in res), exp_res) + + # Test shape-only version + res = indexed_result_shape( + at.as_tensor(test_array).shape, + [bcast_shape_tuple(idx_as_tensor(i)) for i in test_idx], + indices_are_shapes=True, + ) + exp_res = test_array[test_idx].shape + assert np.array_equal(tuple(get_test_value(r) for r in res), exp_res) def test_symbolic_slice(): @@ -2622,3 +2615,8 @@ def test_index_vars_to_types(): res = index_vars_to_types(iscalar) assert isinstance(res, scal.ScalarType) + + x = scal.constant(1, dtype=np.uint8) + assert isinstance(x.type, scal.ScalarType) + res = index_vars_to_types(x) + assert res == x.type diff --git a/tests/tensor/test_type.py b/tests/tensor/test_type.py index 2665fb7dc6..eed93e2527 100644 --- a/tests/tensor/test_type.py +++ b/tests/tensor/test_type.py @@ -10,9 +10,18 @@ from aesara.tensor.type import TensorType -def test_numpy_dtype(): - test_type = TensorType(np.int32, []) - assert test_type.dtype == "int32" +@pytest.mark.parametrize( + "dtype, exp_dtype", + [ + (np.int32, "int32"), + (np.dtype(np.int32), "int32"), + ("int32", "int32"), + ("floatX", config.floatX), + ], +) +def test_numpy_dtype(dtype, exp_dtype): + test_type = TensorType(dtype, []) + assert test_type.dtype == exp_dtype def test_in_same_class(): @@ -62,6 +71,18 @@ def test_convert_variable(): assert res is const_var +def test_convert_variable_mixed_specificity(): + type1 = TensorType(config.floatX, shape=(1, None, 3)) + type2 = TensorType(config.floatX, shape=(None, 5, 3)) + type3 = TensorType(config.floatX, shape=(1, 5, 3)) + + test_var1 = type1() + test_var2 = type2() + + assert type1.convert_variable(test_var2).type == type3 + assert type2.convert_variable(test_var1).type == type3 + + def test_filter_variable(): test_type = TensorType(config.floatX, []) diff --git a/tests/tensor/test_var.py b/tests/tensor/test_var.py index c67e70eaba..05127cb3d2 100644 --- a/tests/tensor/test_var.py +++ b/tests/tensor/test_var.py @@ -1,14 +1,17 @@ +from copy import copy + import numpy as np import pytest -from numpy.testing import assert_equal, assert_string_equal +from numpy.testing import assert_array_equal, assert_equal, assert_string_equal import aesara import tests.unittest_tools as utt +from aesara.compile.mode import get_default_mode from aesara.graph.basic import Constant, equal_computations from aesara.tensor import get_vector_length from aesara.tensor.basic import constant from aesara.tensor.elemwise import DimShuffle -from aesara.tensor.math import dot +from aesara.tensor.math import dot, eq from aesara.tensor.subtensor import AdvancedSubtensor, Subtensor from aesara.tensor.type import ( TensorType, @@ -18,16 +21,22 @@ dvector, iscalar, ivector, + matrices, matrix, + scalar, tensor3, ) -from aesara.tensor.type_other import MakeSlice +from aesara.tensor.type_other import MakeSlice, NoneConst from aesara.tensor.var import ( DenseTensorConstant, DenseTensorVariable, TensorConstant, TensorVariable, ) +from tests.tensor.utils import random + + +pytestmark = pytest.mark.filterwarnings("error") @pytest.mark.parametrize( @@ -212,6 +221,7 @@ def test_print_constant(): [ (tensor3(), (np.newaxis, slice(None), np.newaxis), ("x", 0, "x", 1, 2)), (cscalar(), (np.newaxis,), ("x",)), + (cscalar(), (NoneConst,), ("x",)), (matrix(), (np.newaxis,), ("x", 0, 1)), (matrix(), (np.newaxis, np.newaxis), ("x", "x", 0, 1)), (matrix(), (np.newaxis, slice(None)), ("x", 0, 1)), @@ -264,3 +274,125 @@ def test_dense_types(): x = constant(1) assert not isinstance(x, DenseTensorVariable) assert isinstance(x, DenseTensorConstant) + + +class TestTensorConstantSignature: + vals = [ + [np.nan, np.inf, 0, 1], + [np.nan, np.inf, -np.inf, 1], + [0, np.inf, -np.inf, 1], + [0, 3, -np.inf, 1], + [0, 3, np.inf, 1], + [np.nan, 3, 4, 1], + [0, 3, 4, 1], + np.nan, + np.inf, + -np.inf, + 0, + 1, + ] + + @pytest.mark.parametrize("val_1", vals) + @pytest.mark.parametrize("val_2", vals) + def test_nan_inf_constant_signature(self, val_1, val_2): + # Test that the signature of a constant tensor containing NaN and Inf + # values is correct. + # We verify that signatures of two rows i, j in the matrix above are + # equal if and only if i == j. + x = constant(val_1) + y = constant(val_2) + assert (x.signature() == y.signature()) == (val_1 is val_2) + + def test_nan_nan(self): + # Also test that nan !=0 and nan != nan. + x = scalar() + mode = get_default_mode() + if isinstance(mode, aesara.compile.debugmode.DebugMode): + # Disable the check preventing usage of NaN / Inf values. + # We first do a copy of the mode to avoid side effects on other tests. + mode = copy(mode) + mode.check_isfinite = False + f = aesara.function([x], eq(x, np.nan), mode=mode) + + assert f(0) == 0 + assert f(np.nan) == 0 + + def test_empty_hash(self): + x = constant(np.array([], dtype=np.int64)) + y = constant(np.array([], dtype=np.int64)) + + x_sig = x.signature() + y_sig = y.signature() + + assert hash(x_sig) == hash(y_sig) + + +class TestTensorInstanceMethods: + def setup_method(self): + self.vars = matrices("X", "Y") + self.vals = [ + m.astype(aesara.config.floatX) for m in [random(2, 2), random(2, 2)] + ] + + def test_repeat(self): + X, _ = self.vars + x, _ = self.vals + assert_array_equal(X.repeat(2).eval({X: x}), x.repeat(2)) + + def test_trace(self): + X, _ = self.vars + x, _ = self.vals + assert_array_equal(X.trace().eval({X: x}), x.trace()) + + def test_ravel(self): + X, _ = self.vars + x, _ = self.vals + assert_array_equal(X.ravel().eval({X: x}), x.ravel()) + + def test_diagonal(self): + X, _ = self.vars + x, _ = self.vals + assert_array_equal(X.diagonal().eval({X: x}), x.diagonal()) + assert_array_equal(X.diagonal(1).eval({X: x}), x.diagonal(1)) + assert_array_equal(X.diagonal(-1).eval({X: x}), x.diagonal(-1)) + for offset, axis1, axis2 in [(1, 0, 1), (-1, 0, 1), (0, 1, 0), (-2, 1, 0)]: + assert_array_equal( + X.diagonal(offset, axis1, axis2).eval({X: x}), + x.diagonal(offset, axis1, axis2), + ) + + def test_take(self): + X, _ = self.vars + x, _ = self.vals + indices = [1, 0, 3] + assert_array_equal(X.take(indices).eval({X: x}), x.take(indices)) + indices = [1, 0, 1] + assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1)) + indices = np.array([-10, 5, 12], dtype="int32") + assert_array_equal( + X.take(indices, 1, mode="wrap").eval({X: x}), + x.take(indices, 1, mode="wrap"), + ) + assert_array_equal( + X.take(indices, -1, mode="wrap").eval({X: x}), + x.take(indices, -1, mode="wrap"), + ) + assert_array_equal( + X.take(indices, 1, mode="clip").eval({X: x}), + x.take(indices, 1, mode="clip"), + ) + assert_array_equal( + X.take(indices, -1, mode="clip").eval({X: x}), + x.take(indices, -1, mode="clip"), + ) + # Test error handling + with pytest.raises(IndexError): + X.take(indices).eval({X: x}) + with pytest.raises(IndexError): + (2 * X.take(indices)).eval({X: x}) + with pytest.raises(TypeError): + X.take([0.0]) + indices = [[1, 0, 1], [0, 1, 1]] + assert_array_equal(X.take(indices, 1).eval({X: x}), x.take(indices, 1)) + # Test equivalent advanced indexing + assert_array_equal(X[:, indices].eval({X: x}), x[:, indices]) diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index adb64efaa1..c59c99d99d 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -109,89 +109,6 @@ def eval_outputs(outputs, ops=(), mode=None): return variables -def get_numeric_subclasses(cls=np.number, ignore=None): - """Return subclasses of `cls` in the numpy scalar hierarchy. - - We only return subclasses that correspond to unique data types. The - hierarchy can be seen here: - http://docs.scipy.org/doc/numpy/reference/arrays.scalars.html - """ - if ignore is None: - ignore = [] - rval = [] - dtype = np.dtype(cls) - dtype_num = dtype.num - if dtype_num not in ignore: - # Safety check: we should be able to represent 0 with this data type. - np.array(0, dtype=dtype) - rval.append(cls) - ignore.append(dtype_num) - for sub_ in cls.__subclasses__(): - rval += [c for c in get_numeric_subclasses(sub_, ignore=ignore)] - return rval - - -def get_numeric_types( - with_int=True, with_float=True, with_complex=False, only_aesara_types=True -): - """Return NumPy numeric data types. - - Parameters - ---------- - with_int - Whether to include integer types. - with_float - Whether to include floating point types. - with_complex - Whether to include complex types. - only_aesara_types - If ``True``, then numpy numeric data types that are not supported by - Aesara are ignored (i.e. those that are not declared in - ``scalar/basic.py``). - - Returns - ------- - A list of unique data type objects. Note that multiple data types may share - the same string representation, but can be differentiated through their - `num` attribute. - - Note that when `only_aesara_types` is True we could simply return the list - of types defined in the `scalar` module. However with this function we can - test more unique dtype objects, and in the future we may use it to - automatically detect new data types introduced in numpy. - """ - if only_aesara_types: - aesara_types = [d.dtype for d in aesara.scalar.all_types] - rval = [] - - def is_within(cls1, cls2): - # Return True if scalars defined from `cls1` are within the hierarchy - # starting from `cls2`. - # The third test below is to catch for instance the fact that - # one can use ``dtype=numpy.number`` and obtain a float64 scalar, even - # though `numpy.number` is not under `numpy.floating` in the class - # hierarchy. - return ( - cls1 is cls2 - or issubclass(cls1, cls2) - or isinstance(np.array([0], dtype=cls1)[0], cls2) - ) - - for cls in get_numeric_subclasses(): - dtype = np.dtype(cls) - if ( - (not with_complex and is_within(cls, np.complexfloating)) - or (not with_int and is_within(cls, np.integer)) - or (not with_float and is_within(cls, np.floating)) - or (only_aesara_types and dtype not in aesara_types) - ): - # Ignore this class. - continue - rval.append([str(dtype), dtype, dtype.num]) - # We sort it to be deterministic, then remove the string and num elements. - return [x[1] for x in sorted(rval, key=str)] - - def _numpy_checker(x, y): """Checks if `x.data` and `y.data` have the same contents. diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 7d7c92a391..50dcf8170a 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -14,8 +14,6 @@ NullTypeGradError, Rop, UndefinedGrad, - consider_constant, - consider_constant_, disconnected_grad, disconnected_grad_, grad, @@ -769,37 +767,45 @@ def test_subgraph_grad(): class TestConsiderConstant: - def setup_method(self): - self.rng = np.random.default_rng(seed=utt.fetch_seed()) - def test_op_removed(self): + from aesara.gradient import ConsiderConstant, consider_constant + x = matrix("x") - y = x * consider_constant(x) + + with pytest.deprecated_call(): + y = x * consider_constant(x) + f = aesara.function([x], y) - # need to refer to aesara.consider_constant_ here, - # aesara.consider_constant is a wrapper function! - assert consider_constant_ not in [node.op for node in f.maker.fgraph.toposort()] + + assert ConsiderConstant not in [ + type(node.op) for node in f.maker.fgraph.toposort() + ] def test_grad(self): - a = np.asarray(self.rng.standard_normal((5, 5)), dtype=config.floatX) + from aesara.gradient import consider_constant - x = matrix("x") + rng = np.random.default_rng(seed=utt.fetch_seed()) - expressions_gradients = [ - (x * consider_constant(x), x), - (x * consider_constant(exp(x)), exp(x)), - (consider_constant(x), at.constant(0.0)), - (x**2 * consider_constant(x), 2 * x**2), - ] + a = np.asarray(rng.standard_normal((5, 5)), dtype=config.floatX) - for expr, expr_grad in expressions_gradients: - g = grad(expr.sum(), x) - # gradient according to aesara - f = aesara.function([x], g, on_unused_input="ignore") - # desired gradient - f2 = aesara.function([x], expr_grad, on_unused_input="ignore") + x = matrix("x") - assert np.allclose(f(a), f2(a)) + with pytest.deprecated_call(): + expressions_gradients = [ + (x * consider_constant(x), x), + (x * consider_constant(exp(x)), exp(x)), + (consider_constant(x), at.constant(0.0)), + (x**2 * consider_constant(x), 2 * x**2), + ] + + for expr, expr_grad in expressions_gradients: + g = grad(expr.sum(), x) + # gradient according to aesara + f = aesara.function([x], g, on_unused_input="ignore") + # desired gradient + f2 = aesara.function([x], expr_grad, on_unused_input="ignore") + + assert np.allclose(f(a), f2(a)) class TestZeroGrad: diff --git a/tests/test_ifelse.py b/tests/test_ifelse.py index 1eea3217a6..5eda9f2fba 100644 --- a/tests/test_ifelse.py +++ b/tests/test_ifelse.py @@ -6,6 +6,7 @@ import aesara import aesara.ifelse +import aesara.sparse import aesara.tensor.basic as at from aesara import function from aesara.compile.mode import Mode, get_mode @@ -14,15 +15,19 @@ from aesara.ifelse import IfElse, ifelse from aesara.link.c.type import generic from aesara.tensor.math import eq -from aesara.tensor.type import col, iscalar, matrix, row, scalar, tensor3, vector +from aesara.tensor.type import ( + col, + iscalar, + ivector, + matrix, + row, + scalar, + tensor3, + vector, +) from tests import unittest_tools as utt -__docformat__ = "restructedtext en" -__authors__ = "Razvan Pascanu " "PyMC Development Team " "Aesara Developers " -__copyright__ = "(c) 2010, Universite de Montreal" - - class TestIfelse(utt.OptimizationTestMixin): mode = None dtype = aesara.config.floatX @@ -41,7 +46,7 @@ def test_wrong_n_outs(self): with pytest.raises(ValueError): IfElse(0)(c, x, x) - def test_const_Op_argument(self): + def test_const_false_branch(self): x = vector("x", dtype=self.dtype) y = np.array([2.0, 3.0], dtype=self.dtype) c = iscalar("c") @@ -320,25 +325,24 @@ def test_broadcast_mismatch(self): with pytest.raises(TypeError): ifelse(cond, y, x) - def test_sparse_tensor_error(self): - pytest.importorskip("scipy", minversion="0.7.0") + def test_sparse_conversions(self): - import aesara.sparse + from aesara.sparse import matrix rng = np.random.default_rng(utt.fetch_seed()) data = rng.random((2, 3)).astype(self.dtype) x = self.shared(data) - y = aesara.sparse.matrix("csc", dtype=self.dtype, name="y") - z = aesara.sparse.matrix("csr", dtype=self.dtype, name="z") + y = matrix("csc", dtype=self.dtype, name="y") + z = matrix("csr", dtype=self.dtype, name="z") cond = iscalar("cond") - with pytest.raises(NotImplementedError): + with pytest.raises(TypeError, match=".*do not match."): ifelse(cond, x, y) - with pytest.raises(NotImplementedError): + with pytest.raises(TypeError, match=".*do not match."): ifelse(cond, y, x) - with pytest.raises(NotImplementedError): + with pytest.raises(TypeError): ifelse(cond, x, z) - with pytest.raises(NotImplementedError): + with pytest.raises(TypeError): ifelse(cond, z, x) with pytest.raises(TypeError): ifelse(cond, y, z) @@ -527,6 +531,39 @@ def test_str(self): res.owner.op.as_view = True assert str(res.owner).startswith("if{name,inplace}") + @pytest.mark.parametrize( + "x_shape, y_shape, x_val, y_val, exp_shape", + [ + ((2,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)), + ((None,), (3,), np.r_[1.0, 2.0], np.r_[1.0, 2.0, 3.0], (None,)), + ((3,), (None,), np.r_[1.0, 2.0, 3.0], np.r_[1.0, 2.0], (None,)), + ((2, 1), (None, 1), np.c_[[1.0, 2.0]], np.c_[[1.0, 2.0, 3.0]], (None, 1)), + ((3,), (3,), np.r_[1.0, 2.0, 3.0], np.r_[1.0, 2.0, 3.0], (3,)), + ((1,), (3,), np.r_[1.0], np.r_[1.0, 2.0, 3.0], (None,)), + ], + ) + def test_static_branch_shapes(self, x_shape, y_shape, x_val, y_val, exp_shape): + + x = at.tensor(dtype=self.dtype, shape=x_shape, name="x") + y = at.tensor(dtype=self.dtype, shape=y_shape, name="y") + c = iscalar("c") + z = IfElse(1)(c, x, y) + assert z.type.shape == exp_shape + + f = function([c, x, y], z, mode=self.mode) + + x_val = x_val.astype(self.dtype) + y_val = y_val.astype(self.dtype) + val = f(0, x_val, y_val) + assert np.array_equal(val, y_val) + + def test_nonscalar_condition(self): + x = vector("x") + y = vector("y") + c = ivector("c") + with pytest.raises(TypeError, match="The condition argument"): + IfElse(1)(c, x, y) + class IfElseIfElseIf(Op): def __init__(self, inplace=False): diff --git a/tests/test_raise_op.py b/tests/test_raise_op.py index c5911ebfeb..8236ba28f1 100644 --- a/tests/test_raise_op.py +++ b/tests/test_raise_op.py @@ -7,6 +7,7 @@ from aesara.compile.mode import OPT_FAST_RUN, Mode from aesara.graph.basic import Constant, equal_computations from aesara.raise_op import Assert, CheckAndRaise, assert_op +from aesara.scalar.basic import ScalarType, float64 from aesara.sparse import as_sparse_variable from tests import unittest_tools as utt @@ -94,6 +95,50 @@ def test_CheckAndRaise_basic_c(linker): assert np.array_equal(y_fn(x_val), [x_val]) +@pytest.mark.parametrize( + "linker", + [ + pytest.param( + "cvm", + marks=pytest.mark.skipif( + not aesara.config.cxx, + reason="G++ not available, so we need to skip this test.", + ), + ), + "py", + ], +) +def test_perform_CheckAndRaise_scalar(linker): + exc_msg = "this is the exception" + check_and_raise = CheckAndRaise(CustomException, exc_msg) + + val = float64("val") + conds = (val > 0, val > 3) + y = check_and_raise(val, *conds) + + assert all(isinstance(i.type, ScalarType) for i in y.owner.inputs) + assert isinstance(y.type, ScalarType) + + mode = Mode(linker=linker) + y_fn = aesara.function([val], y, mode=mode) + + with pytest.raises(CustomException, match=exc_msg): + y_fn(0.0) + + assert y_fn(4.0) == 4.0 + + if linker == "cvm": + assert isinstance( + y_fn.maker.fgraph.outputs[0].owner.inputs[0].owner.op, CheckAndRaise + ) + assert hasattr(y_fn.vm.thunks[-2], "cthunk") + + (y_grad,) = aesara.grad(y, [val]) + y_fn = aesara.function([val], y_grad, mode=Mode(linker, OPT_FAST_RUN)) + + assert np.array_equal(y_fn(4.0), 1.0) + + class TestCheckAndRaiseInferShape(utt.InferShapeTester): def setup_method(self): super().setup_method() @@ -117,6 +162,16 @@ def test_infer_shape(self): [admat, adscal, bdscal], [out], [admat_val, adscal_val, bdscal_val], Assert ) + def test_infer_shape_scalar(self): + adscal = float64("adscal") + bdscal = float64("bdscal") + adscal_val = np.random.random() + bdscal_val = np.random.random() + 1 + out = assert_op(adscal, bdscal) + self._compile_and_check( + [adscal, bdscal], [out], [adscal_val, bdscal_val], Assert + ) + def test_CheckAndRaise_sparse_variable(): check_and_raise = CheckAndRaise(ValueError, "sparse_check") diff --git a/tests/typed_list/test_opt.py b/tests/typed_list/test_rewriting.py similarity index 88% rename from tests/typed_list/test_opt.py rename to tests/typed_list/test_rewriting.py index df23d3f484..167424cfb8 100644 --- a/tests/typed_list/test_opt.py +++ b/tests/typed_list/test_rewriting.py @@ -17,7 +17,9 @@ def test_reverse_inplace(self): )() z = Reverse()(mySymbolicMatricesList) - m = aesara.compile.mode.get_default_mode().including("typed_list_inplace_opt") + m = aesara.compile.mode.get_default_mode().including( + "typed_list_inplace_rewrite" + ) f = aesara.function( [In(mySymbolicMatricesList, borrow=True, mutable=True)], z, @@ -38,7 +40,9 @@ def test_append_inplace(self): )() mySymbolicMatrix = matrix() z = Append()(mySymbolicMatricesList, mySymbolicMatrix) - m = aesara.compile.mode.get_default_mode().including("typed_list_inplace_opt") + m = aesara.compile.mode.get_default_mode().including( + "typed_list_inplace_rewrite" + ) f = aesara.function( [ In(mySymbolicMatricesList, borrow=True, mutable=True), @@ -66,7 +70,9 @@ def test_extend_inplace(self): )() z = Extend()(mySymbolicMatricesList1, mySymbolicMatricesList2) - m = aesara.compile.mode.get_default_mode().including("typed_list_inplace_opt") + m = aesara.compile.mode.get_default_mode().including( + "typed_list_inplace_rewrite" + ) f = aesara.function( [ In(mySymbolicMatricesList1, borrow=True, mutable=True), @@ -91,7 +97,9 @@ def test_insert_inplace(self): mySymbolicMatrix = matrix() z = Insert()(mySymbolicMatricesList, mySymbolicIndex, mySymbolicMatrix) - m = aesara.compile.mode.get_default_mode().including("typed_list_inplace_opt") + m = aesara.compile.mode.get_default_mode().including( + "typed_list_inplace_rewrite" + ) f = aesara.function( [ @@ -117,7 +125,9 @@ def test_remove_inplace(self): )() mySymbolicMatrix = matrix() z = Remove()(mySymbolicMatricesList, mySymbolicMatrix) - m = aesara.compile.mode.get_default_mode().including("typed_list_inplace_opt") + m = aesara.compile.mode.get_default_mode().including( + "typed_list_inplace_rewrite" + ) f = aesara.function( [ In(mySymbolicMatricesList, borrow=True, mutable=True),