Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ScalarLoop in torch backend #958

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
66 changes: 63 additions & 3 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from itertools import chain

import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar import ScalarLoop
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
Expand All @@ -9,11 +12,15 @@
@pytorch_funcify.register(Elemwise)
def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op

base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
if isinstance(scalar_op, ScalarLoop):
Ch0ronomato marked this conversation as resolved.
Show resolved Hide resolved
return elemwise_scalar_loop(base_fn, op, node, **kwargs)
else:

def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)
def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)

return elemwise_fn

Expand Down Expand Up @@ -148,3 +155,56 @@ def softmax_grad(dy, sm):
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm

return softmax_grad


def elemwise_scalar_loop(base_fn, op, node, **kwargs):
"""
ScalarLoop + Elemwise is too common
to not work, but @1031, vmap won't allow it.
Instead, we can do the following strategy
1. `.unbind(dim)` will return a list of tensors
representing `dim` but "unwrapped". e.x.
```
t = torch.ones(3, 4, 2)
len(t.unbind(0)) == 3
t[0].shape == torch.Size[4, 2]
2. If we successfully apply, the length of the list will grow
by the next dimension in the tensor if we flatten the previous
dimension result
```
inputs = [torch.ones(3, 4, 2)]
level_1 = chain.from_iterable(t.unbind(0) for t in inputs)
level_2 = chain.from_iterable(t.unbind(0) for t in level_1)
len(level_2) == 3 * 4
```
3. Eventually we'll reach single dimension tensors. At that point
we can iterate over each input in an element by element manner
and call some function

For scalar loop, we need to broadcast the tensors so all
the necessary values are repeated, and we "evenly" iterate through everything
"""

def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
shaped_inputs = torch.broadcast_tensors(*inputs)
expected_size = shaped_inputs[0].numel()
final_inputs = [s.clone() for s in shaped_inputs]
for _ in range(shaped_inputs[0].dim() - 1):
for i, _ in enumerate(shaped_inputs):
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
final_inputs[i] = list(layer)

# make sure we still have the same number of things
assert len(final_inputs) == len(shaped_inputs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can put these into the unit test if that's preferred now.


# make sure each group of things are the expected size
assert all(len(x) == expected_size for x in final_inputs)

# make sure they are all single elements
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
res = [base_fn(*args) for args in zip(*final_inputs)]

return [torch.stack(tuple(out[i] for out in res)) for i in range(len(res[0]))]

return elemwise_fn
39 changes: 39 additions & 0 deletions pytensor/link/pytorch/dispatch/scalar.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
import torch.compiler

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
Cast,
ScalarOp,
)
from pytensor.scalar.loop import ScalarLoop


@pytorch_funcify.register(ScalarOp)
Expand Down Expand Up @@ -41,6 +43,43 @@
return pytorch_func


@pytorch_funcify.register(ScalarLoop)
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
update = pytorch_funcify(op.fgraph)
Ch0ronomato marked this conversation as resolved.
Show resolved Hide resolved
state_length = op.nout
if op.is_while:

def scalar_loop(steps, *start_and_constants):
carry, constants = (
start_and_constants[:state_length],
start_and_constants[state_length:],
)
done = True
for _ in range(steps):
*carry, done = update(*carry, *constants)
if torch.any(done):
break
if len(node.outputs) == 2:
return carry[0], done
else:
return carry, done

Check warning on line 65 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L65

Added line #L65 was not covered by tests
else:

def scalar_loop(steps, *start_and_constants):
carry, constants = (
start_and_constants[:state_length],
start_and_constants[state_length:],
)
for _ in range(steps):
carry = update(*carry, *constants)
if len(node.outputs) == 1:
return carry[0]
else:
return carry

Check warning on line 78 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L78

Added line #L78 was not covered by tests

return torch.compiler.disable(scalar_loop, recursive=False)


@pytorch_funcify.register(Cast)
def pytorch_funcify_Cast(op: Cast, node, **kwargs):
dtype = getattr(torch, op.o_type.dtype)
Expand Down
7 changes: 6 additions & 1 deletion pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
return pytorch_typify(inp)

