Skip to content

Commit

Permalink
Remove BroadcastTo
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 7, 2023
1 parent 5f809cf commit 6898f74
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 541 deletions.
14 changes: 0 additions & 14 deletions pytensor/link/jax/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import jax.numpy as jnp

from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.basic import infer_static_shape
from pytensor.tensor.extra_ops import (
Bartlett,
BroadcastTo,
CumOp,
FillDiagonal,
FillDiagonalOffset,
Expand Down Expand Up @@ -102,18 +100,6 @@ def ravelmultiindex(*inp, mode=mode, order=order):
return ravelmultiindex


@jax_funcify.register(BroadcastTo)
def jax_funcify_BroadcastTo(op, node, **kwargs):
shape = node.inputs[1:]
static_shape = infer_static_shape(shape)[1]

def broadcast_to(x, *shape):
shape = tuple(st if st is not None else s for s, st in zip(shape, static_shape))
return jnp.broadcast_to(x, shape)

return broadcast_to


@jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op, **kwargs):
def filldiagonal(value, diagonal):
Expand Down
25 changes: 0 additions & 25 deletions pytensor/link/numba/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

import numba
import numpy as np
from numba.misc.special import literal_unroll

from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.extra_ops import (
Bartlett,
BroadcastTo,
CumOp,
FillDiagonal,
FillDiagonalOffset,
Expand Down Expand Up @@ -353,29 +351,6 @@ def searchsorted(a, v):
return searchsorted


@numba_funcify.register(BroadcastTo)
def numba_funcify_BroadcastTo(op, node, **kwargs):
create_zeros_tuple = numba_basic.create_tuple_creator(
lambda _: 0, len(node.inputs) - 1
)

# TODO broadcastable checks
@numba_basic.numba_njit
def broadcast_to(x, *shape):
scalars_shape = create_zeros_tuple()

i = 0
for s_i in literal_unroll(shape):
scalars_shape = numba_basic.tuple_setitem(
scalars_shape, i, numba_basic.to_scalar(s_i)
)
i += 1

return np.broadcast_to(x, scalars_shape)

return broadcast_to


@numba_funcify.register(CheckAndRaise)
def numba_funcify_CheckAndRaise(op, node, **kwargs):
error = op.exc_type
Expand Down
145 changes: 2 additions & 143 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytensor.scalar import upcast
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at
from pytensor.tensor.basic import get_vector_length, second
from pytensor.tensor.basic import alloc, second
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all
Expand Down Expand Up @@ -1584,141 +1584,6 @@ def broadcast_shape_iter(
return tuple(result_dims)


class BroadcastTo(COp):
"""An `Op` for `numpy.broadcast_to`."""

_output_type_depends_on_input_value = True

__props__ = ()

view_map = {0: [0]}

def __call__(self, a, shape, **kwargs):
return super().__call__(a, *shape, **kwargs)

def make_node(self, a, *shape):
a = at.as_tensor_variable(a)

shape, static_shape = at.infer_static_shape(shape)

if len(shape) < a.ndim:
raise ValueError(
f"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims"
)

out = TensorType(dtype=a.type.dtype, shape=static_shape)()

# Attempt to prevent in-place operations on this view-based output
out.tag.indestructible = True

return Apply(self, [a] + shape, [out])

def perform(self, node, inputs, output_storage):
a, *shape = inputs
z = output_storage[0]
z[0] = np.broadcast_to(a, shape)

def grad(self, inputs, outputs_gradients):
a, *shape = inputs
(dout,) = outputs_gradients

# Determine the dimensions that were added by broadcasting
new_dims = list(range(dout.ndim - a.ndim))

d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)

# Determine the dimensions that were broadcast
_, static_shape = at.infer_static_shape(shape)

# TODO: This needs to be performed at run-time when static shape
# information isn't available.
bcast_sums = [
i
for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :]))
if a_s == 1 and s_s != 1
]

if bcast_sums:
d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True)

return [d_wrt_a] + [
grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1)
]

def infer_shape(self, fgraph, node, ins_shapes):
return [node.inputs[1:]]

def c_code(self, node, name, inputs, outputs, sub):
inp_dims = node.inputs[0].ndim
out_dims = node.outputs[0].ndim
new_dims = out_dims - inp_dims

(x, *shape) = inputs
(out,) = outputs
fail = sub["fail"]

# TODO: Could just use `PyArray_Return`, no?
dims_array = ", ".join(
[
f"((dtype_{shape}*)(PyArray_DATA({shape})))[0]"
for i, shape in enumerate(shape)
]
)

