Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement indexing operations in PyTorch #910

Merged
merged 1 commit into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
"BlasOpt",
"fusion",
"inplace",
"local_uint_constant_indices",
],
),
)
Expand Down
3 changes: 2 additions & 1 deletion pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.math
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.nlinalg
import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.nlinalg
import pytensor.link.pytorch.dispatch.subtensor
# isort: on
34 changes: 29 additions & 5 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,40 @@
from functools import singledispatch
from types import NoneType

import numpy as np
import torch

from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
ARange,
Eye,
Join,
MakeVector,
TensorFromScalar,
)


@singledispatch
def pytorch_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
def pytorch_typify(data, **kwargs):
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")

Check warning on line 24 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L24

Added line #L24 was not covered by tests


@pytorch_typify.register(np.ndarray)
@pytorch_typify.register(torch.Tensor)
def pytorch_typify_tensor(data, dtype=None, **kwargs):
return torch.as_tensor(data, dtype=dtype)


@pytorch_typify.register(slice)
@pytorch_typify.register(NoneType)
def pytorch_typify_None(data, **kwargs):
return None
@pytorch_typify.register(np.number)
def pytorch_typify_no_conversion_needed(data, **kwargs):
return data


@singledispatch
Expand Down Expand Up @@ -132,3 +148,11 @@
return torch.tensor(x, dtype=torch_dtype)

return makevector


@pytorch_funcify.register(TensorFromScalar)
def pytorch_funcify_TensorFromScalar(op, **kwargs):
def tensorfromscalar(x):
return torch.as_tensor(x)

Check warning on line 156 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L156

Added line #L156 was not covered by tests

return tensorfromscalar
124 changes: 124 additions & 0 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice, SliceType


def check_negative_steps(indices):
for index in indices:
if isinstance(index, slice):
if index.step is not None and index.step < 0:
raise NotImplementedError(
"Negative step sizes are not supported in Pytorch"
)


@pytorch_funcify.register(Subtensor)
def pytorch_funcify_Subtensor(op, node, **kwargs):
idx_list = op.idx_list

def subtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
return x[indices]

return subtensor


@pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(*x):
return slice(x)

Check warning on line 38 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L37-L38

Added lines #L37 - L38 were not covered by tests

return makeslice

Check warning on line 40 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L40

Added line #L40 was not covered by tests


@pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices):
check_negative_steps(indices)
return x[indices]

return advsubtensor


@pytorch_funcify.register(IncSubtensor)
def pytorch_funcify_IncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace
if op.set_instead_of_inc:

def set_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] = y
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
return x
Comment on lines +62 to +65
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for me locally, are you sure it wasn't working on your end?


return set_subtensor

else:

def inc_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] += y
return x

return inc_subtensor


@pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False)

if op.set_instead_of_inc:

def adv_set_subtensor(x, y, *indices):
check_negative_steps(indices)
if not inplace:
x = x.clone()
x[indices] = y.type_as(x)
return x

return adv_set_subtensor

elif ignore_duplicates:

def adv_inc_subtensor_no_duplicates(x, y, *indices):
check_negative_steps(indices)
if not inplace:
x = x.clone()

Check warning on line 104 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L104

Added line #L104 was not covered by tests
x[indices] += y.type_as(x)
return x

return adv_inc_subtensor_no_duplicates

else:
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)

def adv_inc_subtensor(x, y, *indices):
# Not needed because slices aren't supported
# check_negative_steps(indices)
if not inplace:
x = x.clone()
x.index_put_(indices, y.type_as(x), accumulate=True)
return x

Check warning on line 122 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L120-L122

Added lines #L120 - L122 were not covered by tests

return adv_inc_subtensor
6 changes: 3 additions & 3 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def compare_pytorch_and_py(
py_res = pytensor_py_fn(*test_inputs)

if len(fgraph.outputs) > 1:
for j, p in zip(pytorch_res, py_res):
assert_fn(j.cpu(), p)
for pytorch_res_i, py_res_i in zip(pytorch_res, py_res):
assert_fn(pytorch_res_i.detach().cpu().numpy(), py_res_i)
else:
assert_fn([pytorch_res[0].cpu()], py_res)
assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0])

return pytensor_torch_fn, pytorch_res

Expand Down
Loading
Loading