def output_filter(self, var: Variable, out: Any) -> Any:
return out.cpu()
from torch import is_tensor

if is_tensor(out):
return out.cpu()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will probably create conflict when one of my other PRs gets merged as an FYI.

else:
return out

Check warning on line 21 in pytensor/link/pytorch/linker.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/linker.py#L21

Added line #L21 was not covered by tests

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify
Expand Down
78 changes: 78 additions & 0 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from collections.abc import Callable, Iterable
from functools import partial
from itertools import repeat, starmap
from unittest.mock import MagicMock, call, patch

import numpy as np
import pytest

import pytensor.tensor as pt
import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
Expand All @@ -17,7 +20,10 @@
from pytensor.ifelse import ifelse
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar import float64, int64
from pytensor.scalar.loop import ScalarLoop
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.type import matrices, matrix, scalar, vector


Expand Down Expand Up @@ -312,6 +318,37 @@ def test_pytorch_MakeVector():
compare_pytorch_and_py(x_fg, [])


def test_ScalarLoop():
n_steps = int64("n_steps")
x0 = float64("x0")
const = float64("const")
x = x0 + const

op = ScalarLoop(init=[x0], constant=[const], update=[x])
Ch0ronomato marked this conversation as resolved.
Show resolved Hide resolved
x = op(n_steps, x0, const)

fn = function([n_steps, x0, const], x, mode=pytorch_mode)
np.testing.assert_allclose(fn(5, 0, 1), 5)
np.testing.assert_allclose(fn(5, 0, 2), 10)
np.testing.assert_allclose(fn(4, 3, -1), -1)


def test_ScalarLoop_while():
n_steps = int64("n_steps")
x0 = float64("x0")
x = x0 + 1
until = x >= 10

op = ScalarLoop(init=[x0], update=[x], until=until)
fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode)
for res, expected in zip(
[fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)],
[[10, True], [10, True], [6, False]],
):
np.testing.assert_allclose(res[0], np.array(expected[0]))
np.testing.assert_allclose(res[1], np.array(expected[1]))


def test_pytorch_ifelse():
p1_vals = np.r_[1, 2, 3]
p2_vals = np.r_[-1, -2, -3]
Expand Down Expand Up @@ -343,3 +380,44 @@ def test_pytorch_OpFromGraph():

f = FunctionGraph([x, y, z], [out])
compare_pytorch_and_py(f, [xv, yv, zv])


def test_ScalarLoop_Elemwise():
n_steps = int64("n_steps")
x0 = float64("x0")
x = x0 * 2
until = x >= 10

scalarop = ScalarLoop(init=[x0], update=[x], until=until)
op = Elemwise(scalarop)

n_steps = pt.scalar("n_steps", dtype="int32")
x0 = pt.vector("x0", dtype="float32")
state, done = op(n_steps, x0)

f = FunctionGraph([n_steps, x0], [state, done])
args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
compare_pytorch_and_py(f, args)


torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise")


@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set this up so we can try different shapes, but I stuck this one to get started. If you think we should add more lmk.

@patch("pytensor.link.pytorch.dispatch.elemwise.Elemwise")
def test_ScalarLoop_Elemwise_iteration_logic(_, input_shapes):
args = [torch.ones(*s) for s in input_shapes[:-1]] + [
torch.zeros(*input_shapes[-1])
]
mock_inner_func = MagicMock()
ret_value = torch.rand(2, 2).unbind(0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename to expected

mock_inner_func.f.return_value = ret_value
elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None)
result = elemwise_fn(*args)
for actual, expected in zip(ret_value, result):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are backwards fyi

assert torch.all(torch.eq(*torch.broadcast_tensors(actual, expected)))
np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0]))

expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0)
expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm bullish on itertools stuff but I think I saw mention earlier that list comprehensions are preferred. I can refactor it if so.

mock_inner_func.f.assert_has_calls(expected_calls)
Loading