Skip to content

Commit

Permalink
Make Split a view_op
Browse files Browse the repository at this point in the history
This allows the outputs to be views of the inputs. The Python and Numba implementation do that, but the C still performs a copy
  • Loading branch information
ricardoV94 committed Jun 15, 2023
1 parent 91966e8 commit f4536c3
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 48 deletions.
4 changes: 3 additions & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1903,6 +1903,7 @@ class Split(COp):
b == [3, 4]
c == [5]
TODO: Don't make a copy in C impl
"""

len_splits = None
Expand All @@ -1913,6 +1914,7 @@ class Split(COp):

def __init__(self, len_splits):
self.len_splits = int(len_splits)
self.view_map = {i: [0] for i in range(self.len_splits)}

def __str__(self):
return f"{self.__class__.__name__ }{{{self.len_splits}}}"
Expand Down Expand Up @@ -1949,7 +1951,7 @@ def perform(self, node, inputs, outputs):

split_outs = np.split(x, np.cumsum(splits[:-1]), axis=axis)
for i, out in enumerate(split_outs):
outputs[i][0] = out.copy()
outputs[i][0] = out

def infer_shape(self, fgraph, node, in_shapes):
axis = node.inputs[1]
Expand Down
27 changes: 26 additions & 1 deletion tests/link/numba/test_tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import pytensor.scalar as aes
import pytensor.tensor as at
import pytensor.tensor.basic as atb
from pytensor import config
from pytensor import config, function
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar import Add
from pytensor.tensor.shape import Unbroadcast
from tests.link.numba.test_basic import (
compare_numba_and_py,
Expand Down Expand Up @@ -332,6 +333,30 @@ def test_Split(n_splits, axis, values, sizes):
)


def test_Split_view():
# https://github.com/pymc-devs/pytensor/issues/343
x1 = at.matrix("x1")
x2 = at.matrix("x2", shape=(None, 1))
v = at.vector("v", shape=(2,), dtype=int)
out = at.split(x1, v, n_splits=2, axis=1)[0] + x2

fn = function([x1, x2, v], out, mode="NUMBA")
# Check that the addition of split[0] and x2 is not in place
add_op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(add_op.scalar_op, Add)
assert not add_op.inplace_pattern

rng = np.random.default_rng(123)
test_x1 = rng.normal(size=(2, 2))
test_x2 = rng.normal(size=(2, 1))
test_v = np.array([1, 1])

np.testing.assert_allclose(
fn(test_x1, test_x2, test_v).copy(),
fn(test_x1, test_x2, test_v).copy(),
)


@pytest.mark.parametrize(
"val, offset",
[
Expand Down
25 changes: 19 additions & 6 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,15 +1372,28 @@ def test_local_useless_split():

f_rewritten(np.random.random((4, 4)).astype(config.floatX), [4])
f_not_rewritten(np.random.random((4, 4)).astype(config.floatX), [1, 2, 1])
graph_rewritten = f_rewritten.maker.fgraph.toposort()
graph_not_rewritten = f_not_rewritten.maker.fgraph.toposort()
graph_rewritten = f_rewritten.maker.fgraph
graph_not_rewritten = f_not_rewritten.maker.fgraph

assert isinstance(graph_rewritten[-1].op, DeepCopyOp)
assert len(graph_not_rewritten) == 1
assert isinstance(graph_not_rewritten[0].op, Split)
assert all(
isinstance(out.owner.op, DeepCopyOp) for out in graph_not_rewritten.outputs
)
assert all(isinstance(out.owner.op, DeepCopyOp) for out in graph_rewritten.outputs)

assert sum(isinstance(node.op, Split) for node in graph_rewritten.apply_nodes) == 0
assert (
sum(isinstance(node.op, Split) for node in graph_not_rewritten.apply_nodes) == 1
)

assert sum(isinstance(node.op, Assert) for node in graph_rewritten.apply_nodes) == 2
assert (
sum(isinstance(node.op, Assert) for node in graph_not_rewritten.apply_nodes)
== 0
)

# The DeepCopy Ops don't have traces, so we can't check "all"
assert check_stack_trace(f_rewritten, ops_to_check=[Assert])
assert check_stack_trace(f_not_rewritten, ops_to_check="all")
assert check_stack_trace(f_not_rewritten, ops_to_check=[Split])


@pytest.mark.parametrize("i", list(range(1, 4)))
Expand Down
100 changes: 60 additions & 40 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pytensor.tensor.math as tm
from pytensor import compile, config, function, shared
from pytensor.compile.io import In, Out
from pytensor.compile.mode import get_default_mode
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.compile.ops import DeepCopyOp
from pytensor.gradient import grad, hessian
from pytensor.graph.basic import Apply
Expand Down Expand Up @@ -2002,45 +2002,65 @@ def test_split_static_shape(self):
y = Split(2)(x, 0, [s, 5 - s])[0]
assert y.type.shape == (None,)


def test_join_inplace():
# Test join to work inplace.
#
# This function tests the case when several elements are passed to the
# join function but all except one of them are empty. In this case join
# should work inplace and the output should be the view of the non-empty
# element.
s = lscalar()
x = vector("x")
z = at.zeros((s,))

join = Join(view=0)
c = join(0, x, z, z)

f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True))

data = np.array([3, 4, 5], dtype=config.floatX)

if config.mode not in ["DebugMode", "DEBUG_MODE"]:
assert f(data, 0) is data
assert np.allclose(f(data, 0), [3, 4, 5])


def test_join_oneInput():
# Test join when only 1 input is given.
#
# This functions tests the case when concatenate is called
# on an array of tensors but the array has only one element.
# In this case, we would like to avoid the computational
# overhead of concatenation of one element.
x_0 = fmatrix()
x_1 = fmatrix()
x_2 = fvector()
join_0 = at.concatenate([x_0], axis=1)
join_1 = at.concatenate([x_0, x_1, shape_padright(x_2)], axis=1)

assert join_0 is x_0
assert join_1 is not x_0
def test_join_inplace(self):
# Test join to work inplace.
#
# This function tests the case when several elements are passed to the
# join function but all except one of them are empty. In this case join
# should work inplace and the output should be the view of the non-empty
# element.
s = lscalar()
x = vector("x")
z = at.zeros((s,))

join = Join(view=0)
c = join(0, x, z, z)

f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True))

data = np.array([3, 4, 5], dtype=config.floatX)

if config.mode not in ["DebugMode", "DEBUG_MODE"]:
assert f(data, 0) is data
assert np.allclose(f(data, 0), [3, 4, 5])

def test_join_oneInput(self):
# Test join when only 1 input is given.
#
# This functions tests the case when concatenate is called
# on an array of tensors but the array has only one element.
# In this case, we would like to avoid the computational
# overhead of concatenation of one element.
x_0 = fmatrix()
x_1 = fmatrix()
x_2 = fvector()
join_0 = at.concatenate([x_0], axis=1)
join_1 = at.concatenate([x_0, x_1, shape_padright(x_2)], axis=1)

assert join_0 is x_0
assert join_1 is not x_0

@pytest.mark.parametrize("linker", ("py", "c"))
def test_split_view(self, linker):
x = vector("x")
axis = 0
op = Split(len_splits=3)
assert op.view_map == {0: [0], 1: [0], 2: [0]}
splits = op(x, axis, [0, 3, 2])

mode = Mode(linker)
f = pytensor.function(
[In(x, borrow=True)], [Out(s, borrow=True) for s in splits], mode=mode
)
x_test = np.arange(5, dtype=config.floatX)
res = f(x_test)
for r, expected in zip(res, ([], [0, 1, 2], [3, 4])):
assert np.allclose(r, expected)
if linker == "py":
assert r.base is x_test
else:
# C impl always makes a copy
assert r.base is not x_test


def test_TensorFromScalar():
Expand Down

0 comments on commit f4536c3

Please sign in to comment.