Skip to content

Commit

Permalink
Prevent broadcast_to from creating useless Ops
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jan 11, 2022
1 parent 8af9aa2 commit 3edbbc4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 12 deletions.
35 changes: 33 additions & 2 deletions aesara/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
disconnected_type,
grad_undefined,
)
from aesara.graph.basic import Apply, equal_computations
from aesara.graph.basic import Apply, Variable, equal_computations
from aesara.graph.op import COp, Op
from aesara.graph.params_type import ParamsType
from aesara.graph.type import EnumList, Generic
Expand All @@ -19,6 +19,7 @@
from aesara.scalar import int32 as int_t
from aesara.scalar import upcast
from aesara.tensor import basic as at
from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import all as at_all
Expand Down Expand Up @@ -1627,7 +1628,37 @@ def infer_shape(self, fgraph, node, ins_shapes):
return [node.inputs[1:]]


broadcast_to = BroadcastTo()
broadcast_to_ = BroadcastTo()


def broadcast_to(
x: TensorVariable, shape: Union[TensorVariable, Tuple[Variable]]
) -> TensorVariable:
"""Broadcast an array to a new shape.
Parameters
----------
array
The array to broadcast.
shape
The shape of the desired array.
Returns
-------
broadcast
A readonly view on the original array with the given shape. It is
typically not contiguous. Furthermore, more than one element of a
broadcasted array may refer to a single memory location.
"""
x = at.as_tensor(x)
shape = at.as_tensor(shape, ndim=1, dtype="int64")
shape_len = get_vector_length(shape)

if x.ndim == 0 and shape_len == 0:
return x

return broadcast_to_(x, shape)


def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
Expand Down
12 changes: 2 additions & 10 deletions tests/tensor/test_basic_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,7 @@
register_specialize,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import (
BroadcastTo,
Repeat,
Unique,
broadcast_to,
repeat,
unique,
)
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique
from aesara.tensor.math import (
add,
bitwise_and,
Expand Down Expand Up @@ -3359,7 +3352,6 @@ def test_local_Unique_Alloc_lift(
@pytest.mark.parametrize(
"x_val, axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, ()),
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
Expand All @@ -3372,7 +3364,7 @@ def test_local_Unique_BroadcastTo(
):
x = as_tensor_variable(x_val).type()
y = unique(
broadcast_to(x, tuple(new_shape)),
BroadcastTo()(x, tuple(new_shape)),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
Expand Down
5 changes: 5 additions & 0 deletions tests/tensor/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,11 @@ def setup_method(self):
self.op_class = BroadcastTo
self.op = broadcast_to

def test_avoid_useless_scalars(self):
x = scalar()
y = broadcast_to(x, ())
assert y is x

@config.change_flags(compute_test_value="raise")
def test_perform(self):
a = scalar()
Expand Down

0 comments on commit 3edbbc4

Please sign in to comment.