diff --git a/pytensor/link/jax/dispatch/extra_ops.py b/pytensor/link/jax/dispatch/extra_ops.py index bfce752434..a9e36667ef 100644 --- a/pytensor/link/jax/dispatch/extra_ops.py +++ b/pytensor/link/jax/dispatch/extra_ops.py @@ -3,10 +3,8 @@ import jax.numpy as jnp from pytensor.link.jax.dispatch.basic import jax_funcify -from pytensor.tensor.basic import infer_static_shape from pytensor.tensor.extra_ops import ( Bartlett, - BroadcastTo, CumOp, FillDiagonal, FillDiagonalOffset, @@ -102,18 +100,6 @@ def ravelmultiindex(*inp, mode=mode, order=order): return ravelmultiindex -@jax_funcify.register(BroadcastTo) -def jax_funcify_BroadcastTo(op, node, **kwargs): - shape = node.inputs[1:] - static_shape = infer_static_shape(shape)[1] - - def broadcast_to(x, *shape): - shape = tuple(st if st is not None else s for s, st in zip(shape, static_shape)) - return jnp.broadcast_to(x, shape) - - return broadcast_to - - @jax_funcify.register(FillDiagonal) def jax_funcify_FillDiagonal(op, **kwargs): def filldiagonal(value, diagonal): diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index ce275fd031..a3a489deaa 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -2,7 +2,6 @@ import numba import numpy as np -from numba.misc.special import literal_unroll from pytensor import config from pytensor.link.numba.dispatch import basic as numba_basic @@ -10,7 +9,6 @@ from pytensor.raise_op import CheckAndRaise from pytensor.tensor.extra_ops import ( Bartlett, - BroadcastTo, CumOp, FillDiagonal, FillDiagonalOffset, @@ -353,29 +351,6 @@ def searchsorted(a, v): return searchsorted -@numba_funcify.register(BroadcastTo) -def numba_funcify_BroadcastTo(op, node, **kwargs): - create_zeros_tuple = numba_basic.create_tuple_creator( - lambda _: 0, len(node.inputs) - 1 - ) - - # TODO broadcastable checks - @numba_basic.numba_njit - def broadcast_to(x, *shape): - scalars_shape = create_zeros_tuple() - - i = 0 - for s_i in literal_unroll(shape): - scalars_shape = numba_basic.tuple_setitem( - scalars_shape, i, numba_basic.to_scalar(s_i) - ) - i += 1 - - return np.broadcast_to(x, scalars_shape) - - return broadcast_to - - @numba_funcify.register(CheckAndRaise) def numba_funcify_CheckAndRaise(op, node, **kwargs): error = op.exc_type diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 09e8bf5551..6a7e8d38bc 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -23,7 +23,7 @@ from pytensor.scalar import upcast from pytensor.tensor import as_tensor_variable from pytensor.tensor import basic as at -from pytensor.tensor.basic import get_vector_length, second +from pytensor.tensor.basic import alloc, second from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import all as pt_all @@ -1584,141 +1584,6 @@ def broadcast_shape_iter( return tuple(result_dims) -class BroadcastTo(COp): - """An `Op` for `numpy.broadcast_to`.""" - - _output_type_depends_on_input_value = True - - __props__ = () - - view_map = {0: [0]} - - def __call__(self, a, shape, **kwargs): - return super().__call__(a, *shape, **kwargs) - - def make_node(self, a, *shape): - a = at.as_tensor_variable(a) - - shape, static_shape = at.infer_static_shape(shape) - - if len(shape) < a.ndim: - raise ValueError( - f"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims" - ) - - out = TensorType(dtype=a.type.dtype, shape=static_shape)() - - # Attempt to prevent in-place operations on this view-based output - out.tag.indestructible = True - - return Apply(self, [a] + shape, [out]) - - def perform(self, node, inputs, output_storage): - a, *shape = inputs - z = output_storage[0] - z[0] = np.broadcast_to(a, shape) - - def grad(self, inputs, outputs_gradients): - a, *shape = inputs - (dout,) = outputs_gradients - - # Determine the dimensions that were added by broadcasting - new_dims = list(range(dout.ndim - a.ndim)) - - d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims) - - # Determine the dimensions that were broadcast - _, static_shape = at.infer_static_shape(shape) - - # TODO: This needs to be performed at run-time when static shape - # information isn't available. - bcast_sums = [ - i - for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :])) - if a_s == 1 and s_s != 1 - ] - - if bcast_sums: - d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True) - - return [d_wrt_a] + [ - grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1) - ] - - def infer_shape(self, fgraph, node, ins_shapes): - return [node.inputs[1:]] - - def c_code(self, node, name, inputs, outputs, sub): - inp_dims = node.inputs[0].ndim - out_dims = node.outputs[0].ndim - new_dims = out_dims - inp_dims - - (x, *shape) = inputs - (out,) = outputs - fail = sub["fail"] - - # TODO: Could just use `PyArray_Return`, no? - dims_array = ", ".join( - [ - f"((dtype_{shape}*)(PyArray_DATA({shape})))[0]" - for i, shape in enumerate(shape) - ] - ) - - src = ( - """ - npy_intp itershape[%(out_dims)s] = {%(dims_array)s}; - - NpyIter *iter; - PyArrayObject *ops[1] = {%(x)s}; - npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK; - npy_uint32 op_flags[1] = {NPY_ITER_READONLY}; - PyArray_Descr *op_dtypes[1] = {NULL}; - int oa_ndim = %(out_dims)s; - int* op_axes[1] = {NULL}; - npy_intp buffersize = 0; - - for(int i = 0; i < %(inp_dims)s; i++) - { - if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + %(new_dims)s])) - { - PyErr_Format(PyExc_ValueError, - "Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.", - i, - (long long int) itershape[i + %(new_dims)s], - (long long int) PyArray_DIMS(%(x)s)[i] - ); - %(fail)s - } - } - - iter = NpyIter_AdvancedNew( - 1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize - ); - %(out)s = NpyIter_GetIterView(iter, 0); - - if(%(out)s == NULL){ - NpyIter_Deallocate(iter); - %(fail)s; - } - - if (NpyIter_Deallocate(iter) != NPY_SUCCEED) { - %(fail)s; - } - - """ - % locals() - ) - - return src - - def c_code_cache_version(self): - return (2,) - - -broadcast_to_ = BroadcastTo() - - def geomspace(start, end, steps, base=10.0): from pytensor.tensor.math import log @@ -1762,13 +1627,7 @@ def broadcast_to( broadcasted array may refer to a single memory location. """ - x = at.as_tensor(x) - shape_len = get_vector_length(shape) - - if x.ndim == 0 and shape_len == 0: - return x - - return broadcast_to_(x, shape) + return alloc(x, *shape) def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]: diff --git a/pytensor/tensor/rewriting/extra_ops.py b/pytensor/tensor/rewriting/extra_ops.py index aa20334abc..945433f2a4 100644 --- a/pytensor/tensor/rewriting/extra_ops.py +++ b/pytensor/tensor/rewriting/extra_ops.py @@ -2,7 +2,7 @@ from pytensor.graph.rewriting.basic import node_rewriter from pytensor.tensor.basic import Alloc, as_tensor_variable from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.extra_ops import BroadcastTo, Repeat, Unique +from pytensor.tensor.extra_ops import Repeat, Unique from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless @@ -60,39 +60,6 @@ def local_Unique_Alloc_lift(fgraph, node): return [new_x] -@register_useless -@register_canonicalize -@node_rewriter([Unique]) -def local_Unique_BroadcastTo_lift(fgraph, node): - """Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``. - - This isn't really so much a lift as a "reduction/consumption". - """ - if not isinstance(node.op, Unique): - return False - - if ( - node.op.return_index - or node.op.return_inverse - or node.op.return_counts - or node.op.axis is not None - ): - return False - - bcast_var = node.inputs[0] - - if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)): - return False - - bcasted_var, *bcast_shape = bcast_var.owner.inputs - - new_unique, *_ = node.op.make_node(bcasted_var).outputs - - old_out = node.outputs[0] - new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype) - return [new_x] - - @register_useless @register_canonicalize @node_rewriter([Unique]) @@ -161,16 +128,3 @@ def local_Unique_second(fgraph, node): old_out = node.outputs[0] new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype) return [new_x] - - -@register_useless -@register_canonicalize -@node_rewriter([BroadcastTo]) -def local_remove_scalar_BroadcastTo(fgraph, node): - bcast_shape = node.inputs[1:] - - if not bcast_shape: - bcasted_var = node.inputs[0] - # If this isn't true, the graph is invalid - assert bcasted_var.ndim == 0 - return [bcasted_var] diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index 73c6e4249c..78abd671b8 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -7,7 +7,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value from pytensor.tensor import extra_ops as at_extra_ops -from pytensor.tensor.type import matrix, vector +from pytensor.tensor.type import matrix from tests.link.jax.test_basic import compare_jax_and_py @@ -63,29 +63,6 @@ def test_extra_ops(): ) -@pytest.mark.parametrize( - "x, shape", - [ - ( - set_test_value( - vector("x"), np.random.random(size=(2,)).astype(config.floatX) - ), - [at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)], - ), - ( - set_test_value( - vector("x"), np.random.random(size=(2,)).astype(config.floatX) - ), - [at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)], - ), - ], -) -def test_BroadcastTo(x, shape): - out = at_extra_ops.broadcast_to(x, shape) - fgraph = FunctionGraph(outputs=[out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), reason="Omnistaging cannot be disabled", diff --git a/tests/link/numba/test_extra_ops.py b/tests/link/numba/test_extra_ops.py index 30b62ba225..36a67cfff0 100644 --- a/tests/link/numba/test_extra_ops.py +++ b/tests/link/numba/test_extra_ops.py @@ -36,41 +36,6 @@ def test_Bartlett(val): ) -@pytest.mark.parametrize( - "x, shape", - [ - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - [set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]], - ), - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - [at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)], - ), - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - at.as_tensor([set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]]), - ), - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - [at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)], - ), - ], -) -def test_BroadcastTo(x, shape): - g = extra_ops.BroadcastTo()(x, shape) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - @pytest.mark.parametrize( "val, axis, mode", [ diff --git a/tests/tensor/rewriting/test_extra_ops.py b/tests/tensor/rewriting/test_extra_ops.py index d0aac80249..15f5870e5b 100644 --- a/tests/tensor/rewriting/test_extra_ops.py +++ b/tests/tensor/rewriting/test_extra_ops.py @@ -8,7 +8,7 @@ from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.tensor.basic import Alloc, alloc, as_tensor_variable, second from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique +from pytensor.tensor.extra_ops import Repeat, Unique, repeat, unique from pytensor.tensor.type import dscalar @@ -103,64 +103,6 @@ def test_local_Unique_Alloc_lift( assert np.array_equal(y_exp_val, y_val) -@pytest.mark.parametrize( - "x_val, axis, new_shape", - [ - (np.array(-10, dtype=np.int64), None, (2, 3)), - (np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)), - ], -) -@pytest.mark.parametrize("return_index", [False]) -@pytest.mark.parametrize("return_counts", [False]) -@pytest.mark.parametrize("return_inverse", [False]) -def test_local_Unique_BroadcastTo( - x_val, axis, new_shape, return_index, return_counts, return_inverse -): - x = as_tensor_variable(x_val).type() - y = unique( - BroadcastTo()(x, tuple(new_shape)), - return_index=return_index, - return_counts=return_counts, - return_inverse=return_inverse, - axis=axis, - ) - - if isinstance(y, list): - y, *_ = y - - # This approach allows us to directly confirm that `x` is in the result. - y_fg = FunctionGraph(outputs=[y], copy_inputs=False) - y_rewritten_fg = rewrite_graph( - y_fg, - clone=False, - include=["canonicalize", "local_Unique_BroadcastTo_lift"], - exclude=["local_Unique_scalar"], - ) - y_rewritten = y_rewritten_fg.outputs[0] - y_rewritten_start = y_rewritten - - assert isinstance(y_rewritten_start.owner.op, Unique) - assert y_rewritten_start.owner.inputs[0] == x - assert not any( - isinstance(node.op, BroadcastTo) for node in y_rewritten_fg.apply_nodes - ) - - default_mode = get_default_mode() - # The rewrite has already been applied to `y_rewritten`, so we can--and - # should--exclude it from the compilation of both our reference, `y`, and - # the rewritten result, `y_rewritten`. - rewrite_mode = default_mode.excluding("local_Unique_BroadcastTo_lift") - y_fn = function([x], [y, y_rewritten], mode=rewrite_mode) - # Make sure that the original `BroadcastTo` is used to compute the - # reference `y` result - assert any( - isinstance(node.op, BroadcastTo) for node in y_fn.maker.fgraph.apply_nodes - ) - - y_exp_val, y_val = y_fn(x_val) - assert np.array_equal(y_exp_val, y_val) - - @pytest.mark.parametrize( "x_val, unique_axis, repeats, repeat_axis", [ @@ -287,16 +229,3 @@ def test_local_Unique_second( y_exp_val, y_val = y_fn(x_val) assert np.array_equal(y_exp_val, y_val) - - -def test_local_remove_scalar_BroadcastTo(): - x = dscalar() - y = BroadcastTo()(x, ()) - - assert isinstance(y.owner.op, BroadcastTo) - - res = rewrite_graph( - y, clone=False, include=["canonicalize", "local_remove_scalar_BroadcastTo"] - ) - - assert res is x diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 4d2c3fec9e..e103567564 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -8,14 +8,12 @@ from pytensor import tensor as at from pytensor.compile.mode import Mode from pytensor.configdefaults import config -from pytensor.graph.basic import Constant, applys_between -from pytensor.graph.replace import clone_replace -from pytensor.graph.rewriting.db import RewriteDatabaseQuery +from pytensor.graph.basic import Constant, applys_between, equal_computations from pytensor.raise_op import Assert +from pytensor.tensor import alloc from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.extra_ops import ( Bartlett, - BroadcastTo, CpuContiguous, CumOp, FillDiagonal, @@ -47,7 +45,6 @@ to_one_hot, unravel_index, ) -from pytensor.tensor.subtensor import AdvancedIncSubtensor from pytensor.tensor.type import ( TensorType, dmatrix, @@ -61,7 +58,6 @@ lscalar, matrix, scalar, - tensor, tensor3, vector, ) @@ -1246,183 +1242,15 @@ def test_broadcast_shape_symbolic_one_symbolic(): assert res_shape[2].data == 3 -class TestBroadcastTo(utt.InferShapeTester): - def setup_method(self): - super().setup_method() - self.op_class = BroadcastTo - self.op = broadcast_to - - def test_avoid_useless_scalars(self): - x = scalar() - y = broadcast_to(x, ()) - assert y is x - - def test_avoid_useless_subtensors(self): - x = scalar() - y = broadcast_to(x, (1, 2)) - # There shouldn't be any unnecessary `Subtensor` operations - # (e.g. from `at.as_tensor((1, 2))[0]`) - assert y.owner.inputs[1].owner is None - assert y.owner.inputs[2].owner is None - - @pytest.mark.parametrize("linker", ["cvm", "py"]) - def test_perform(self, linker): - a = pytensor.shared(np.full((3, 1, 1), 5)) - s_0 = iscalar("s_0") - s_1 = iscalar("s_1") - shape = (s_0, s_1, 1) - - bcast_res = broadcast_to(a, shape) - assert bcast_res.broadcastable == (False, False, True) - - bcast_fn = pytensor.function( - [s_0, s_1], bcast_res, mode=Mode(optimizer=None, linker=linker) - ) - bcast_fn.vm.allow_gc = False - - bcast_at = bcast_fn(3, 4) - bcast_np = np.broadcast_to(5, (3, 4, 1)) - - assert np.array_equal(bcast_at, bcast_np) - - with pytest.raises(ValueError): - bcast_fn(5, 4) - - if linker != "py": - bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0] - bcast_in = bcast_fn.vm.storage_map[a] - bcast_out = bcast_fn.vm.storage_map[bcast_var] - assert np.shares_memory(bcast_out[0], bcast_in[0]) - - def test_make_node_error_handling(self): - with pytest.raises( - ValueError, - match="Broadcast target shape has 1 dims, which is shorter than input with 2 dims", - ): - broadcast_to(at.zeros((3, 4)), (5,)) +def test_broadcast_to(): + x = vector("x") + y1 = scalar(dtype="int64") + y2 = scalar(dtype="int64") - @pytest.mark.skipif( - not config.cxx, reason="G++ not available, so we need to skip this test." + assert equal_computations( + [broadcast_to(x, (y1, y2))], + [alloc(x, y1, y2)], ) - @pytest.mark.parametrize("valid", (True, False)) - def test_memory_leak(self, valid): - import gc - import tracemalloc - - from pytensor.link.c.cvm import CVM - - n = 100_000 - x = pytensor.shared(np.ones((1, n), dtype=np.float64)) - y = broadcast_to(x, (5, n)) - - f = pytensor.function([], y, mode=Mode(optimizer=None, linker="cvm")) - assert isinstance(f.vm, CVM) - - assert len(f.maker.fgraph.apply_nodes) == 2 - assert any( - isinstance(node.op, BroadcastTo) for node in f.maker.fgraph.apply_nodes - ) - - tracemalloc.start() - - blocks_last = None - block_diffs = [] - for i in range(1, 50): - if valid: - x.set_value(np.ones((1, n))) - _ = f() - else: - x.set_value(np.ones((2, n))) - try: - _ = f() - except ValueError: - pass - else: - raise RuntimeError("Should have failed") - _ = gc.collect() - blocks_i, _ = tracemalloc.get_traced_memory() - if blocks_last is not None: - blocks_diff = (blocks_i - blocks_last) // 10**3 - block_diffs.append(blocks_diff) - blocks_last = blocks_i - - tracemalloc.stop() - assert np.all(np.array(block_diffs) <= (0 + 1e-8)) - - @pytest.mark.parametrize( - "fn,input_dims", - [ - [lambda x: broadcast_to(x, (1,)), (1,)], - [lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)], - [lambda x: broadcast_to(x, (6, 2, 5, 3)), (5, 1)], - [lambda x: broadcast_to(x, (6, 2, 1, 3)), (2, 1, 3)], - ], - ) - def test_gradient(self, fn, input_dims): - rng = np.random.default_rng(43) - utt.verify_grad( - fn, - [rng.random(input_dims).astype(config.floatX)], - n_tests=1, - rng=rng, - ) - - def test_infer_shape(self): - rng = np.random.default_rng(43) - a = tensor(dtype=config.floatX, shape=(None, 1, None)) - shape = list(a.shape) - out = self.op(a, shape) - - self._compile_and_check( - [a] + shape, - [out], - [rng.random((2, 1, 3)).astype(config.floatX), 2, 1, 3], - self.op_class, - ) - - a = tensor(dtype=config.floatX, shape=(None, 1, None)) - shape = [iscalar() for i in range(4)] - self._compile_and_check( - [a] + shape, - [self.op(a, shape)], - [rng.random((2, 1, 3)).astype(config.floatX), 6, 2, 5, 3], - self.op_class, - ) - - def test_inplace(self): - """Make sure that in-place optimizations are *not* performed on the output of a ``BroadcastTo``.""" - a = at.zeros((5,)) - d = at.vector("d") - c = at.set_subtensor(a[np.r_[0, 1, 3]], d) - b = broadcast_to(c, (5,)) - q = b[np.r_[0, 1, 3]] - e = at.set_subtensor(q, np.r_[0, 0, 0]) - - opts = RewriteDatabaseQuery(include=["inplace"]) - py_mode = Mode("py", opts) - e_fn = function([d], e, mode=py_mode) - - advincsub_node = e_fn.maker.fgraph.outputs[0].owner - assert isinstance(advincsub_node.op, AdvancedIncSubtensor) - assert isinstance(advincsub_node.inputs[0].owner.op, BroadcastTo) - - assert advincsub_node.op.inplace is False - - def test_rebuild(self): - x = vector(shape=(50,)) - x_test = np.zeros((50,), dtype=config.floatX) - i = 0 - y = broadcast_to(i, x.shape) - assert y.type.shape == (50,) - assert y.shape.eval({x: x_test}) == (50,) - assert y.eval({x: x_test}).shape == (50,) - - x_new = vector(shape=(100,)) - x_new_test = np.zeros((100,), dtype=config.floatX) - y_new = clone_replace(y, {x: x_new}, rebuild_strict=False) - assert y_new.type.shape == (100,) - assert y_new.shape.eval({x_new: x_new_test}) == (100,) - assert y_new.eval({x_new: x_new_test}).shape == (100,) def test_broadcast_arrays():