From 9de5a78af1b7de446b2170b11af6f77bafd54c82 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 20 Jun 2024 18:33:39 +0200 Subject: [PATCH] Implement basic Alloc Ops in PyTorch --- pytensor/link/pytorch/dispatch/basic.py | 31 ++++++++++++++++++++++ tests/link/pytorch/test_basic.py | 34 ++++++++++++++++++++++++- 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index c74df67b5b..a9521dc3cd 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -6,6 +6,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import CheckAndRaise +from pytensor.tensor.basic import Alloc, AllocEmpty, ARange @singledispatch @@ -58,3 +59,33 @@ def deepcopyop(x): return x.clone() return deepcopyop + + +@pytorch_funcify.register(AllocEmpty) +def pytorch_funcify_AllocEmpty(op, **kwargs): + dtype = getattr(torch, op.dtype) + + def alloc_empty(*shape): + return torch.empty(shape, dtype=dtype) + + return alloc_empty + + +@pytorch_funcify.register(Alloc) +def pytorch_funcify_alloc(op, **kwargs): + def alloc(value, *shape): + out = torch.empty(shape, dtype=value.dtype) + out[...] = value # broadcast value to shape of out + return out + + return alloc + + +@pytorch_funcify.register(ARange) +def pytorch_funcify_arange(op, **kwargs): + dtype = getattr(torch, op.dtype) + + def arange(start, stop, step): + return torch.arange(start, stop, step, dtype=dtype) + + return arange diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 68d937fce8..cb6e652e23 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -12,6 +12,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.raise_op import CheckAndRaise +from pytensor.tensor import alloc, arange, as_tensor, empty from pytensor.tensor.type import scalar, vector @@ -191,7 +192,7 @@ def test_shared_updates(device): assert isinstance(a.get_value(), np.ndarray) -def test_pytorch_checkandraise(): +def test_checkandraise(): check_and_raise = CheckAndRaise(AssertionError, "testing") x = scalar("x") @@ -203,3 +204,34 @@ def test_pytorch_checkandraise(): with pytest.raises(AssertionError, match="testing"): y_fn(0.0) assert y_fn(4).item() == 4 + + +def test_alloc_and_empty(): + dim0 = as_tensor(5, dtype="int64") + dim1 = scalar("dim1", dtype="int64") + + out = empty((dim0, dim1, 3), dtype="float32") + fn = function([dim1], out, mode=pytorch_mode) + res = fn(7) + assert res.shape == (5, 7, 3) + assert res.dtype == torch.float32 + + v = vector("v", shape=(3,), dtype="float64") + out = alloc(v, (dim0, dim1, 3)) + compare_pytorch_and_py( + FunctionGraph([v, dim1], [out]), + [np.array([1, 2, 3]), np.array(7)], + ) + + +def test_arange(): + start = scalar("start", dtype="int64") + stop = scalar("stop", dtype="int64") + step = scalar("step", dtype="int64") + + out = arange(start, stop, step, dtype="int16") + + compare_pytorch_and_py( + FunctionGraph([start, stop, step], [out]), + [np.array(1), np.array(10), np.array(2)], + )