Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor get_scalar_constant_value #643

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
7a5b1ff
Fix strict type equality in aesara.tensor.subtensor_opt
brandonwillard Oct 30, 2021
5954079
Use Type instead of CType in get_scalar_constant_value
brandonwillard Oct 30, 2021
cd3e134
Fix strict type checking in aesara.tensor.basic_opt
brandonwillard Oct 30, 2021
2f1b3eb
Fix local_join_make_vector docstring
brandonwillard Oct 30, 2021
2586e67
Clarify assert statement in test_get_vector_length
brandonwillard Oct 30, 2021
9b8f91e
Remove unnecessary print statements and comments
brandonwillard Oct 30, 2021
d471db3
Add local Subtensor of Shape canonicalization
brandonwillard Oct 30, 2021
4241f8a
Add local Shape of SpecifyShape canonicalization
brandonwillard Oct 30, 2021
5b4a857
Add local Shape_i of broadcastable dimension canonicalization
brandonwillard Oct 30, 2021
e77a484
Add missing info to Rebroadcast str representation
brandonwillard Oct 30, 2021
5e3113f
Fix local_upcast_elemwise_constant_inputs docstring
brandonwillard Oct 30, 2021
8e57cc8
Remove use of numpy_scalar in get_scalar_constant_value
brandonwillard Oct 30, 2021
cf0b2c9
Remove impossible condition in get_scalar_constant_value logic
brandonwillard Oct 30, 2021
c8ca89f
Make aesara.tensor.basic exceptions use f-strings
brandonwillard Oct 31, 2021
d142225
Simplify as_tensor_variable imports in aesara.tensor.elemwise
brandonwillard Oct 31, 2021
b1a8fa9
Replace old style exception in aesara.tensor.subtensor_opt
brandonwillard Oct 31, 2021
c68677e
Add type annotations to aesara.tensor.subtensor.as_index_constant
brandonwillard Oct 31, 2021
350268b
Refactor get_canonical_form_slice so that it uses as_index_literal
brandonwillard Oct 31, 2021
3e2c472
Use pytest's parameterize for TestSubtensor.test_ellipsis
brandonwillard Oct 31, 2021
3b5ec3d
Fix flaky RNG usage in tests.tensor.test_math.test_cov
brandonwillard Nov 1, 2021
6825e02
Simplify construction of aesara.tensor.basic_opt.local_elemwise_alloc
brandonwillard Nov 2, 2021
7315bc7
Refactor tests.tensor.test_basic_opt tests
brandonwillard Nov 2, 2021
582f19f
Move _fill_chain to aesara.tensor.math_opt
brandonwillard Nov 9, 2021
e5df6aa
Use register_* decorators in basic_opt and math_opt
brandonwillard Nov 9, 2021
4359ca7
Correctly set tracked Op on local_track_shape_i
brandonwillard Nov 9, 2021
7b9b445
Clean up comments in basic_opt and math_opt
brandonwillard Nov 9, 2021
066709c
Prevent unnecessary shadowing of builtin input
brandonwillard Nov 10, 2021
d7358d3
Extract get_constant from AlgebraicCanonizer
brandonwillard Nov 10, 2021
d26b992
Refactor local_add_specialize and test_local_add_specialize
brandonwillard Nov 10, 2021
e792667
Update verbose optimizer format and print individual steps in LocalOp…
brandonwillard Nov 14, 2021
20601b7
Remove debug print message in aesara.ifelse.CondMerge
brandonwillard Nov 14, 2021
63dd23a
Move scalarconsts_rest to math_opt
brandonwillard Nov 10, 2021
9144137
Add missing get_scalar_constant_value tests
brandonwillard Nov 10, 2021
2563350
Refactor test_local_zero_div
brandonwillard Nov 10, 2021
1d14ee9
Make sure repeats argument is int64
brandonwillard Nov 10, 2021
287f7a0
Implement basic rewrites for Unique
brandonwillard Nov 10, 2021
3a3b122
Remove scan__debug config option
brandonwillard Nov 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions aesara/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,13 +1462,6 @@ def add_scan_configvars():
in_c_key=False,
)

config.add(
"scan__debug",
"If True, enable extra verbose output related to scan",
BoolParam(False),
in_c_key=False,
)


