From 8cbe18e9d281d00e236db1c006dd72a578da3de0 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Mon, 25 Apr 2022 14:48:10 +0200 Subject: [PATCH] Allow broadcasting in Elemwise c_code This removes an inconsistency between Numpy and Aesara broadcasting rules, where a variable dimension with unknown shape was always assumed to be non-broadcastable (i.e., different than 1) --- aesara/scalar/basic.py | 2 + aesara/tensor/elemwise.py | 23 ++++--- aesara/tensor/elemwise_cgen.py | 122 ++++++++++++++++++++++++--------- tests/tensor/test_elemwise.py | 86 +++++++++++++---------- 4 files changed, 155 insertions(+), 78 deletions(-) diff --git a/aesara/scalar/basic.py b/aesara/scalar/basic.py index 0bf12a70a3..36a05c3871 100644 --- a/aesara/scalar/basic.py +++ b/aesara/scalar/basic.py @@ -2307,6 +2307,8 @@ def c_code_contiguous(self, node, name, inputs, outputs, sub): if ( node.inputs[0].type == node.outputs[0].type and node.inputs[1].type == node.outputs[0].type + and None not in node.inputs[0].type.shape + and None not in node.inputs[1].type.shape and # amdlibm 3.0 do not have a float64 version of this SIMD function node.inputs[0].dtype == "float32" diff --git a/aesara/tensor/elemwise.py b/aesara/tensor/elemwise.py index 8228c43a28..10fea56e4c 100644 --- a/aesara/tensor/elemwise.py +++ b/aesara/tensor/elemwise.py @@ -913,7 +913,7 @@ def _c_all(self, node, nodename, inames, onames, sub): checks = cgen.make_checks(orders, idtypes, sub) # Check if all inputs (except broadcasted scalar) are fortran. - # In that case, create an fortran output ndarray. + # In that case, create a fortran output ndarray. z = list(zip(inames, inputs)) alloc_fortran = " && ".join( [ @@ -1071,7 +1071,7 @@ def _c_all(self, node, nodename, inames, onames, sub): # If all inputs and outputs are contiguous # and the scalar op define optimized code for that case - # use it! The scalar_op need to check the broadcast flag himself. + # use it! The scalar_op needs to check the type shapes itself. if ( all(o.ndim >= 1 for o in node.outputs) and @@ -1088,11 +1088,18 @@ def _c_all(self, node, nodename, inames, onames, sub): # compiler to vectorize the code as their won't be as # many ptr and the stride will be hard coded. if all( - [ - io.broadcastable == node.outputs[0].broadcastable - or all(io.broadcastable) - for io in node.inputs + node.outputs - ] + # io.type.shape == node.outputs[1].type.shape + # Elemwise does not specify non-broadcastable shape for the outputs yet + node.outputs[0].type.is_super(io.type) + for io in node.inputs + node.outputs + ) and ( + len(node.inputs) <= 1 + # If either one of the inputs has a `None` shape, we cannot + # assume they will have the same size + or all( + len(set(inp_shape)) == 1 and None not in inp_shape + for inp_shape in zip(*(inp.type.shape for inp in node.inputs)) + ) ): z = onames[0] contig = f""" @@ -1188,7 +1195,7 @@ def c_support_code_apply(self, node, nodename): return support_code def c_code_cache_version_apply(self, node): - version = [13] # the version corresponding to the c code in this Op + version = [14] # the version corresponding to the c code in this Op # now we insert versions for the ops on which we depend... scalar_node = Apply( diff --git a/aesara/tensor/elemwise_cgen.py b/aesara/tensor/elemwise_cgen.py index cfc5abde8c..e4e6c35f70 100644 --- a/aesara/tensor/elemwise_cgen.py +++ b/aesara/tensor/elemwise_cgen.py @@ -66,10 +66,12 @@ def make_checks(loop_orders, dtypes, sub): if index != "x": # Initialize the variables associated to the jth loop # jump = stride - adjust + # If the variable has size 1 in that dim, we set the stride to zero to + # emulate broadcasting jump = f"({var}_stride{index}) - ({adjust})" init += f""" {var}_n{index} = PyArray_DIMS({var})[{index}]; - {var}_stride{index} = PyArray_STRIDES({var})[{index}] / sizeof({dtype}); + {var}_stride{index} = ({var}_n{index} == 1)? 0 : PyArray_STRIDES({var})[{index}] / sizeof({dtype}); {var}_jump{index}_{j} = {jump}; """ adjust = f"{var}_n{index}*{var}_stride{index}" @@ -90,22 +92,40 @@ def make_checks(loop_orders, dtypes, sub): # elements of to_compare are pairs ( input_variable_idx, input_variable_dim_idx ) if len(to_compare) < 2: continue - j0, x0 = to_compare[0] - for (j, x) in to_compare[1:]: - check += f""" - if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x}) + + # Find first dimension size that is != 1 + jl, xl = to_compare[-1] + non1size_dim_check = f""" + npy_intp non1size_dim{xl}; + non1size_dim{xl} = """ + for (j, x) in to_compare[:-1]: + non1size_dim_check += f"(%(lv{j})s_n{x} != 1) ? %(lv{j})s_n{x} : " + non1size_dim_check += f"%(lv{jl})s_n{xl};" + check += non1size_dim_check + + # Check the nonsize1 dims match + # TODO: This is a bit inneficient because we are comparing one dimension against itself + check += f""" + if (non1size_dim{xl} != 1) {{ - PyErr_Format(PyExc_ValueError, "Input dimension mismatch. (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)", - {j0}, - {x0}, - (long long int) %(lv{j0})s_n{x0}, - {j}, - {x}, - (long long int) %(lv{j})s_n{x} - ); - %(fail)s + """ + for (j, x) in to_compare: + check += f""" + if ((%(lv{j})s_n{x} != non1size_dim{x}) && (%(lv{j})s_n{x} != 1)) + {{ + PyErr_Format(PyExc_ValueError, "Input dimension mismatch. One other input has shape[%%i] = %%lld, but input[%%i].shape[%%i] = %%lld.", + {x}, + (long long int) non1size_dim{x}, + {j}, + {x}, + (long long int) %(lv{j})s_n{x} + ); + %(fail)s + }} + """ + check += f""" }} - """ + """ # noqa: F541 return init % sub + check % sub @@ -125,20 +145,41 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): if type.startswith("AESARA_COMPLEX"): type = type.replace("AESARA_COMPLEX", "NPY_COMPLEX") nd = len(loop_orders[0]) - init_dims = "" # For each dimension, the tensors are either all broadcasted, in # which case the output will also be broadcastable (dimension = # 1), or one or more are not broadcasted, in which case the number # of elements of the output in that dimension will be equal to the # number of elements of any of them. + # TODO: We can decide to either specialize C code even further given the input types + # Or make it general, regardless of whether static brodacstable information is given + init_dims = "" for i, candidates in enumerate(zip(*loop_orders)): - for j, candidate in enumerate(candidates): - if candidate != "x": - var = sub[f"lv{int(j)}"] - init_dims += f"dims[{i}] = {var}_n{candidate};\n" - break - else: + # TODO: Are candidates always either "x" or "i"? If that's the case we can + # simplify some logic here. We don't need to track the `idx` + nonx_canditates = tuple( + (idx, c) for idx, c in enumerate(candidates) if c != "x" + ) + + # All inputs are known to be broadcastable + if not nonx_canditates: init_dims += f"dims[{i}] = 1;\n" + continue + + # There is only one informative source of size + if len(nonx_canditates) == 1: + idx, candidate = nonx_canditates[0] + var = sub[f"lv{int(idx)}"] + init_dims += f"dims[{i}] = {var}_n{candidate};\n" + continue + + # In this case any non-size 1 variable will define the right size + init_dims += f"dims[{i}] = " + for (idx, candidate) in nonx_canditates[:-1]: + var = sub[f"lv{int(idx)}"] + init_dims += f"({var}_n{candidate} != 1)? {var}_n{candidate}: " + idx, candidate = nonx_canditates[-1] + var = sub[f"lv{idx}"] + init_dims += f"{var}_n{candidate};\n" # TODO: it would be interesting to allocate the output in such a # way that its contiguous dimensions match one of the input's @@ -316,20 +357,33 @@ def make_reordered_loop( # more are not broadcasted, in which case the number of elements # of any of them will be equal to the number of iterations we have # to do. - totals = [] + # TODO: This considers the outputs dimensions, should those ever matter? + declare_totals = f"int init_totals[{nnested}];\n" for i, candidates in enumerate(zip(*init_loop_orders)): - for j, candidate in enumerate(candidates): - if candidate != "x": - var = sub[f"lv{int(j)}"] - total = f"{var}_n{candidate}" - break - else: - total = "1" - totals.append(total) + nonx_canditates = tuple( + (idx, c) for idx, c in enumerate(candidates) if c != "x" + ) - declare_totals = f""" - int init_totals[{nnested}] = {{{", ".join(totals)}}}; - """ + # All inputs are known to be broadcastable + if not nonx_canditates: + declare_totals += f"init_totals[{i}] = 1;\n" + continue + + # There is only one informative source of size + if len(nonx_canditates) == 1: + idx, candidate = nonx_canditates[0] + var = sub[f"lv{int(idx)}"] + declare_totals += f"init_totals[{i}] = {var}_n{candidate};\n" + continue + + # In this case any non-size 1 variable will define the right size + declare_totals += f"init_totals[{i}] = " + for (idx, candidate) in nonx_canditates[:-1]: + var = sub[f"lv{int(idx)}"] + declare_totals += f"({var}_n{candidate} != 1)? {var}_n{candidate}: " + idx, candidate = nonx_canditates[-1] + var = sub[f"lv{idx}"] + declare_totals += f"{var}_n{candidate};\n" # Sort totals to match the new order that was computed by sorting # the loop vector. One integer variable per loop is declared. diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 56a7beafee..a4b73f100f 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -200,40 +200,58 @@ def rand_cval(self, shp): return np.asarray(np.random.random(shp), dtype=aesara.config.floatX) def with_linker(self, linker, op, type, rand_val): - for xsh, ysh in [ - ((3, 5), (3, 5)), - ((3, 5), (1, 5)), - ((3, 5), (3, 1)), - ((1, 5), (5, 1)), - ((1, 1), (1, 1)), - ((self.openmp_minsize,), (self.openmp_minsize,)), - ( - (self.openmp_minsize_sqrt, self.openmp_minsize_sqrt), - (self.openmp_minsize_sqrt, self.openmp_minsize_sqrt), - ), - ((2, 3, 4, 5), (2, 3, 4, 5)), - ((2, 3, 4, 5), (1, 3, 1, 5)), - ((2, 3, 4, 5), (1, 1, 1, 1)), - ((), ()), - ]: - x = type(aesara.config.floatX, [(entry == 1) for entry in xsh])("x") - y = type(aesara.config.floatX, [(entry == 1) for entry in ysh])("y") - e = op(aes.add)(x, y) - f = make_function(copy(linker).accept(FunctionGraph([x, y], [e]))) - xv = rand_val(xsh) - yv = rand_val(ysh) - zv = xv + yv + for shape_info in ("complete", "only_broadcastable", "none"): + for xsh, ysh in [ + ((3, 5), (3, 5)), + ((3, 5), (1, 5)), + ((3, 5), (3, 1)), + ((1, 5), (5, 1)), + ((1, 1), (1, 1)), + ((self.openmp_minsize,), (self.openmp_minsize,)), + ( + (self.openmp_minsize_sqrt, self.openmp_minsize_sqrt), + (self.openmp_minsize_sqrt, self.openmp_minsize_sqrt), + ), + ((2, 3, 4, 5), (2, 3, 4, 5)), + ((2, 3, 4, 5), (1, 3, 1, 5)), + ((2, 3, 4, 5), (1, 1, 1, 1)), + ((), ()), + ]: + if shape_info == "complete": + x = type(aesara.config.floatX, shape=xsh)("x") + y = type(aesara.config.floatX, shape=ysh)("y") + elif shape_info == "only_broadcastable": + # This condition is here for backwards compatibility, when the only + # type shape provided by Aesara was broadcastable/non-broadcastable + x = type( + aesara.config.floatX, + broadcastable=[(entry == 1) for entry in xsh], + )("x") + y = type( + aesara.config.floatX, + broadcastable=[(entry == 1) for entry in ysh], + )("y") + else: + x = type(aesara.config.floatX, shape=[None for _ in xsh])("x") + y = type(aesara.config.floatX, shape=[None for _ in ysh])("y") + e = op(aes.add)(x, y) + f = make_function(copy(linker).accept(FunctionGraph([x, y], [e]))) + xv = rand_val(xsh) + yv = rand_val(ysh) + zv = xv + yv - unittest_tools.assert_allclose(f(xv, yv), zv) + unittest_tools.assert_allclose(f(xv, yv), zv) - # test Elemwise.infer_shape - # the Shape op don't implement c_code! - if isinstance(linker, PerformLinker): - x = type(aesara.config.floatX, [(entry == 1) for entry in xsh])("x") - y = type(aesara.config.floatX, [(entry == 1) for entry in ysh])("y") - e = op(aes.add)(x, y) - f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape]))) - assert tuple(f(xv, yv)) == tuple(zv.shape) + # test Elemwise.infer_shape + # the Shape op don't implement c_code! + if isinstance(linker, PerformLinker): + x = type(aesara.config.floatX, [(entry == 1) for entry in xsh])("x") + y = type(aesara.config.floatX, [(entry == 1) for entry in ysh])("y") + e = op(aes.add)(x, y) + f = make_function( + copy(linker).accept(FunctionGraph([x, y], [e.shape])) + ) + assert tuple(f(xv, yv)) == tuple(zv.shape) def with_linker_inplace(self, linker, op, type, rand_val): for xsh, ysh in [ @@ -740,10 +758,6 @@ def check_input_dimensions_match(self, mode): def test_input_dimensions_match_python(self): self.check_input_dimensions_match(Mode(linker="py")) - @pytest.mark.xfail( - reason="Elemwise C implementation does not broadcast parameters", - exception=ValueError, - ) @pytest.mark.skipif( not aesara.config.cxx, reason="G++ not available, so we need to skip this test." )