From b088cc8f8f056aaa03d2d12e621ad26ddffc4320 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 11 Jan 2022 14:58:49 -0600 Subject: [PATCH] Introduce aesara.as_symbolic `aesara.as_symbolic` is a new function that converts all eligible objects to `Variable`s. Unlike `aesara.tensor.as_tensor_variable`, `as_symbolic` will convert `None`s and `slice`s, or any other types that have equivalent Aesara `Type`s. --- aesara/__init__.py | 58 +++++++++++++++++++++++++++++++-- aesara/graph/opt.py | 3 +- aesara/tensor/__init__.py | 28 ++++++++-------- aesara/tensor/type_other.py | 21 +++++++++--- tests/tensor/test_type_other.py | 21 +++++++++++- 5 files changed, 109 insertions(+), 22 deletions(-) diff --git a/aesara/__init__.py b/aesara/__init__.py index 715b7b0626..506f066c62 100644 --- a/aesara/__init__.py +++ b/aesara/__init__.py @@ -26,6 +26,8 @@ import logging import os import sys +from functools import singledispatch +from typing import Any, NoReturn, Optional aesara_logger = logging.getLogger("aesara") @@ -76,6 +78,51 @@ def disable_log_handler(logger=aesara_logger, handler=logging_default_handler): # very rarely. __api_version__ = 1 +# isort: off +from aesara.graph.basic import Variable, clone_replace + +# isort: on + + +def as_symbolic(x: Any, name: Optional[str] = None, **kwargs) -> Variable: + """Convert `x` into an equivalent Aesara `Variable`. + + Parameters + ---------- + x + The object to be converted into a ``Variable`` type. A + ``numpy.ndarray`` argument will not be copied, but a list of numbers + will be copied to make an ``numpy.ndarray``. + name + If a new ``Variable`` instance is created, it will be named with this + string. + kwargs + Options passed to the appropriate sub-dispatch functions. For example, + `ndim` and `dtype` can be passed when `x` is an `numpy.ndarray` or + `Number` type. + + Raises + ------ + TypeError + If `x` cannot be converted to a `Variable`. + + """ + if isinstance(x, Variable): + return x + + res = _as_symbolic(x, **kwargs) + res.name = name + return res + + +@singledispatch +def _as_symbolic(x, **kwargs) -> Variable: + from aesara.tensor import as_tensor_variable + + return as_tensor_variable(x, **kwargs) + + +# isort: off from aesara import scalar, tensor from aesara.compile import ( In, @@ -95,6 +142,8 @@ def disable_log_handler(logger=aesara_logger, handler=logging_default_handler): from aesara.printing import pp, pprint from aesara.updates import OrderedUpdates +# isort: on + if ( config.device.startswith("cuda") @@ -126,13 +175,16 @@ def get_scalar_constant_value(v): return tensor.get_scalar_constant_value(v) +# isort: off import aesara.tensor.random.var -from aesara.graph.basic import clone_replace from aesara.scan import checkpoints from aesara.scan.basic import scan from aesara.scan.views import foldl, foldr, map, reduce +# isort: on + -# Some config variables are registered by submodules. Only after all those imports -# were executed, we can warn about remaining flags provided by the user through AESARA_FLAGS. +# Some config variables are registered by submodules. Only after all those +# imports were executed, we can warn about remaining flags provided by the user +# through AESARA_FLAGS. config.warn_unused_flags() diff --git a/aesara/graph/opt.py b/aesara/graph/opt.py index 805a072435..f4fc296d8c 100644 --- a/aesara/graph/opt.py +++ b/aesara/graph/opt.py @@ -37,7 +37,6 @@ from aesara.graph.op import Op from aesara.graph.utils import AssocList from aesara.misc.ordered_set import OrderedSet -from aesara.raise_op import CheckAndRaise from aesara.utils import flatten @@ -789,6 +788,8 @@ def add_requirements(self, fgraph): fgraph.attach_feature(MergeFeature()) def apply(self, fgraph): + from aesara.raise_op import CheckAndRaise + # Constant and non-constant are now applied in the same phase. # I am not sure why, but it seems to be faster this way. sched = fgraph.merge_feature.scheduled diff --git a/aesara/tensor/__init__.py b/aesara/tensor/__init__.py index 43c6d16621..fcd7f1b313 100644 --- a/aesara/tensor/__init__.py +++ b/aesara/tensor/__init__.py @@ -9,32 +9,34 @@ def as_tensor_variable( x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs -) -> Callable: - """Convert `x` into the appropriate ``TensorType``. +) -> Variable: + """Convert `x` into an equivalent `TensorVariable`. - This function is often used by ``make_node`` methods of ``Op`` subclasses - to turn ndarrays, numbers, ``Scalar`` instances, ``Apply`` instances and - ``TensorType`` instances into valid input list elements. + This function can be used to turn ndarrays, numbers, `Scalar` instances, + `Apply` instances and `TensorVariable` instances into valid input list + elements. + + See `aesara.as_symbolic` for a more general conversion function. Parameters ---------- x - The object to be converted into a ``Variable`` type. A - ``numpy.ndarray`` argument will not be copied, but a list of numbers - will be copied to make an ``numpy.ndarray``. + The object to be converted into a `Variable` type. A + `numpy.ndarray` argument will not be copied, but a list of numbers + will be copied to make an `numpy.ndarray`. name - If a new ``Variable`` instance is created, it will be named with this + If a new `Variable` instance is created, it will be named with this string. ndim - Return a ``Variable`` with this many dimensions. + Return a `Variable` with this many dimensions. dtype - The dtype to use for the resulting ``Variable``. If `x` is already - a ``Variable`` type, then the dtype will not be changed. + The dtype to use for the resulting `Variable`. If `x` is already + a `Variable` type, then the dtype will not be changed. Raises ------ TypeError - If `x` cannot be converted to a ``TensorType`` Variable. + If `x` cannot be converted to a `TensorVariable`. """ return _as_tensor_variable(x, name, ndim, **kwargs) diff --git a/aesara/tensor/type_other.py b/aesara/tensor/type_other.py index c51b1ad43c..37c747d2ab 100644 --- a/aesara/tensor/type_other.py +++ b/aesara/tensor/type_other.py @@ -5,8 +5,9 @@ import numpy as np import aesara +from aesara import _as_symbolic from aesara.gradient import DisconnectedType -from aesara.graph.basic import Apply, Constant +from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.op import Op from aesara.graph.type import Generic, Type from aesara.tensor.type import integer_dtypes @@ -108,6 +109,15 @@ def __str__(self): SliceType.Constant = SliceConstant +@_as_symbolic.register(slice) +def as_symbolic_slice(x, **kwargs): + + if any(isinstance(i, Variable) for i in (x.start, x.stop, x.step)): + return make_slice(x) + + return SliceConstant(slicetype, x) + + class NoneTypeT(Generic): """ Inherit from Generic to have c code working. @@ -129,9 +139,12 @@ def may_share_memory(a, b): none_type_t = NoneTypeT() -# This is a variable instance. It can be used only once per fgraph. -# So use NoneConst.clone() before using it in an Aesara graph. -# Use NoneConst.equals(x) to check if two variable are NoneConst. NoneConst = Constant(none_type_t, None, name="NoneConst") + +@_as_symbolic.register(type(None)) +def as_symbolic_None(x, **kwargs): + return NoneConst + + __all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst"] diff --git a/tests/tensor/test_type_other.py b/tests/tensor/test_type_other.py index 344bc302f2..a5cbad6e1e 100644 --- a/tests/tensor/test_type_other.py +++ b/tests/tensor/test_type_other.py @@ -1,10 +1,17 @@ """ This file don't test everything. It only test one past crash error.""" import aesara +from aesara import as_symbolic from aesara.graph.basic import Constant from aesara.tensor.math import argmax from aesara.tensor.type import iscalar, vector -from aesara.tensor.type_other import MakeSlice, NoneConst, NoneTypeT, make_slice +from aesara.tensor.type_other import ( + MakeSlice, + NoneConst, + NoneTypeT, + SliceConstant, + make_slice, +) def test_make_slice_merge(): @@ -44,3 +51,15 @@ def test_none_Constant(): kwargs = {"mode": "FAST_RUN"} f = aesara.function([x], [y], **kwargs) pickle.loads(pickle.dumps(f)) + + +def test_as_symbolic(): + + res = as_symbolic(None) + assert res is NoneConst + + res = as_symbolic(slice(iscalar())) + assert res.owner.op == make_slice + + res = as_symbolic(slice(1, 2)) + assert isinstance(res, SliceConstant)