src = (
"""
npy_intp itershape[%(out_dims)s] = {%(dims_array)s};
NpyIter *iter;
PyArrayObject *ops[1] = {%(x)s};
npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK;
npy_uint32 op_flags[1] = {NPY_ITER_READONLY};
PyArray_Descr *op_dtypes[1] = {NULL};
int oa_ndim = %(out_dims)s;
int* op_axes[1] = {NULL};
npy_intp buffersize = 0;
for(int i = 0; i < %(inp_dims)s; i++)
{
if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + %(new_dims)s]))
{
PyErr_Format(PyExc_ValueError,
"Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.",
i,
(long long int) itershape[i + %(new_dims)s],
(long long int) PyArray_DIMS(%(x)s)[i]
);
%(fail)s
}
}
iter = NpyIter_AdvancedNew(
1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize
);
%(out)s = NpyIter_GetIterView(iter, 0);
if(%(out)s == NULL){
NpyIter_Deallocate(iter);
%(fail)s;
}
if (NpyIter_Deallocate(iter) != NPY_SUCCEED) {
%(fail)s;
}
"""
% locals()
)

return src

def c_code_cache_version(self):
return (2,)


broadcast_to_ = BroadcastTo()


def geomspace(start, end, steps, base=10.0):
from pytensor.tensor.math import log

Expand Down Expand Up @@ -1762,13 +1627,7 @@ def broadcast_to(
broadcasted array may refer to a single memory location.
"""
x = at.as_tensor(x)
shape_len = get_vector_length(shape)

if x.ndim == 0 and shape_len == 0:
return x

return broadcast_to_(x, shape)
return alloc(x, *shape)


def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
Expand Down
48 changes: 1 addition & 47 deletions pytensor/tensor/rewriting/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.basic import Alloc, as_tensor_variable
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.extra_ops import BroadcastTo, Repeat, Unique
from pytensor.tensor.extra_ops import Repeat, Unique
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless


Expand Down Expand Up @@ -60,39 +60,6 @@ def local_Unique_Alloc_lift(fgraph, node):
return [new_x]


@register_useless
@register_canonicalize
@node_rewriter([Unique])
def local_Unique_BroadcastTo_lift(fgraph, node):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False

if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False

bcast_var = node.inputs[0]

if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)):
return False

bcasted_var, *bcast_shape = bcast_var.owner.inputs

new_unique, *_ = node.op.make_node(bcasted_var).outputs

old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]


@register_useless
@register_canonicalize
@node_rewriter([Unique])
Expand Down Expand Up @@ -161,16 +128,3 @@ def local_Unique_second(fgraph, node):
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]


@register_useless
@register_canonicalize
@node_rewriter([BroadcastTo])
def local_remove_scalar_BroadcastTo(fgraph, node):
bcast_shape = node.inputs[1:]

if not bcast_shape:
bcasted_var = node.inputs[0]
# If this isn't true, the graph is invalid
assert bcasted_var.ndim == 0
return [bcasted_var]
25 changes: 1 addition & 24 deletions tests/link/jax/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import extra_ops as at_extra_ops
from pytensor.tensor.type import matrix, vector
from pytensor.tensor.type import matrix
from tests.link.jax.test_basic import compare_jax_and_py


Expand Down Expand Up @@ -63,29 +63,6 @@ def test_extra_ops():
)


@pytest.mark.parametrize(
"x, shape",
[
(
set_test_value(
vector("x"), np.random.random(size=(2,)).astype(config.floatX)
),
[at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)],
),
(
set_test_value(
vector("x"), np.random.random(size=(2,)).astype(config.floatX)
),
[at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)],
),
],
)
def test_BroadcastTo(x, shape):
out = at_extra_ops.broadcast_to(x, shape)
fgraph = FunctionGraph(outputs=[out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])


@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
Expand Down
35 changes: 0 additions & 35 deletions tests/link/numba/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,41 +36,6 @@ def test_Bartlett(val):
)


@pytest.mark.parametrize(
"x, shape",
[
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]],
),
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)],
),
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
at.as_tensor([set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]]),
),
(
set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)),
[at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)],
),
],
)
def test_BroadcastTo(x, shape):
g = extra_ops.BroadcastTo()(x, shape)
g_fg = FunctionGraph(outputs=[g])

compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)


@pytest.mark.parametrize(
"val, axis, mode",
[
Expand Down
Loading

0 comments on commit 6898f74

Please sign in to comment.