Skip to content

Commit

Permalink
Introduce aesara.as_symbolic
Browse files Browse the repository at this point in the history
`aesara.as_symbolic` is a new function that converts all eligible objects to
`Variable`s.  Unlike `aesara.tensor.as_tensor_variable`, `as_symbolic` will
convert `None`s and `slice`s, or any other types that have equivalent Aesara `Type`s.
  • Loading branch information
brandonwillard committed Jan 11, 2022
1 parent 3edbbc4 commit a90c14d
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 22 deletions.
58 changes: 55 additions & 3 deletions aesara/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import logging
import os
import sys
from functools import singledispatch
from typing import Any, NoReturn, Optional


aesara_logger = logging.getLogger("aesara")
Expand Down Expand Up @@ -76,6 +78,51 @@ def disable_log_handler(logger=aesara_logger, handler=logging_default_handler):
# very rarely.
__api_version__ = 1

# isort: off
from aesara.graph.basic import Variable, clone_replace

# isort: on


def as_symbolic(x: Any, name: Optional[str] = None, **kwargs) -> Variable:
"""Convert `x` into an equivalent Aesara `Variable`.
Parameters
----------
x
The object to be converted into a ``Variable`` type. A
``numpy.ndarray`` argument will not be copied, but a list of numbers
will be copied to make an ``numpy.ndarray``.
name
If a new ``Variable`` instance is created, it will be named with this
string.
kwargs
Options passed to the appropriate sub-dispatch functions. For example,
`ndim` and `dtype` can be passed when `x` is an `numpy.ndarray` or
`Number` type.
Raises
------
TypeError
If `x` cannot be converted to a `Variable`.
"""
if isinstance(x, Variable):
return x

res = _as_symbolic(x, **kwargs)
res.name = name
return res


@singledispatch
def _as_symbolic(x, **kwargs) -> Variable:
from aesara.tensor import as_tensor_variable

return as_tensor_variable(x, **kwargs)


# isort: off
from aesara import scalar, tensor
from aesara.compile import (
In,
Expand All @@ -95,6 +142,8 @@ def disable_log_handler(logger=aesara_logger, handler=logging_default_handler):
from aesara.printing import pp, pprint
from aesara.updates import OrderedUpdates

# isort: on


if (
config.device.startswith("cuda")
Expand Down Expand Up @@ -126,13 +175,16 @@ def get_scalar_constant_value(v):
return tensor.get_scalar_constant_value(v)


# isort: off
import aesara.tensor.random.var
from aesara.graph.basic import clone_replace
from aesara.scan import checkpoints
from aesara.scan.basic import scan
from aesara.scan.views import foldl, foldr, map, reduce

# isort: on


# Some config variables are registered by submodules. Only after all those imports
# were executed, we can warn about remaining flags provided by the user through AESARA_FLAGS.
# Some config variables are registered by submodules. Only after all those
# imports were executed, we can warn about remaining flags provided by the user
# through AESARA_FLAGS.
config.warn_unused_flags()
3 changes: 2 additions & 1 deletion aesara/graph/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from aesara.graph.op import Op
from aesara.graph.utils import AssocList
from aesara.misc.ordered_set import OrderedSet
from aesara.raise_op import CheckAndRaise
from aesara.utils import flatten


Expand Down Expand Up @@ -789,6 +788,8 @@ def add_requirements(self, fgraph):
fgraph.attach_feature(MergeFeature())

def apply(self, fgraph):
from aesara.raise_op import CheckAndRaise

# Constant and non-constant are now applied in the same phase.
# I am not sure why, but it seems to be faster this way.
sched = fgraph.merge_feature.scheduled
Expand Down
28 changes: 15 additions & 13 deletions aesara/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,34 @@

def as_tensor_variable(
x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
) -> Callable:
"""Convert `x` into the appropriate ``TensorType``.
) -> Variable:
"""Convert `x` into an equivalent `TensorVariable`.
This function is often used by ``make_node`` methods of ``Op`` subclasses
to turn ndarrays, numbers, ``Scalar`` instances, ``Apply`` instances and
``TensorType`` instances into valid input list elements.
This function can be used to turn ndarrays, numbers, `Scalar` instances,
`Apply` instances and `TensorVariable` instances into valid input list
elements.
See `aesara.as_symbolic` for a more general conversion function.
Parameters
----------
x
The object to be converted into a ``Variable`` type. A
``numpy.ndarray`` argument will not be copied, but a list of numbers
will be copied to make an ``numpy.ndarray``.
The object to be converted into a `Variable` type. A
`numpy.ndarray` argument will not be copied, but a list of numbers
will be copied to make an `numpy.ndarray`.
name
If a new ``Variable`` instance is created, it will be named with this
If a new `Variable` instance is created, it will be named with this
string.
ndim
Return a ``Variable`` with this many dimensions.
Return a `Variable` with this many dimensions.
dtype
The dtype to use for the resulting ``Variable``. If `x` is already
a ``Variable`` type, then the dtype will not be changed.
The dtype to use for the resulting `Variable`. If `x` is already
a `Variable` type, then the dtype will not be changed.
Raises
------
TypeError
If `x` cannot be converted to a ``TensorType`` Variable.
If `x` cannot be converted to a `TensorVariable`.
"""
return _as_tensor_variable(x, name, ndim, **kwargs)
Expand Down
21 changes: 17 additions & 4 deletions aesara/tensor/type_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import numpy as np

import aesara
from aesara import _as_symbolic
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Constant
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import Op
from aesara.graph.type import Generic, Type
from aesara.tensor.type import integer_dtypes
Expand Down Expand Up @@ -108,6 +109,15 @@ def __str__(self):
SliceType.Constant = SliceConstant


@_as_symbolic.register(slice)
def as_symbolic_slice(x, **kwargs):

if any(isinstance(i, Variable) for i in (x.start, x.stop, x.step)):
return make_slice(x)

return SliceConstant(slicetype, x)


class NoneTypeT(Generic):
"""
Inherit from Generic to have c code working.
Expand All @@ -129,9 +139,12 @@ def may_share_memory(a, b):

none_type_t = NoneTypeT()

# This is a variable instance. It can be used only once per fgraph.
# So use NoneConst.clone() before using it in an Aesara graph.
# Use NoneConst.equals(x) to check if two variable are NoneConst.
NoneConst = Constant(none_type_t, None, name="NoneConst")


@_as_symbolic.register(type(None))
def as_symbolic_None(x, **kwargs):
return NoneConst


__all__ = ["make_slice", "slicetype", "none_type_t", "NoneConst"]
21 changes: 20 additions & 1 deletion tests/tensor/test_type_other.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
""" This file don't test everything. It only test one past crash error."""

import aesara
from aesara import as_symbolic
from aesara.graph.basic import Constant
from aesara.tensor.math import argmax
from aesara.tensor.type import iscalar, vector
from aesara.tensor.type_other import MakeSlice, NoneConst, NoneTypeT, make_slice
from aesara.tensor.type_other import (
MakeSlice,
NoneConst,
NoneTypeT,
SliceConstant,
make_slice,
)


def test_make_slice_merge():
Expand Down Expand Up @@ -44,3 +51,15 @@ def test_none_Constant():
kwargs = {"mode": "FAST_RUN"}
f = aesara.function([x], [y], **kwargs)
pickle.loads(pickle.dumps(f))


def test_as_symbolic():

res = as_symbolic(None)
assert res is NoneConst

res = as_symbolic(slice(iscalar()))
assert res.owner.op == make_slice

res = as_symbolic(slice(1, 2))
assert isinstance(res, SliceConstant)

0 comments on commit a90c14d

Please sign in to comment.