Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forbid runtime broadcasting in Elemwise #372

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions pytensor/link/jax/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@


@jax_funcify.register(Elemwise)
def jax_funcify_Elemwise(op, **kwargs):
def jax_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
return jax_funcify(scalar_op, **kwargs)
base_fn = jax_funcify(scalar_op, node=node, **kwargs)

def elemwise_fn(*inputs):
# ScalarVariables in JAX are passed as int/float.
# We wrap them in arrays just for the broadcast check
Elemwise._check_runtime_broadcast(node, tuple(map(jnp.asarray, inputs)))
return base_fn(*inputs)

return elemwise_fn


@jax_funcify.register(CAReduce)
Expand Down
1 change: 1 addition & 0 deletions pytensor/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
shape_padaxis,
shape_padleft,
shape_padright,
specify_broadcastable,
specify_shape,
)

Expand Down
36 changes: 21 additions & 15 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from pytensor.scalar.basic import bool as scalar_bool
from pytensor.scalar.basic import identity as scalar_identity
from pytensor.scalar.basic import transfer_type, upcast
from pytensor.tensor import _get_vector_length, as_tensor_variable
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
from pytensor.tensor.type import (
TensorType,
continuous_dtypes,
Expand Down Expand Up @@ -740,9 +740,7 @@ def perform(self, node, inputs, output_storage):
# FIXME: This no longer calls the C implementation!
super().perform(node, inputs, output_storage)

for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
if len(set(dim_shapes) - {1}) > 1:
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
self._check_runtime_broadcast(node, inputs)

ufunc_args = inputs
ufunc_kwargs = {}
Expand Down Expand Up @@ -818,18 +816,26 @@ def perform(self, node, inputs, output_storage):
else:
storage[0] = variable

def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
if len(node.outputs) > 1:
from pytensor.tensor.exceptions import ShapeError

raise ShapeError(
"Multiple outputs are not supported by the default `Elemwise.infer_shape`"
)
@staticmethod
def _check_runtime_broadcast(node, inputs):
for dims_and_bcast in zip(
*[
zip(input.shape, sinput.type.broadcastable)
for input, sinput in zip(inputs, node.inputs)
]
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
"Runtime broadcasting not allowed. "
"At least one input has a distinct dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)

out_shape = pytensor.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could just use this function: https://github.com/pymc-devs/pytensor/blob/main/pytensor/tensor/extra_ops.py#L1465

The make_node method doesn't seem to properly take into account the broadcastable flag either though, maybe that needs an update as well?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to introduce checks or comparison between shapes, which that function does. This allows it to return a more optimized graph like Theano used to by assuming no invalid shapes were provided

The question then is whether we want to refactor that helper to do the same when arrays_are_shapes=False?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the make_node is correct insofar as it uses static shape and it's not possible to have broadcastable=False and shape=1

That one still requires some thinking and would be tackled in a separate PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to introduce checks or comparison between shapes, which that function does. This allows it to return a more optimized graph like Theano used to by assuming no invalid shapes were provided

So we allow undefined behavior in the shapes and in rewrites? I'm not sure I see that much downside with having that check here...

But at least I think we shouldn't have this logic in both places. Maybe the function should have a flag if it should return shape with or without checks?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking we should add a config.assume_shapes_correct flag (default to True) to toggle that behavior in both shape_inference and rewrites that can return simplified cases.

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually that helper works differently in that it expects either shapes or arrays, but here we are combining information from both shapes and arrays so it would require some refactoring. We don't want to simply pass node.inputs since infer_shape wants us to return a graph from ishapes.

I don't know if that is the right place to implement this logic since it is a user facing function. WDYT?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I reverted to using the helper. Things are a bit weird in shape compilation because it will just use the static type shape of the node if that's available. Because the Elemwise make_node assumes valid shapes, the check introduced by infer_shape is only triggered when all dims are None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not much we can do about that then I think without a major rewrite of the shape handling...

from pytensor.tensor.extra_ops import broadcast_shape

# The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s
return [tuple(as_tensor_variable(s) for s in out_shape)]
out_shape = broadcast_shape(*i_shapes, arrays_are_shapes=True)
return [tuple(as_tensor_variable(s) for s in out_shape)] * len(node.outputs)

def _c_all(self, node, nodename, inames, onames, sub):
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
Expand Down Expand Up @@ -1193,7 +1199,7 @@ def c_support_code_apply(self, node, nodename):
return support_code

def c_code_cache_version_apply(self, node):
version = [14] # the version corresponding to the c code in this Op
version = [15] # the version corresponding to the c code in this Op

# now we insert versions for the ops on which we depend...
scalar_node = Apply(
Expand Down
101 changes: 42 additions & 59 deletions pytensor/tensor/elemwise_cgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,10 @@ def make_checks(loop_orders, dtypes, sub):
if index != "x":
# Initialize the variables associated to the jth loop
# jump = stride - adjust
# If the variable has size 1 in that dim, we set the stride to zero to
# emulate broadcasting
jump = f"({var}_stride{index}) - ({adjust})"
init += f"""
{var}_n{index} = PyArray_DIMS({var})[{index}];
{var}_stride{index} = ({var}_n{index} == 1)? 0 : PyArray_STRIDES({var})[{index}] / sizeof({dtype});
{var}_stride{index} = PyArray_STRIDES({var})[{index}] / sizeof({dtype});
{var}_jump{index}_{j} = {jump};
"""
adjust = f"{var}_n{index}*{var}_stride{index}"
Expand All @@ -86,88 +84,73 @@ def make_checks(loop_orders, dtypes, sub):
# This loop builds multiple if conditions to verify that the
# dimensions of the inputs match, and the first one that is true
# raises an informative error message

runtime_broadcast_error_msg = (
"Runtime broadcasting not allowed. "
"One input had a distinct dimension length of 1, but was not marked as broadcastable: "
"(input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld). "
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)

for matches in zip(*loop_orders):
to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"]

# elements of to_compare are pairs ( input_variable_idx, input_variable_dim_idx )
if len(to_compare) < 2:
continue

# Find first dimension size that is != 1
jl, xl = to_compare[-1]
non1size_dim_check = f"""
npy_intp non1size_dim{xl};
non1size_dim{xl} = """
for j, x in to_compare[:-1]:
non1size_dim_check += f"(%(lv{j})s_n{x} != 1) ? %(lv{j})s_n{x} : "
non1size_dim_check += f"%(lv{jl})s_n{xl};"
check += non1size_dim_check

# Check the nonsize1 dims match
# TODO: This is a bit inefficient because we are comparing one dimension against itself
check += f"""
if (non1size_dim{xl} != 1)
{{
"""
for j, x in to_compare:
j0, x0 = to_compare[0]
for j, x in to_compare[1:]:
check += f"""
if ((%(lv{j})s_n{x} != non1size_dim{x}) && (%(lv{j})s_n{x} != 1))
if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x})
{{
if (%(lv{j0})s_n{x0} == 1 || %(lv{j})s_n{x} == 1)
{{
PyErr_Format(PyExc_ValueError, "Input dimension mismatch. One other input has shape[%%i] = %%lld, but input[%%i].shape[%%i] = %%lld.",
{x},
(long long int) non1size_dim{x},
PyErr_Format(PyExc_ValueError, "{runtime_broadcast_error_msg}",
{j0},
{x0},
(long long int) %(lv{j0})s_n{x0},
{j},
{x},
(long long int) %(lv{j})s_n{x}
);
}} else {{
PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
{j0},
{x0},
(long long int) %(lv{j0})s_n{x0},
{j},
{x},
(long long int) %(lv{j})s_n{x}
);
%(fail)s
}}
"""
check += """
}
%(fail)s
}}
"""

return init % sub + check % sub


def compute_broadcast_dimensions(array_name: str, loop_orders, sub) -> str:
"""Create c_code to compute broadcasted dimensions of multiple arrays, arising from
Elemwise operations.
def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str:
"""Create c_code to compute the output dimensions of an Elemwise operation.

The code returned by this function populates the array `array_name`, but does not
initialize it.

TODO: We can decide to either specialize C code even further given the input types
or make it general, regardless of whether static broadcastable information is given
Note: We could specialize C code even further with the known static output shapes
"""
dims_c_code = ""
for i, candidates in enumerate(zip(*loop_orders)):
# TODO: Are candidates always either "x" or "i"? If that's the case we can
# simplify some logic here (e.g., we don't need to track the `idx`).
nonx_candidates = tuple(
(idx, c) for idx, c in enumerate(candidates) if c != "x"
)

# All inputs are known to be broadcastable
if not nonx_candidates:
# Borrow the length of the first non-broadcastable input dimension
for j, candidate in enumerate(candidates):
if candidate != "x":
var = sub[f"lv{int(j)}"]
dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
break
# If none is non-broadcastable, the output dimension has a length of 1
else: # no-break
dims_c_code += f"{array_name}[{i}] = 1;\n"
continue

# There is only one informative source of size
if len(nonx_candidates) == 1:
idx, candidate = nonx_candidates[0]
var = sub[f"lv{int(idx)}"]
dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
continue

# In this case any non-size 1 variable will define the right size
dims_c_code += f"{array_name}[{i}] = "
for idx, candidate in nonx_candidates[:-1]:
var = sub[f"lv{int(idx)}"]
dims_c_code += f"({var}_n{candidate} != 1)? {var}_n{candidate}: "
idx, candidate = nonx_candidates[-1]
var = sub[f"lv{idx}"]
dims_c_code += f"{var}_n{candidate};\n"
return dims_c_code


Expand All @@ -186,7 +169,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
if type.startswith("PYTENSOR_COMPLEX"):
type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX")
nd = len(loop_orders[0])
init_dims = compute_broadcast_dimensions("dims", loop_orders, sub)
init_dims = compute_output_dims_lengths("dims", loop_orders, sub)

# TODO: it would be interesting to allocate the output in such a
# way that its contiguous dimensions match one of the input's
Expand Down Expand Up @@ -359,7 +342,7 @@ def make_reordered_loop(

# Get the (sorted) total number of iterations of each loop
declare_totals = f"int init_totals[{nnested}];\n"
declare_totals += compute_broadcast_dimensions("init_totals", init_loop_orders, sub)
declare_totals += compute_output_dims_lengths("init_totals", init_loop_orders, sub)

# Sort totals to match the new order that was computed by sorting
# the loop vector. One integer variable per loop is declared.
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,7 +1439,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):

_broadcast_assert = Assert(
"Could not broadcast dimensions. Broadcasting is only allowed along "
"axes that have a statically known length 1. Use `specify_shape` to "
"axes that have a statically known length 1. Use `specify_broadcastable` to "
"inform PyTensor of a known shape."
)

Expand Down
6 changes: 6 additions & 0 deletions tests/link/jax/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytensor
import pytensor.tensor as at
from pytensor.compile import get_mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
Expand All @@ -14,6 +15,11 @@
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector
from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_elemwise import TestElemwise


def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("JAX"))


def test_jax_Dimshuffle():
Expand Down
7 changes: 7 additions & 0 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytensor.tensor.inplace as ati
import pytensor.tensor.math as aem
from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.compile.ops import deep_copy_op
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad
Expand All @@ -22,6 +23,7 @@
scalar_my_multi_out,
set_test_value,
)
from tests.tensor.test_elemwise import TestElemwise


rng = np.random.default_rng(42849)
Expand Down Expand Up @@ -119,6 +121,11 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
compare_numba_and_py(out_fg, input_vals)


@pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults")
def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("NUMBA"))


def test_elemwise_speed(benchmark):
x = at.dmatrix("y")
y = at.dvector("z")
Expand Down
7 changes: 1 addition & 6 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,12 +1671,7 @@ def verify_op_count(f, count, cls):
(),
(),
),
pytest.param(
lambda x, y: at.mul(y, at.alloc(1, x)),
(),
(),
marks=pytest.mark.xfail(reason="Not implemented"),
),
(lambda x, y: at.mul(y, at.alloc(1, x)), (), ()),
(lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)),
(lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)),
(
Expand Down
3 changes: 1 addition & 2 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,7 @@ def test_mul_div_cases(self):
((fx / fy) / fx, [fx, fy], [fxv, fyv], 1, "float32"),
((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"),
((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"),
# must broadcast as their is a dimshuffle in the computation
# The broadcast leads to an extra elemwise to check compatibility
# must broadcast as there is a dimshuffle in the computation
((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"),
Expand Down
3 changes: 1 addition & 2 deletions tests/tensor/rewriting/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,11 @@ def test_no_static_shapes(self):
# could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any
# combination of the two.
assert not shape_feature.same_shape(x, o)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o)

@pytest.mark.parametrize(
"y_dim_0",
[2, pytest.param(None, marks=pytest.mark.xfail(reason="Not implemented"))],
[2, None],
)
def test_vector_dim(self, y_dim_0):
x = at.tensor(dtype="floatX", shape=(2, None))
Expand Down
Loading