Skip to content

Commit

Permalink
Generalize broadcastables inference and fix Alloc broadcastables case
Browse files Browse the repository at this point in the history
Closes #692
  • Loading branch information
brandonwillard committed Dec 9, 2021
1 parent 934596e commit c67e71b
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 73 deletions.
6 changes: 3 additions & 3 deletions aesara/gpuarray/basic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from aesara.link.c.interface import HideC
from aesara.scalar import bool as bool_t
from aesara.scalar import int32 as int32_t
from aesara.tensor.basic import Alloc, AllocEmpty, Join, Split, alloc_validate_shape
from aesara.tensor.basic import Alloc, AllocEmpty, Join, Split, infer_broadcastable
from aesara.tensor.shape import Reshape
from aesara.tensor.type import TensorType, values_eq_approx_always_true

Expand Down Expand Up @@ -909,7 +909,7 @@ def __str__(self):

def make_node(self, value, *shape):
value = as_gpuarray_variable(value, context_name=self.context_name)
sh, bcast = alloc_validate_shape(shape)
sh, bcast = infer_broadcastable(shape)
if value.ndim > len(sh):
TypeError(
"The GpuAlloc value to use has more dimensions "
Expand Down Expand Up @@ -1071,7 +1071,7 @@ def get_params(self, node):
)

def make_node(self, *shape):
sh, bcast = alloc_validate_shape(shape)
sh, bcast = infer_broadcastable(shape)
output = GpuArrayType(
dtype=self.dtype, broadcastable=bcast, context_name=self.context_name
)()
Expand Down
67 changes: 33 additions & 34 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from aesara import scalar as aes
from aesara.gradient import DisconnectedType, grad_not_implemented, grad_undefined
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import COp, Op
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.params_type import ParamsType
from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray
Expand Down Expand Up @@ -1324,43 +1326,44 @@ def identity_like(x):
return eye(x.shape[0], x.shape[1], k=0, dtype=x.dtype)


def alloc_validate_shape(shape):
sh = [as_tensor_variable(s) for s in shape]
bcast = []
for i, s in enumerate(sh):
def infer_broadcastable(shape):
"""Infer the broadcastable dimensions for `shape`.
def err_str():
if config.exception_verbosity == "high":
return "\n" + min_informative_str(s)
else:
return str(s)
`shape` will be validated and constant folded in order to determine
which dimensions are broadcastable (i.e. equal to ``1``).
"""
from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding

if s.type.dtype not in integer_dtypes:
s_as_str = err_str()
raise TypeError(
"Shape arguments to Alloc must be integers, "
f"but argument {i} is not for apply node: {s_as_str}"
)
if s.ndim != 0:
s_as_str = err_str()
raise TypeError(
"Each shape dimension to Alloc must be a scalar, ",
f"but dimension {i} have {int(s.ndim)} dimensions for apply node: {s_as_str}",
)
def check_type(s):
if s.type.dtype in integer_dtypes:
return s

if config.exception_verbosity == "high":
s_as_str = "\n" + min_informative_str(s)
else:
s_as_str = str(s)

raise TypeError(f"Shapes must be scalar integers; got {s_as_str}")

# if s is constant 1, then we're broadcastable in that dim
try:
const_shp = get_scalar_constant_value(s)
except NotScalarConstantError:
const_shp = None
bcast.append(1 == const_shp)
sh = [check_type(as_tensor_variable(s, ndim=0)) for s in shape]

shape_fg = FunctionGraph(
outputs=sh,
features=[ShapeFeature()],
clone=True,
)
folded_shape = optimize_graph(shape_fg, custom_opt=topo_constant_folding).outputs

bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape)
return sh, bcast


