Skip to content

Commit

Permalink
Incorporate static shape of Alloc input
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 14, 2023
1 parent 6a33e7c commit 2a3adbe
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 28 deletions.
30 changes: 23 additions & 7 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,17 +1432,33 @@ class Alloc(COp):
__props__ = ()

def make_node(self, value, *shape):
v = as_tensor_variable(value)
sh, static_shape = infer_static_shape(shape)
if v.ndim > len(sh):
value = as_tensor_variable(value)
shape, static_shape = infer_static_shape(shape)
if value.ndim > len(shape):
raise TypeError(
"The Alloc value to use has more dimensions"
" than the specified dimensions",
v.ndim,
len(sh),
value.ndim,
len(shape),
)
otype = TensorType(dtype=v.dtype, shape=static_shape)
return Apply(self, [v] + sh, [otype()])

# Combine static shape information from value and shape
combined_static_shape = list(static_shape).copy()
new_dims = len(shape) - value.type.ndim
extended_value_static_shape = (None,) * new_dims + (value.type.shape)
for i, (v_st, sh_st) in enumerate(
zip(extended_value_static_shape, static_shape)
):
if (v_st not in (1, None)) and (sh_st is None):
combined_static_shape[i] = v_st
elif (v_st is not None) and (sh_st is not None):
if v_st != sh_st and v_st != 1:
raise ValueError(
f"Alloc static input shape and target shape are incompatible: {value.type.shape} vs {static_shape}"
)

otype = TensorType(dtype=value.dtype, shape=combined_static_shape)
return Apply(self, [value] + shape, [otype()])

def perform(self, node, inputs, out_):
(out,) = out_
Expand Down
21 changes: 0 additions & 21 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,27 +272,6 @@ class TestLocalCanonicalizeAlloc:
def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed())

def test_inconsistent_constant(self):
x = at.as_tensor(self.rng.standard_normal((3, 7)))
a = at.alloc(x, 6, 7)

assert a.owner and isinstance(a.owner.op, Alloc)

# `local_useless_alloc` should attempt to replace the `Alloc` with an
# `Assert` and fail when the static shape information conflicts.
with pytest.raises(TypeError):
f = function([], a, mode=rewrite_mode)

x = at.as_tensor(self.rng.standard_normal((6, 7)))
a = at.alloc(x, 6, 7)

f = function([], a, mode=rewrite_mode)

# The rewrite should then be applied, and remove Alloc
assert not any(
isinstance(node.op, (Alloc, Assert)) for node in f.maker.fgraph.toposort()
)

def test_inconsistent_shared(self):
# These shapes don't match!
x = shared(self.rng.standard_normal((3, 7)))
Expand Down
16 changes: 16 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,22 @@ def test_rebuild(self, func):
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
assert y_new.eval({x_new: x_new_test}).shape == (100,)

def test_static_shape(self):
x = tensor(shape=(None, 1, 5))
d0 = scalar("d0", dtype=int)
d1 = scalar("d1", dtype=int)
assert at.alloc(x, 3, 1, 5).type.shape == (3, 1, 5)
assert at.alloc(x, 3, 4, 5).type.shape == (3, 4, 5)
assert at.alloc(x, d0, d1, 5).type.shape == (None, None, 5)
assert at.alloc(x, d0, 1, d1).type.shape == (None, 1, 5)

msg = "Alloc static input shape and target shape are incompatible"
with pytest.raises(ValueError, match=msg):
at.alloc(x, 3, 1, 1)

with pytest.raises(ValueError, match=msg):
at.alloc(x, 3, 1, 6)


def test_infer_shape():
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):
Expand Down

0 comments on commit 2a3adbe

Please sign in to comment.