Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ScalarLoop in torch backend #958

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
38 changes: 35 additions & 3 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from itertools import chain

import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar import ScalarLoop
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
Expand All @@ -11,9 +14,38 @@
scalar_op = op.scalar_op
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)

def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)
if isinstance(scalar_op, ScalarLoop):
Ch0ronomato marked this conversation as resolved.
Show resolved Hide resolved
# note: scalarloop + elemwise is too common
# to not work, but @1031, vmap won't allow it.
# Instead, we will just successively unbind
def elemwise_fn(*inputs):
Ch0ronomato marked this conversation as resolved.
Show resolved Hide resolved
Elemwise._check_runtime_broadcast(node, inputs)
shaped_inputs = torch.broadcast_tensors(*inputs)
expected_size = shaped_inputs[0].numel()
final_inputs = [s.clone() for s in shaped_inputs]

Check warning on line 25 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L22-L25

Added lines #L22 - L25 were not covered by tests
for _ in range(shaped_inputs[0].dim() - 1):
for i, _ in enumerate(shaped_inputs):
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
final_inputs[i] = list(layer)

Check warning on line 29 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L28-L29

Added lines #L28 - L29 were not covered by tests

# make sure we still have the same number of things
assert len(final_inputs) == len(shaped_inputs)

Check warning on line 32 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L32

Added line #L32 was not covered by tests

# make sure each group of things are the expected size
assert all(len(x) == expected_size for x in final_inputs)

Check warning on line 35 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L35

Added line #L35 was not covered by tests

# make sure they are all single elements
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)

Check warning on line 38 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L38

Added line #L38 was not covered by tests
res = [base_fn(*args) for args in zip(*final_inputs)]
states = torch.stack(tuple(out[0] for out in res))
done = torch.stack(tuple(out[1] for out in res))
return states, done

Check warning on line 42 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L40-L42

Added lines #L40 - L42 were not covered by tests

else:

def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)

return elemwise_fn

Expand Down
39 changes: 39 additions & 0 deletions pytensor/link/pytorch/dispatch/scalar.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
import torch.compiler

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
Cast,
ScalarOp,
)
from pytensor.scalar.loop import ScalarLoop


@pytorch_funcify.register(ScalarOp)
Expand Down Expand Up @@ -41,6 +43,43 @@
return pytorch_func


@pytorch_funcify.register(ScalarLoop)
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
update = pytorch_funcify(op.fgraph)
Ch0ronomato marked this conversation as resolved.
Show resolved Hide resolved
state_length = op.nout
if op.is_while:

def scalar_loop(steps, *start_and_constants):
carry, constants = (
start_and_constants[:state_length],
start_and_constants[state_length:],
)
done = True
for _ in range(steps):
*carry, done = update(*carry, *constants)
if torch.any(done):
break
if len(node.outputs) == 2:
return carry[0], done
else:
return carry, done

Check warning on line 65 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L65

Added line #L65 was not covered by tests
else:

def scalar_loop(steps, *start_and_constants):
carry, constants = (
start_and_constants[:state_length],
start_and_constants[state_length:],
)
for _ in range(steps):
carry = update(*carry, *constants)
if len(node.outputs) == 1:
return carry[0]
else:
return carry

Check warning on line 78 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L78

Added line #L78 was not covered by tests

return torch.compiler.disable(scalar_loop, recursive=False)


@pytorch_funcify.register(Cast)
def pytorch_funcify_Cast(op: Cast, node, **kwargs):
dtype = getattr(torch, op.o_type.dtype)
Expand Down
7 changes: 6 additions & 1 deletion pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
return pytorch_typify(inp)

def output_filter(self, var: Variable, out: Any) -> Any:
return out.cpu()
from torch import is_tensor

if is_tensor(out):
return out.cpu()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will probably create conflict when one of my other PRs gets merged as an FYI.

else:
return out

Check warning on line 21 in pytensor/link/pytorch/linker.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/linker.py#L21

Added line #L21 was not covered by tests

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify
Expand Down
59 changes: 59 additions & 0 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest

import pytensor.tensor as pt
import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
Expand All @@ -17,7 +18,10 @@
from pytensor.ifelse import ifelse
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar import float64, int64
from pytensor.scalar.loop import ScalarLoop
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.type import matrices, matrix, scalar, vector


Expand Down Expand Up @@ -312,6 +316,37 @@ def test_pytorch_MakeVector():
compare_pytorch_and_py(x_fg, [])


def test_ScalarLoop():
n_steps = int64("n_steps")
x0 = float64("x0")
const = float64("const")
x = x0 + const

op = ScalarLoop(init=[x0], constant=[const], update=[x])
Ch0ronomato marked this conversation as resolved.
Show resolved Hide resolved
x = op(n_steps, x0, const)

fn = function([n_steps, x0, const], x, mode=pytorch_mode)
np.testing.assert_allclose(fn(5, 0, 1), 5)
np.testing.assert_allclose(fn(5, 0, 2), 10)
np.testing.assert_allclose(fn(4, 3, -1), -1)


def test_ScalarLoop_while():
n_steps = int64("n_steps")
x0 = float64("x0")
x = x0 + 1
until = x >= 10

op = ScalarLoop(init=[x0], update=[x], until=until)
fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode)
for res, expected in zip(
[fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)],
[[10, True], [10, True], [6, False]],
):
np.testing.assert_allclose(res[0], np.array(expected[0]))
np.testing.assert_allclose(res[1], np.array(expected[1]))


def test_pytorch_ifelse():
p1_vals = np.r_[1, 2, 3]
p2_vals = np.r_[-1, -2, -3]
Expand Down Expand Up @@ -343,3 +378,27 @@ def test_pytorch_OpFromGraph():

f = FunctionGraph([x, y, z], [out])
compare_pytorch_and_py(f, [xv, yv, zv])


def test_ScalarLoop_Elemwise():
n_steps = int64("n_steps")
x0 = float64("x0")
x = x0 * 2
until = x >= 10

scalarop = ScalarLoop(init=[x0], update=[x], until=until)
op = Elemwise(scalarop)

n_steps = pt.scalar("n_steps", dtype="int32")
x0 = pt.vector("x0", dtype="float32")
state, done = op(n_steps, x0)

fn = function([n_steps, x0], [state, done], mode=pytorch_mode)
py_fn = function([n_steps, x0], [state, done])

args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
torch_states, torch_dones = fn(*args)
py_states, py_dones = py_fn(*args)

np.testing.assert_allclose(torch_states, py_states)
np.testing.assert_allclose(torch_dones, py_dones)
Ch0ronomato marked this conversation as resolved.
Show resolved Hide resolved
Loading