class Alloc(COp):
"""Create a `TensorVariable` from an initial value and a desired shape.
alloc(value, shape0, shape1, ..., shapeN)
Usage:
alloc(value, shape0, shape1, ..., shapeN)
Returns an N-dimensional tensor initialized by a value, using something
equivalent to
Expand All @@ -1380,12 +1383,9 @@ class Alloc(COp):
_f16_ok = True
__props__ = ()

def validate_shape(self, shape):
return alloc_validate_shape(shape)

def make_node(self, value, *shape):
v = as_tensor_variable(value)
sh, bcast = alloc_validate_shape(shape)
sh, bcast = infer_broadcastable(shape)
if v.ndim > len(sh):
raise TypeError(
"The Alloc value to use has more dimensions"
Expand Down Expand Up @@ -4102,7 +4102,7 @@ def typecode(self):
return np.dtype(self.dtype).num

def make_node(self, *_shape):
_shape, bcast = alloc_validate_shape(_shape)
_shape, bcast = infer_broadcastable(_shape)
otype = TensorType(dtype=self.dtype, broadcastable=bcast)
output = otype()

Expand Down Expand Up @@ -4363,7 +4363,6 @@ def take_along_axis(arr, indices, axis=0):
"tensor_copy",
"transfer",
"alloc",
"alloc_validate_shape",
"identity_like",
"eye",
"triu",
Expand Down
4 changes: 2 additions & 2 deletions aesara/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,7 +1585,7 @@ def make_node(self, a, *shape):
a = aet.as_tensor_variable(a)
shape = aet.as_tensor_variable(shape, ndim=1)

shape, bcast = aet.alloc_validate_shape(shape)
shape, bcast = aet.infer_broadcastable(shape)

out = type(a.type)(dtype=a.type.dtype, broadcastable=bcast)()

Expand All @@ -1609,7 +1609,7 @@ def grad(self, inputs, outputs_gradients):
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)

# Determine the dimensions that were broadcast
_, shape_bcast = aet.alloc_validate_shape(shape)
_, shape_bcast = aet.infer_broadcastable(shape)
bcast_sums = [
i
for i, (a_b, s_b) in enumerate(zip(a.broadcastable, shape_bcast[-a.ndim :]))
Expand Down
32 changes: 3 additions & 29 deletions aesara/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,16 @@
import aesara
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt_utils import optimize_graph
from aesara.misc.safe_asarray import _asarray
from aesara.scalar import ScalarVariable
from aesara.tensor.basic import (
as_tensor_variable,
constant,
get_scalar_constant_value,
get_vector_length,
infer_broadcastable,
)
from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding
from aesara.tensor.random.type import RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from aesara.tensor.shape import shape_tuple
Expand Down Expand Up @@ -276,31 +274,6 @@ def slice_ind_dims(p, ps, n):

return shape

@config.change_flags(compute_test_value="off")
def compute_bcast(self, dist_params, size):
"""Compute the broadcast array for this distribution's `TensorType`.
Parameters
----------
dist_params: list
Distribution parameters.
size: int or Sequence (optional)
Numpy-like size of the output (i.e. replications).
"""
shape = self._infer_shape(size, dist_params)

shape_fg = FunctionGraph(
outputs=[as_tensor_variable(s, ndim=0) for s in shape],
features=[ShapeFeature()],
clone=True,
)
folded_shape = optimize_graph(
shape_fg, custom_opt=topo_constant_folding
).outputs

return [getattr(s, "data", s) == 1 for s in folded_shape]

def infer_shape(self, fgraph, node, input_shapes):
_, size, _, *dist_params = node.inputs
_, size_shape, _, *param_shapes = input_shapes
Expand Down Expand Up @@ -362,7 +335,8 @@ def make_node(self, rng, size, dtype, *dist_params):
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
)

bcast = self.compute_bcast(dist_params, size)
shape = self._infer_shape(size, dist_params)
_, bcast = infer_broadcastable(shape)
dtype = self.dtype or dtype

if dtype == "floatX":
Expand Down
8 changes: 4 additions & 4 deletions tests/tensor/random/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ def test_RandomVariable_bcast():
s3.tag.test_value = 3
s3 = Assert("testing")(s3, eq(s1, 1))

res = rv.compute_bcast([mu, sd], (s1, s2, s3))
assert res == [False] * 3
res = rv(mu, sd, size=(s1, s2, s3))
assert res.broadcastable == (False,) * 3

size = aet.as_tensor((1, 2, 3), dtype=np.int32).astype(np.int64)
res = rv.compute_bcast([mu, sd], size)
assert res == [True, False, False]
res = rv(mu, sd, size=size)
assert res.broadcastable == (True, False, False)

res = rv(0, 1, size=aet.as_tensor(1, dtype=np.int64))
assert res.broadcastable == (True,)
Expand Down
21 changes: 20 additions & 1 deletion tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
get_scalar_constant_value,
get_vector_length,
horizontal_stack,
infer_broadcastable,
inverse_permutation,
join,
make_vector,
Expand Down Expand Up @@ -90,7 +91,7 @@
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import dense_dot, eq
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padright
from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padright, specify_shape
from aesara.tensor.type import (
TensorType,
bvector,
Expand Down Expand Up @@ -658,6 +659,24 @@ def test_full(self):
assert np.array_equal(res, np.full((2, 3), 3, dtype="int64"))


def test_infer_broadcastable():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
infer_broadcastable([constant(1.0)])

with config.change_flags(exception_verbosity="high"), pytest.raises(
TypeError, match=r"A\. x"
):
infer_broadcastable([dscalar("x")])

with pytest.raises(ValueError, match=".*could not be cast to have 0 dimensions"):
infer_broadcastable((as_tensor_variable([[1, 2]]),))

constant_size = constant([1])
specify_size = specify_shape(constant_size, [1])
sh, bcast = infer_broadcastable(specify_size)
assert bcast == (True,)


# This is slow for the ('int8', 3) version.
def test_eye():
def check(dtype, N, M_=None, k=0):
Expand Down

0 comments on commit c67e71b

Please sign in to comment.