Skip to content

Commit

Permalink
Change infer_broadcastable to infer_static_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Oct 16, 2022
1 parent c7ff283 commit 63ca73d
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 26 deletions.
43 changes: 31 additions & 12 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import Sequence
from functools import partial
from numbers import Number
from typing import Optional
from typing import TYPE_CHECKING, Optional
from typing import Sequence as TypeSequence
from typing import Tuple, Union
from typing import cast as type_cast
Expand Down Expand Up @@ -68,6 +68,10 @@
from aesara.tensor.var import TensorConstant, TensorVariable, get_unique_value


if TYPE_CHECKING:
from aesara.tensor import TensorLike


def __oplist_tag(thing, tag):
tags = getattr(thing, "__oplist_tags", [])
tags.append(tag)
Expand Down Expand Up @@ -1334,11 +1338,25 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype)


def infer_broadcastable(shape):
"""Infer the broadcastable dimensions for `shape`.
def infer_static_shape(
shape: Union[Variable, TypeSequence[Union[Variable, int]]]
) -> Tuple[TypeSequence["TensorLike"], TypeSequence[Optional[int]]]:
"""Infer the static shapes implied by the potentially symbolic elements in `shape`.
`shape` will be validated and constant folded. As a result, this function
can be expensive and shouldn't be used unless absolutely necessary.
It mostly exists as a hold-over from pre-static shape times, when it was
required in order to produce correct broadcastable arrays and prevent
some graphs from being unusable. Now, it is no longer strictly required,
so don't use it unless you want the same shape graphs to be rewritten
multiple times during graph construction.
Returns
-------
A validated sequence of symbolic shape values, and a sequence of
``None``/``int`` values that can be used as `TensorType.shape` values.
`shape` will be validated and constant folded in order to determine
which dimensions are broadcastable (i.e. equal to ``1``).
"""
from aesara.tensor.rewriting.basic import topo_constant_folding
from aesara.tensor.rewriting.shape import ShapeFeature
Expand All @@ -1362,9 +1380,10 @@ def check_type(s):
clone=True,
)
folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs

bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape)
return sh, bcast
static_shape = tuple(
s.data.item() if isinstance(s, Constant) else None for s in folded_shape
)
return sh, static_shape


class Alloc(COp):
Expand Down Expand Up @@ -1394,15 +1413,15 @@ class Alloc(COp):

def make_node(self, value, *shape):
v = as_tensor_variable(value)
sh, bcast = infer_broadcastable(shape)
sh, static_shape = infer_static_shape(shape)
if v.ndim > len(sh):
raise TypeError(
"The Alloc value to use has more dimensions"
" than the specified dimensions",
v.ndim,
len(sh),
)
otype = TensorType(dtype=v.dtype, shape=bcast)
otype = TensorType(dtype=v.dtype, shape=static_shape)
return Apply(self, [v] + sh, [otype()])

def perform(self, node, inputs, out_):
Expand Down Expand Up @@ -3823,8 +3842,8 @@ def typecode(self):
return np.dtype(self.dtype).num

def make_node(self, *_shape):
_shape, bcast = infer_broadcastable(_shape)
otype = TensorType(dtype=self.dtype, shape=bcast)
_shape, static_shape = infer_static_shape(_shape)
otype = TensorType(dtype=self.dtype, shape=static_shape)
output = otype()

output.tag.values_eq_approx = values_eq_approx_always_true
Expand Down
13 changes: 8 additions & 5 deletions aesara/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,9 +1613,9 @@ def __call__(self, a, shape, **kwargs):
def make_node(self, a, *shape):
a = at.as_tensor_variable(a)

shape, bcast = at.infer_broadcastable(shape)
shape, static_shape = at.infer_static_shape(shape)

out = TensorType(dtype=a.type.dtype, shape=bcast)()
out = TensorType(dtype=a.type.dtype, shape=static_shape)()

# Attempt to prevent in-place operations on this view-based output
out.tag.indestructible = True
Expand All @@ -1637,11 +1637,14 @@ def grad(self, inputs, outputs_gradients):
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)

# Determine the dimensions that were broadcast
_, shape_bcast = at.infer_broadcastable(shape)
_, static_shape = at.infer_static_shape(shape)

# TODO: This needs to be performed at run-time when static shape
# information isn't available.
bcast_sums = [
i
for i, (a_b, s_b) in enumerate(zip(a.broadcastable, shape_bcast[-a.ndim :]))
if a_b and not s_b
for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :]))
if a_s == 1 and s_s != 1
]

if bcast_sums:
Expand Down
6 changes: 3 additions & 3 deletions aesara/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
constant,
get_scalar_constant_value,
get_vector_length,
infer_broadcastable,
infer_static_shape,
)
from aesara.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
Expand Down Expand Up @@ -322,7 +322,7 @@ def make_node(self, rng, size, dtype, *dist_params):
)

shape = self._infer_shape(size, dist_params)
_, bcast = infer_broadcastable(shape)
_, static_shape = infer_static_shape(shape)
dtype = self.dtype or dtype

if dtype == "floatX":
Expand All @@ -336,7 +336,7 @@ def make_node(self, rng, size, dtype, *dist_params):
dtype_idx = constant(dtype, dtype="int64")
dtype = all_dtypes[dtype_idx.data]

outtype = TensorType(dtype=dtype, shape=bcast)
outtype = TensorType(dtype=dtype, shape=static_shape)
out_var = outtype()
inputs = (rng, size, dtype_idx) + dist_params
outputs = (rng.type(), out_var)
Expand Down
12 changes: 6 additions & 6 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
get_vector_length,
horizontal_stack,
identity_like,
infer_broadcastable,
infer_static_shape,
inverse_permutation,
join,
make_vector,
Expand Down Expand Up @@ -796,20 +796,20 @@ def test_full(self):

def test_infer_broadcastable():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
infer_broadcastable([constant(1.0)])
infer_static_shape([constant(1.0)])

with config.change_flags(exception_verbosity="high"), pytest.raises(
TypeError, match=r"A\. x"
):
infer_broadcastable([dscalar("x")])
infer_static_shape([dscalar("x")])

with pytest.raises(ValueError, match=".*could not be cast to have 0 dimensions"):
infer_broadcastable((as_tensor_variable([[1, 2]]),))
infer_static_shape((as_tensor_variable([[1, 2]]),))

constant_size = constant([1])
specify_size = specify_shape(constant_size, [1])
sh, bcast = infer_broadcastable(specify_size)
assert bcast == (True,)
sh, static_shape = infer_static_shape(specify_size)
assert static_shape == (1,)


# This is slow for the ('int8', 3) version.
Expand Down

0 comments on commit 63ca73d

Please sign in to comment.