From 3c1d876079ffe3f17a0d17a669141ed8a6419110 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 4 Jul 2023 17:02:40 +0200 Subject: [PATCH] Forbid runtime broadcasting in Elemwise --- pytensor/link/jax/dispatch/elemwise.py | 10 ++- pytensor/tensor/__init__.py | 1 + pytensor/tensor/elemwise.py | 53 +++++++++---- pytensor/tensor/elemwise_cgen.py | 101 ++++++++++--------------- tests/link/jax/test_elemwise.py | 6 ++ tests/tensor/test_elemwise.py | 45 ++++++----- 6 files changed, 121 insertions(+), 95 deletions(-) diff --git a/pytensor/link/jax/dispatch/elemwise.py b/pytensor/link/jax/dispatch/elemwise.py index 39ef836b6a..8d2936608d 100644 --- a/pytensor/link/jax/dispatch/elemwise.py +++ b/pytensor/link/jax/dispatch/elemwise.py @@ -7,9 +7,15 @@ @jax_funcify.register(Elemwise) -def jax_funcify_Elemwise(op, **kwargs): +def jax_funcify_Elemwise(op, node, **kwargs): scalar_op = op.scalar_op - return jax_funcify(scalar_op, **kwargs) + base_fn = jax_funcify(scalar_op, node, **kwargs) + + def elemwise_fn(*inputs): + Elemwise._check_runtime_broadcast(node, inputs) + return base_fn(*inputs) + + return elemwise_fn @jax_funcify.register(CAReduce) diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 2e84db13e9..dfe74b5b2f 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -132,6 +132,7 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int: shape_padaxis, shape_padleft, shape_padright, + specify_broadcastable, specify_shape, ) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 6d19579030..e05e5fb687 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -6,7 +6,7 @@ import pytensor.tensor.basic from pytensor.configdefaults import config from pytensor.gradient import DisconnectedType -from pytensor.graph.basic import Apply +from pytensor.graph.basic import Apply, Constant from pytensor.graph.null_type import NullType from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.basic import failure_code @@ -19,9 +19,9 @@ from pytensor.scalar.basic import bool as scalar_bool from pytensor.scalar.basic import identity as scalar_identity from pytensor.scalar.basic import transfer_type, upcast -from pytensor.tensor import _get_vector_length, as_tensor_variable from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import get_vector_length +from pytensor.tensor.basic import _get_vector_length, as_tensor_variable from pytensor.tensor.type import ( TensorType, continuous_dtypes, @@ -737,9 +737,7 @@ def perform(self, node, inputs, output_storage): # FIXME: This no longer calls the C implementation! super().perform(node, inputs, output_storage) - for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))): - if len(set(dim_shapes) - {1}) > 1: - raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}") + self._check_runtime_broadcast(node, inputs) ufunc_args = inputs ufunc_kwargs = {} @@ -815,18 +813,41 @@ def perform(self, node, inputs, output_storage): else: storage[0] = variable - def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]: - if len(node.outputs) > 1: - from pytensor.tensor.exceptions import ShapeError - - raise ShapeError( - "Multiple outputs are not supported by the default `Elemwise.infer_shape`" - ) + @staticmethod + def _check_runtime_broadcast(node, inputs): + for dims_and_bcast in zip( + *[ + zip(input.shape, sinput.type.broadcastable) + for input, sinput in zip(inputs, node.inputs) + ] + ): + if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: + raise ValueError( + "Runtime broadcasting not allowed. " + "At least one input has a distinct dimension length of 1, but was not marked as broadcastable.\n" + "If broadcasting was intended, use `specify_broadcastable` on the relevant input." + ) + raise ValueError - out_shape = pytensor.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True) + def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]: + one = pytensor.tensor.basic.constant(1, dtype="int64") + output_shape = [] + for dim, broadcastable in enumerate(node.outputs[0].type.broadcastable): + out_dim_length = one + if not broadcastable: + # There must be some input that is not broadcastable in this dim + for inp_shape, inp_var in zip(i_shapes, node.inputs): + if not inp_var.type.broadcastable[dim]: + # Give preference to constant dims + if isinstance(inp_shape[dim], Constant): + out_dim_length = inp_shape[dim] + break + # If we haven't yet seen a non-broadcastable dim, use this one + if out_dim_length is one: + out_dim_length = inp_shape[dim] + output_shape.append(as_tensor_variable(out_dim_length, dtype="int64")) - # The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s - return [tuple(as_tensor_variable(s) for s in out_shape)] + return [tuple(output_shape)] * len(node.outputs) def _c_all(self, node, nodename, inames, onames, sub): # Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code` @@ -1190,7 +1211,7 @@ def c_support_code_apply(self, node, nodename): return support_code def c_code_cache_version_apply(self, node): - version = [14] # the version corresponding to the c code in this Op + version = [15] # the version corresponding to the c code in this Op # now we insert versions for the ops on which we depend... scalar_node = Apply( diff --git a/pytensor/tensor/elemwise_cgen.py b/pytensor/tensor/elemwise_cgen.py index 18106f082d..5edfa884e8 100644 --- a/pytensor/tensor/elemwise_cgen.py +++ b/pytensor/tensor/elemwise_cgen.py @@ -66,12 +66,10 @@ def make_checks(loop_orders, dtypes, sub): if index != "x": # Initialize the variables associated to the jth loop # jump = stride - adjust - # If the variable has size 1 in that dim, we set the stride to zero to - # emulate broadcasting jump = f"({var}_stride{index}) - ({adjust})" init += f""" {var}_n{index} = PyArray_DIMS({var})[{index}]; - {var}_stride{index} = ({var}_n{index} == 1)? 0 : PyArray_STRIDES({var})[{index}] / sizeof({dtype}); + {var}_stride{index} = PyArray_STRIDES({var})[{index}] / sizeof({dtype}); {var}_jump{index}_{j} = {jump}; """ adjust = f"{var}_n{index}*{var}_stride{index}" @@ -86,6 +84,14 @@ def make_checks(loop_orders, dtypes, sub): # This loop builds multiple if conditions to verify that the # dimensions of the inputs match, and the first one that is true # raises an informative error message + + runtime_broadcast_error_msg = ( + "Runtime broadcasting not allowed. " + "One input had a distinct dimension length of 1, but was not marked as broadcastable: " + "(input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld). " + "If broadcasting was intended, use `specify_broadcastable` on the relevant input." + ) + for matches in zip(*loop_orders): to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"] @@ -93,81 +99,58 @@ def make_checks(loop_orders, dtypes, sub): if len(to_compare) < 2: continue - # Find first dimension size that is != 1 - jl, xl = to_compare[-1] - non1size_dim_check = f""" - npy_intp non1size_dim{xl}; - non1size_dim{xl} = """ - for j, x in to_compare[:-1]: - non1size_dim_check += f"(%(lv{j})s_n{x} != 1) ? %(lv{j})s_n{x} : " - non1size_dim_check += f"%(lv{jl})s_n{xl};" - check += non1size_dim_check - - # Check the nonsize1 dims match - # TODO: This is a bit inefficient because we are comparing one dimension against itself - check += f""" - if (non1size_dim{xl} != 1) - {{ - """ - for j, x in to_compare: + j0, x0 = to_compare[0] + for j, x in to_compare[1:]: check += f""" - if ((%(lv{j})s_n{x} != non1size_dim{x}) && (%(lv{j})s_n{x} != 1)) + if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x}) + {{ + if (%(lv{j0})s_n{x0} == 1 || %(lv{j})s_n{x} == 1) {{ - PyErr_Format(PyExc_ValueError, "Input dimension mismatch. One other input has shape[%%i] = %%lld, but input[%%i].shape[%%i] = %%lld.", - {x}, - (long long int) non1size_dim{x}, + PyErr_Format(PyExc_ValueError, "{runtime_broadcast_error_msg}", + {j0}, + {x0}, + (long long int) %(lv{j0})s_n{x0}, + {j}, + {x}, + (long long int) %(lv{j})s_n{x} + ); + }} else {{ + PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)", + {j0}, + {x0}, + (long long int) %(lv{j0})s_n{x0}, {j}, {x}, (long long int) %(lv{j})s_n{x} ); - %(fail)s }} - """ - check += """ - } + %(fail)s + }} """ return init % sub + check % sub -def compute_broadcast_dimensions(array_name: str, loop_orders, sub) -> str: - """Create c_code to compute broadcasted dimensions of multiple arrays, arising from - Elemwise operations. +def compute_outputs_dims(array_name: str, loop_orders, sub) -> str: + """Create c_code to compute the output dimensions of an Elemwise operation. The code returned by this function populates the array `array_name`, but does not initialize it. - TODO: We can decide to either specialize C code even further given the input types - or make it general, regardless of whether static broadcastable information is given + Note: We could specialize C code even further with the known static output shapes """ dims_c_code = "" for i, candidates in enumerate(zip(*loop_orders)): - # TODO: Are candidates always either "x" or "i"? If that's the case we can - # simplify some logic here (e.g., we don't need to track the `idx`). - nonx_candidates = tuple( - (idx, c) for idx, c in enumerate(candidates) if c != "x" - ) - - # All inputs are known to be broadcastable - if not nonx_candidates: + # Borrow the length of the first non-broadcastable input dimension + for j, candidate in enumerate(candidates): + if candidate != "x": + var = sub[f"lv{int(j)}"] + dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n" + break + # If none is non-broadcastable, the output dimension has a length of 1 + else: # no-break dims_c_code += f"{array_name}[{i}] = 1;\n" - continue - - # There is only one informative source of size - if len(nonx_candidates) == 1: - idx, candidate = nonx_candidates[0] - var = sub[f"lv{int(idx)}"] - dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n" - continue - # In this case any non-size 1 variable will define the right size - dims_c_code += f"{array_name}[{i}] = " - for idx, candidate in nonx_candidates[:-1]: - var = sub[f"lv{int(idx)}"] - dims_c_code += f"({var}_n{candidate} != 1)? {var}_n{candidate}: " - idx, candidate = nonx_candidates[-1] - var = sub[f"lv{idx}"] - dims_c_code += f"{var}_n{candidate};\n" return dims_c_code @@ -186,7 +169,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): if type.startswith("PYTENSOR_COMPLEX"): type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX") nd = len(loop_orders[0]) - init_dims = compute_broadcast_dimensions("dims", loop_orders, sub) + init_dims = compute_outputs_dims("dims", loop_orders, sub) # TODO: it would be interesting to allocate the output in such a # way that its contiguous dimensions match one of the input's @@ -359,7 +342,7 @@ def make_reordered_loop( # Get the (sorted) total number of iterations of each loop declare_totals = f"int init_totals[{nnested}];\n" - declare_totals += compute_broadcast_dimensions("init_totals", init_loop_orders, sub) + declare_totals += compute_outputs_dims("init_totals", init_loop_orders, sub) # Sort totals to match the new order that was computed by sorting # the loop vector. One integer variable per loop is declared. diff --git a/tests/link/jax/test_elemwise.py b/tests/link/jax/test_elemwise.py index 0f903a33b2..89dc04d6d4 100644 --- a/tests/link/jax/test_elemwise.py +++ b/tests/link/jax/test_elemwise.py @@ -4,6 +4,7 @@ import pytensor import pytensor.tensor as at +from pytensor.compile import get_mode from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value @@ -14,6 +15,11 @@ from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.type import matrix, tensor, vector from tests.link.jax.test_basic import compare_jax_and_py +from tests.tensor.test_elemwise import TestElemwise + + +def test_elemwise_runtime_broadcast_error(): + TestElemwise.check_runtime_broadcast_error(get_mode("JAX")) def test_jax_Dimshuffle(): diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 40e7db879c..9ef822dfaf 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -18,7 +18,6 @@ from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import second from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.exceptions import ShapeError from pytensor.tensor.math import all as at_all from pytensor.tensor.math import any as at_any from pytensor.tensor.math import exp @@ -769,10 +768,9 @@ def test_input_dimensions_overflow(self): g = pytensor.function([a, b, c, d, e, f], s, mode=Mode(linker="py")) g(*[np.zeros(2**11, config.floatX) for i in range(6)]) - def check_input_dimensions_match(self, mode): - """Make sure that our input validation works correctly and doesn't - throw erroneous broadcast-based errors. - """ + @staticmethod + def check_runtime_broadcast_error(mode): + """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules.""" x_v = matrix("x") m_v = vector("m") @@ -782,19 +780,18 @@ def check_input_dimensions_match(self, mode): z_v = x_v - m_v f = pytensor.function([x_v, m_v], z_v, mode=mode) - res = f(x, m) + with pytest.raises(ValueError, match="Runtime broadcasting not allowe"): + f(x, m) - assert np.array_equal(res, x - m) - - def test_input_dimensions_match_python(self): - self.check_input_dimensions_match(Mode(linker="py")) + def test_runtime_broadcast_error_python(self): + self.check_runtime_broadcast_error(Mode(linker="py")) @pytest.mark.skipif( not pytensor.config.cxx, reason="G++ not available, so we need to skip this test.", ) - def test_input_dimensions_match_c(self): - self.check_input_dimensions_match(Mode(linker="c")) + def test_runtime_broadcast_error_c(self): + self.check_runtime_broadcast_error(Mode(linker="c")) def test_str(self): op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None) @@ -819,7 +816,7 @@ def test_partial_static_shape_info(self): assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1 assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1 - def test_multi_output(self): + def test_infer_shape_multi_output(self): class CustomElemwise(Elemwise): def make_node(self, *args): res = super().make_node(*args) @@ -833,14 +830,26 @@ def make_node(self, *args): ], ) - z_1, z_2 = CustomElemwise(aes.add)( - as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(1)) - ) + custom_elemwise = CustomElemwise(aes.add) + z_1, z_2 = custom_elemwise( + as_tensor_variable(np.eye(1)), + as_tensor_variable(np.eye(1)), + ) in_1_shape = (aes.constant(1), aes.constant(1)) + outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) + for out in outs: + assert out[0].eval() == 1 + assert out[1].eval() == 1 - with pytest.raises(ShapeError): - z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) + z_1, z_2 = custom_elemwise( + as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(3)) + ) + in_2_shape = (aes.constant(3), aes.constant(3)) + outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_2_shape]) + for out in outs: + assert out[0].eval() == 3 + assert out[1].eval() == 3 def test_shape_types(self): x = tensor(dtype=np.float64, shape=(None, 1))