-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use static-only broadcasting rules to compute shape of broadcasting #345
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,9 +1,7 @@ | ||||||
from collections.abc import Collection | ||||||
from functools import reduce | ||||||
from typing import Iterable, Set, Tuple, Union | ||||||
|
||||||
import numpy as np | ||||||
import numpy.core.numeric | ||||||
from numpy.core.multiarray import normalize_axis_index | ||||||
|
||||||
import pytensor | ||||||
|
@@ -14,7 +12,7 @@ | |||||
disconnected_type, | ||||||
grad_undefined, | ||||||
) | ||||||
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations | ||||||
from pytensor.graph.basic import Apply, Constant, Variable | ||||||
from pytensor.graph.op import Op | ||||||
from pytensor.link.c.op import COp | ||||||
from pytensor.link.c.params_type import ParamsType | ||||||
|
@@ -23,12 +21,12 @@ | |||||
from pytensor.raise_op import Assert | ||||||
from pytensor.scalar import int32 as int_t | ||||||
from pytensor.scalar import upcast | ||||||
from pytensor.scalar.basic import Composite | ||||||
from pytensor.tensor import basic as at | ||||||
from pytensor.tensor import get_vector_length | ||||||
from pytensor.tensor.exceptions import NotScalarConstantError | ||||||
from pytensor.tensor.math import abs as at_abs | ||||||
from pytensor.tensor.math import all as at_all | ||||||
from pytensor.tensor.math import all as pt_all | ||||||
from pytensor.tensor.math import eq as pt_eq | ||||||
from pytensor.tensor.math import ge, lt, maximum, minimum, prod | ||||||
from pytensor.tensor.math import sum as at_sum | ||||||
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor | ||||||
|
@@ -536,7 +534,7 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False): | |||||
|
||||||
if assert_nonneg: | ||||||
assert_op = Assert("Input to bincount has negative values!") | ||||||
x = assert_op(x, at_all(x >= 0)) | ||||||
x = assert_op(x, pt_all(x >= 0)) | ||||||
|
||||||
max_value = at.cast(x.max() + 1, "int64") | ||||||
|
||||||
|
@@ -1510,8 +1508,8 @@ def broadcast_shape_iter( | |||||
result_dims = [] | ||||||
|
||||||
for dim_shapes in zip(*array_shapes): | ||||||
# Get the shapes in this dimension that are not definitively | ||||||
# broadcastable (i.e. not symbolically known to be broadcastable) | ||||||
# Get the shapes in this dimension that are not broadcastable | ||||||
# (i.e. not symbolically known to be broadcastable) | ||||||
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] | ||||||
|
||||||
if len(maybe_non_bcast_shapes) == 0: | ||||||
|
@@ -1532,97 +1530,40 @@ def broadcast_shape_iter( | |||||
nonconst_nb_shapes.add(shape) | ||||||
|
||||||
if len(const_nb_shapes) > 1: | ||||||
raise ValueError("Could not broadcast dimensions") | ||||||
elif len(const_nb_shapes) == 1: | ||||||
(const_nb_shape,) = const_nb_shapes | ||||||
|
||||||
assert const_nb_shape != 1 | ||||||
|
||||||
const_nt_shape_var = pytensor.scalar.ScalarConstant( | ||||||
pytensor.scalar.int64, const_nb_shape | ||||||
raise ValueError( | ||||||
f"Could not broadcast dimensions. Incompatible shapes were {array_shapes}." | ||||||
) | ||||||
|
||||||
if len(nonconst_nb_shapes) > 0: | ||||||
# All the potential non-broadcast shapes need to either | ||||||
# be broadcastable or equal to the one non-broadcastable | ||||||
# constant `const_nt_shape_var`. | ||||||
assert_dim = Assert("Could not broadcast dimensions") | ||||||
|
||||||
scalar_nonconst_nb_shapes = [ | ||||||
at.scalar_from_tensor(s) | ||||||
if isinstance(s.type, TensorType) | ||||||
else s | ||||||
for s in nonconst_nb_shapes | ||||||
] | ||||||
|
||||||
dummy_nonconst_nb_shapes = [ | ||||||
aes.get_scalar_type(dtype=v.dtype)() | ||||||
for v in scalar_nonconst_nb_shapes | ||||||
] | ||||||
assert_cond = reduce( | ||||||
aes.and_, | ||||||
( | ||||||
aes.or_( | ||||||
aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var) | ||||||
) | ||||||
for nbv in dummy_nonconst_nb_shapes | ||||||
), | ||||||
) | ||||||
assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond]) | ||||||
|
||||||
bcast_dim = assert_dim( | ||||||
const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes) | ||||||
) | ||||||
else: | ||||||
bcast_dim = const_nt_shape_var | ||||||
assert_op = Assert("Could not dynamically broadcast dimensions.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe create this Op once at the module level? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made it even more specific: |
||||||
if len(const_nb_shapes) == 1: | ||||||
(first_length,) = const_nb_shapes | ||||||
other_lengths = nonconst_nb_shapes | ||||||
first_length = aes.as_scalar(first_length) | ||||||
else: | ||||||
# There are no constant, non-broadcastable shapes in this | ||||||
# dimension. | ||||||
|
||||||
all_dims_equal = all( | ||||||
# TODO FIXME: This is a largely deficient, and expensive, means | ||||||
# of comparing graphs (and especially shapes) | ||||||
equal_computations([maybe_non_bcast_shapes[0]], [dim]) | ||||||
for dim in maybe_non_bcast_shapes[1:] | ||||||
) | ||||||
|
||||||
if all_dims_equal: | ||||||
result_dims.append(maybe_non_bcast_shapes[0]) | ||||||
continue | ||||||
|
||||||
scalar_maybe_non_bcast_shapes = [ | ||||||
at.scalar_from_tensor(s) if isinstance(s.type, TensorType) else s | ||||||
for s in maybe_non_bcast_shapes | ||||||
] | ||||||
dummy_maybe_non_bcast_shapes = [ | ||||||
aes.get_scalar_type(dtype=v.dtype)() | ||||||
for v in scalar_maybe_non_bcast_shapes | ||||||
] | ||||||
non_bcast_vec = [ | ||||||
aes.switch(aes.eq(nbv, 1), -one_at, nbv) | ||||||
for nbv in dummy_maybe_non_bcast_shapes | ||||||
] | ||||||
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec)) | ||||||
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max]) | ||||||
|
||||||
dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes) | ||||||
|
||||||
assert_dim = Assert("Could not broadcast dimensions") | ||||||
assert_cond = reduce( | ||||||
aes.and_, | ||||||
( | ||||||
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max)) | ||||||
for nbv in non_bcast_vec | ||||||
), | ||||||
) | ||||||
assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond]) | ||||||
|
||||||
bcast_dim = assert_dim( | ||||||
dim_max_op(*scalar_maybe_non_bcast_shapes), | ||||||
assert_cond_op(*scalar_maybe_non_bcast_shapes), | ||||||
first_length, *other_lengths = nonconst_nb_shapes | ||||||
|
||||||
if len(other_lengths) == 0: | ||||||
result_dims.append(first_length) | ||||||
continue | ||||||
|
||||||
# Add assert that all remaining shapes are equal | ||||||
use_scalars = False | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's remove the |
||||||
if use_scalars: | ||||||
condition = None | ||||||
for other in other_lengths: | ||||||
cond = aes.eq(first_length, other) | ||||||
if condition is None: | ||||||
condition = cond | ||||||
else: | ||||||
condition = aes.and_(condition, cond) | ||||||
else: | ||||||
condition = pt_all( | ||||||
[pt_eq(first_length, other) for other in other_lengths] | ||||||
) | ||||||
|
||||||
result_dims.append(bcast_dim) | ||||||
if condition is None: | ||||||
result_dims.append(first_length) | ||||||
else: | ||||||
result_dims.append(assert_op(first_length, condition)) | ||||||
|
||||||
return tuple(result_dims) | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the "maybe"?