Skip to content

Commit

Permalink
Incorporate static shape of Alloc input
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 7, 2023
1 parent 6898f74 commit c6b0858
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 28 deletions.
38 changes: 31 additions & 7 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,17 +1432,41 @@ class Alloc(COp):
__props__ = ()

def make_node(self, value, *shape):
v = as_tensor_variable(value)
sh, static_shape = infer_static_shape(shape)
if v.ndim > len(sh):
value = as_tensor_variable(value)
shape, static_shape = infer_static_shape(shape)
if value.ndim > len(shape):
raise TypeError(
"The Alloc value to use has more dimensions"
" than the specified dimensions",
v.ndim,
len(sh),
value.ndim,
len(shape),
)
otype = TensorType(dtype=v.dtype, shape=static_shape)
return Apply(self, [v] + sh, [otype()])

# Combine static shape information from value and shape
combined_static_shape = list(static_shape).copy()
new_dims = len(shape) - value.type.ndim
extended_value_static_shape = (None,) * new_dims + value.type.shape
extended_value_broadcastable = (False,) * new_dims + value.type.broadcastable
for i, (v_bc, v_st, sh_st) in enumerate(
zip(
extended_value_broadcastable,
extended_value_static_shape,
static_shape,
)
):
# If value is not broadcastable and we don't know the target static shape: use value static shape
if (not v_bc) and (sh_st is None):
combined_static_shape[i] = v_st
# Otherwise check if static shapes are compatible
elif (v_st is not None) and (sh_st is not None):
# They must match or if not, the value must be broadcastable
if v_st != sh_st and not v_bc:
raise ValueError(
f"Alloc static input type and target shape are incompatible: {value.type} vs {static_shape}"
)

otype = TensorType(dtype=value.dtype, shape=combined_static_shape)
return Apply(self, [value] + shape, [otype()])

def perform(self, node, inputs, out_):
(out,) = out_
Expand Down
21 changes: 0 additions & 21 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,27 +272,6 @@ class TestLocalCanonicalizeAlloc:
def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed())

def test_inconsistent_constant(self):
x = at.as_tensor(self.rng.standard_normal((3, 7)))
a = at.alloc(x, 6, 7)

assert a.owner and isinstance(a.owner.op, Alloc)

# `local_useless_alloc` should attempt to replace the `Alloc` with an
# `Assert` and fail when the static shape information conflicts.
with pytest.raises(TypeError):
f = function([], a, mode=rewrite_mode)

x = at.as_tensor(self.rng.standard_normal((6, 7)))
a = at.alloc(x, 6, 7)

f = function([], a, mode=rewrite_mode)

# The rewrite should then be applied, and remove Alloc
assert not any(
isinstance(node.op, (Alloc, Assert)) for node in f.maker.fgraph.toposort()
)

def test_inconsistent_shared(self):
# These shapes don't match!
x = shared(self.rng.standard_normal((3, 7)))
Expand Down
16 changes: 16 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,22 @@ def test_rebuild(self, func):
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)

def test_static_shape(self):
x = tensor(shape=(None, 1, 5))
d0 = scalar("d0", dtype=int)
d1 = scalar("d1", dtype=int)
assert at.alloc(x, 3, 1, 5).type.shape == (3, 1, 5)
assert at.alloc(x, 3, 4, 5).type.shape == (3, 4, 5)
assert at.alloc(x, d0, d1, 5).type.shape == (None, None, 5)
assert at.alloc(x, d0, 1, d1).type.shape == (None, 1, 5)

msg = "Alloc static input type and target shape are incompatible"
with pytest.raises(ValueError, match=msg):
at.alloc(x, 3, 1, 1)

with pytest.raises(ValueError, match=msg):
at.alloc(x, 3, 1, 6)


def test_infer_shape():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
Expand Down

0 comments on commit c6b0858

Please sign in to comment.