def _get_default_gpuarray__cache_path():
return os.path.join(config.compiledir, "gpuarray_kernels")
Expand Down
30 changes: 6 additions & 24 deletions aesara/graph/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,12 +553,9 @@ def replace_all_validate(
self, fgraph, replacements, reason=None, verbose=None, **kwargs
):
chk = fgraph.checkpoint()

if verbose is None:
verbose = config.optimizer_verbose
if config.scan__debug:
from aesara.scan.op import Scan

scans = [n for n in fgraph.apply_nodes if isinstance(n.op, Scan)]

for r, new_r in replacements:
try:
Expand Down Expand Up @@ -596,29 +593,14 @@ def replace_all_validate(
except Exception as e:
fgraph.revert(chk)
if verbose:
print(f"validate failed on node {r}.\n Reason: {reason}, {e}")
raise
if config.scan__debug:
from aesara.scan.op import Scan

scans2 = [n for n in fgraph.apply_nodes if isinstance(n.op, Scan)]
nb = len(scans)
nb2 = len(scans2)
if nb2 > nb:
print(
"Extra scan introduced",
nb,
nb2,
getattr(reason, "name", reason),
r,
new_r,
)
elif nb2 < nb:
print(
"Scan removed", nb, nb2, getattr(reason, "name", reason), r, new_r
f"optimizer: validate failed on node {r}.\n Reason: {reason}, {e}"
)
raise

if verbose:
print(reason, r, new_r)
print(f"optimizer: rewrite {reason} replaces {r} with {new_r}")

# The return is needed by replace_all_validate_remove
return chk

Expand Down
2 changes: 1 addition & 1 deletion aesara/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def replace(
if verbose is None:
verbose = config.optimizer_verbose
if verbose:
print(reason, var, new_var)
print(f"optimizer: rewrite {reason} replaces {var} with {new_var}")

new_var = var.type.filter_variable(new_var, allow_convert=True)

Expand Down
4 changes: 4 additions & 0 deletions aesara/graph/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,10 @@ def transform(self, fgraph, node):
new_vars = new_repl
else: # It must be a dict
new_vars = list(new_repl.values())

if config.optimizer_verbose:
print(f"optimizer: rewrite {opt} replaces {node} with {new_repl}")

if self.profile:
self.node_created[opt] += len(
list(applys_between(fgraph.variables, new_vars))
Expand Down
1 change: 0 additions & 1 deletion aesara/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,6 @@ def apply(self, fgraph):
gpu=False,
name=mn_name + "&" + pl_name,
)
print("here")
new_outs = new_ifelse(*new_ins, return_list=True)
new_outs = [clone_replace(x) for x in new_outs]
old_outs = []
Expand Down
89 changes: 21 additions & 68 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import builtins
import logging
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from functools import partial
from numbers import Number
Expand All @@ -26,7 +25,7 @@
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.op import COp, Op
from aesara.graph.params_type import ParamsType
from aesara.graph.type import CType
from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray
from aesara.printing import min_informative_str, pprint
from aesara.scalar import int32
Expand All @@ -38,7 +37,7 @@
get_vector_length,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
from aesara.tensor.exceptions import EmptyConstantError, NotScalarConstantError
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.shape import (
Shape,
Shape_i,
Expand Down Expand Up @@ -121,7 +120,7 @@ def _as_tensor_Scalar(x, name, ndim, **kwargs):
def _as_tensor_Variable(x, name, ndim, **kwargs):
if not isinstance(x.type, TensorType):
raise TypeError(
"Tensor type field must be a TensorType; found {}.".format(type(x.type))
f"Tensor type field must be a TensorType; found {type(x.type)}."
)

if ndim is None:
Expand All @@ -135,9 +134,7 @@ def _as_tensor_Variable(x, name, ndim, **kwargs):
x = x.dimshuffle(list(range(x.ndim))[first_non_broadcastable:])
if x.ndim > ndim:
raise ValueError(
"Tensor of type {} could not be cast to have {} dimensions".format(
x.type, ndim
)
f"Tensor of type {x.type} could not be cast to have {ndim} dimensions"
)
return x
elif x.type.ndim < ndim:
Expand Down Expand Up @@ -258,31 +255,6 @@ def _obj_is_wrappable_as_tensor(x):
return False


def numpy_scalar(data):
"""Return a scalar stored in a numpy ndarray.

Raises
------
NotScalarConstantError
If the numpy ndarray is not a scalar.
EmptyConstantError

"""

# handle case where data is numpy.array([])
if data.ndim > 0 and (len(data.shape) == 0 or builtins.max(data.shape) == 0):
assert np.all(np.array([]) == data)
raise EmptyConstantError()
try:
complex(data) # works for all numeric scalars
return data
except Exception:
raise NotScalarConstantError(
"v.data is non-numeric, non-scalar, or has more than one" " unique value",
data,
)


_scalar_constant_value_elemwise_ops = (
aes.Cast,
aes.Switch,
Expand Down Expand Up @@ -345,7 +317,10 @@ def get_scalar_constant_value(
return np.asarray(v)

if isinstance(v, np.ndarray):
return numpy_scalar(v).copy()
try:
return np.array(v.item(), dtype=v.dtype)
except ValueError:
raise NotScalarConstantError()

if isinstance(v, Constant):
if getattr(v.tag, "unique_value", None) is not None:
Expand All @@ -354,7 +329,10 @@ def get_scalar_constant_value(
data = v.data

if isinstance(data, np.ndarray):
return numpy_scalar(data).copy()
try:
return np.array(data.item(), dtype=v.dtype)
except ValueError:
raise NotScalarConstantError()
else:
return data

Expand Down Expand Up @@ -461,28 +439,14 @@ def get_scalar_constant_value(
and isinstance(v.owner.inputs[0].owner.op, Join)
and len(v.owner.op.idx_list) == 1
):
# Ensure the Join is joining only scalar variables (so that
# the constant value can be found at the same index as the
# one used in the sub-tensor).
if builtins.all(
var.ndim == 0 for var in v.owner.inputs[0].owner.inputs[1:]
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, CType):
idx = get_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
# Note the '+ 1' is because the first argument to Join
# is the axis.
ret = v.owner.inputs[0].owner.inputs[idx + 1]
ret = get_scalar_constant_value(ret, max_recur=max_recur)
# join can cast implicitly its input in some case.
return _asarray(ret, dtype=v.type.dtype)
# Ensure the Join is joining only (effectively) scalar
# variables (so that the constant value can be found at the
# same index as the one used in the sub-tensor).
if builtins.all(
var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:]
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, CType):
if isinstance(idx, Type):
idx = get_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
Expand Down Expand Up @@ -517,7 +481,7 @@ def get_scalar_constant_value(
):

idx = v.owner.op.idx_list[0]
if isinstance(idx, CType):
if isinstance(idx, Type):
idx = get_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
Expand All @@ -539,7 +503,7 @@ def get_scalar_constant_value(
op = owner.op
idx_list = op.idx_list
idx = idx_list[0]
if isinstance(idx, CType):
if isinstance(idx, Type):
idx = get_scalar_constant_value(
owner.inputs[1], max_recur=max_recur
)
Expand Down Expand Up @@ -576,12 +540,7 @@ def get_scalar_constant_value(
if isinstance(grandparent, Constant):
return np.asarray(np.shape(grandparent.data)[idx])

raise NotScalarConstantError(v)


#########################
# Casting Operations
#########################
raise NotScalarConstantError()


class TensorFromScalar(Op):
Expand Down Expand Up @@ -702,7 +661,7 @@ class Rebroadcast(COp):
def __init__(self, *axis):
# Sort them to make sure we merge all possible case.
items = sorted(axis)
self.axis = OrderedDict(items)
self.axis = dict(items)
for axis, broad in self.axis.items():
if not isinstance(axis, (np.integer, int)):
raise TypeError(f"Rebroadcast needs integer axes. Got {axis}")
Expand All @@ -719,13 +678,7 @@ def __hash__(self):
return hash((type(self), tuple(items)))

def __str__(self):
if len(self.axis) == 0:
broadcast_pattern = []
else:
broadcast_pattern = ["?" for i in range(1 + max(self.axis.keys()))]
for k, v in self.axis.items():
broadcast_pattern[k] = str(int(v))
return f"{self.__class__.__name__}{{{','.join(broadcast_pattern)}}}"
return f"{self.__class__.__name__}{{{','.join(str(i) for i in self.axis.items())}}}"

def make_node(self, x):
if self.axis.keys() and (x.ndim <= max(self.axis.keys())):
Expand Down
Loading