From d454052c1277b2186e0d3a70c0ff2d99ae6d07a8 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Wed, 28 Aug 2024 13:00:11 +0530 Subject: [PATCH] Parametrize tests --- pytensor/link/pytorch/dispatch/subtensor.py | 67 ++++++++++------- tests/link/pytorch/test_subtensor.py | 83 ++++----------------- 2 files changed, 53 insertions(+), 97 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py index 160ed046b2..c2211d3932 100644 --- a/pytensor/link/pytorch/dispatch/subtensor.py +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -37,6 +37,14 @@ def subtensor(x, *ilists): return subtensor +@pytorch_funcify.register(MakeSlice) +def pytorch_funcify_makeslice(op, **kwargs): + def makeslice(*x): + return slice(x) + + return makeslice + + @pytorch_funcify.register(AdvancedSubtensor1) @pytorch_funcify.register(AdvancedSubtensor) def pytorch_funcify_AdvSubtensor(op, node, **kwargs): @@ -47,14 +55,6 @@ def advsubtensor(x, *indices): return advsubtensor -@pytorch_funcify.register(MakeSlice) -def pytorch_funcify_makeslice(op, **kwargs): - def makeslice(*x): - return slice(x) - - return makeslice - - @pytorch_funcify.register(IncSubtensor) def pytorch_funcify_IncSubtensor(op, node, **kwargs): idx_list = getattr(op, "idx_list", None) @@ -62,23 +62,28 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs): if getattr(op, "set_instead_of_inc", False): def torch_fn(x, indices, y): - check_negative_steps(indices) - x[indices] = y - return x + if op.inplace: + x[tuple(indices)] = y + return x + x1 = x.clone() + x1[tuple(indices)] = y + return x1 else: def torch_fn(x, indices, y): - check_negative_steps(indices) - x1 = x.clone() - x1[indices] += y - return x1 - - def incsubtensor(x, y, *ilist, torch_fn=torch_fn, idx_list=idx_list): + if op.inplace: + x[tuple(indices)] += y + return x + else: + x1 = x.clone() + x1[tuple(indices)] += y + return x1 + + def incsubtensor(x, y, *ilist): indices = indices_from_subtensor(ilist, idx_list) - if len(indices) == 1: - indices = indices[0] + check_negative_steps(indices) return torch_fn(x, indices, y) return incsubtensor @@ -90,22 +95,26 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): if getattr(op, "set_instead_of_inc", False): def torch_fn(x, indices, y): - check_negative_steps(indices) - x[indices] = y - return x + if op.inplace: + x[tuple(indices)] = y + return x + x1 = x.clone() + x1[tuple(indices)] = y + return x1 else: def torch_fn(x, indices, y): - check_negative_steps(indices) - x1 = x.clone() - x1[indices] += y - return x1 + if op.inplace: + x[tuple(indices)] += y + return x + else: + x1 = x.clone() + x1[tuple(indices)] += y + return x1 def incsubtensor(x, y, *indices, torch_fn=torch_fn): - if len(indices) == 1: - indices = indices[0] - + check_negative_steps(indices) return torch_fn(x, indices, y) return incsubtensor diff --git a/tests/link/pytorch/test_subtensor.py b/tests/link/pytorch/test_subtensor.py index e6c7620fb3..661b97128b 100644 --- a/tests/link/pytorch/test_subtensor.py +++ b/tests/link/pytorch/test_subtensor.py @@ -109,83 +109,36 @@ def test_pytorch_AdvSubtensor(): compare_pytorch_and_py(out_fg, [x_np]) -def test_pytorch_SetSubtensor(): +@pytest.mark.parametrize( + "subtensor_op", [pt_subtensor.set_subtensor, pt_subtensor.inc_subtensor] +) +def test_pytorch_SetSubtensor(subtensor_op): x_pt = pt.tensor3("x") x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX) # "Set" basic indices st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) - out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt) + out_pt = subtensor_op(x_pt[1, 2, 3], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py(out_fg, [x_test]) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) - out_pt = pt_subtensor.set_subtensor(x_pt[:2, 0, 0], st_pt) + out_pt = subtensor_op(x_pt[:2, 0, 0], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py(out_fg, [x_test]) - out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt) + out_pt = subtensor_op(x_pt[0, 1:3, 0], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py(out_fg, [x_test]) - out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt) - assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) - -def test_pytorch_AdvSetSubtensor(): - rng = np.random.default_rng(42) - - x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) - x_pt = pt.tensor3("x") - x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX) - - # "Set" advanced indices - st_pt = pt.as_tensor_variable( - rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX) - ) - out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt) - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) - - st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) - out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, 0], st_pt) - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) - - # "Set" boolean indices - mask_pt = pt.constant(x_np > 0) - out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0) - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) - - -def test_pytorch_IncSubtensor(): - x_pt = pt.tensor3("x") - x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX) - - # "Increment" basic indices - st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) - out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt) - assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) - - st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) - out_pt = pt_subtensor.inc_subtensor(x_pt[:2, 0, 0], st_pt) - assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) - - -def test_pytorch_AvdancedIncSubtensor(): +@pytest.mark.parametrize( + "advsubtensor_op", [pt_subtensor.set_subtensor, pt_subtensor.inc_subtensor] +) +def test_pytorch_AvdancedIncSubtensor(advsubtensor_op): rng = np.random.default_rng(42) x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) @@ -196,32 +149,26 @@ def test_pytorch_AvdancedIncSubtensor(): st_pt = pt.as_tensor_variable( rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX) ) - out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt) + out_pt = advsubtensor_op(x_pt[np.r_[0, 2]], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py(out_fg, [x_test]) st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX)) - out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, 0], st_pt) + out_pt = advsubtensor_op(x_pt[[0, 2], 0, 0], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py(out_fg, [x_test]) # "Increment" boolean indices mask_pt = pt.constant(x_np > 0) - out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 1.0) - assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([x_pt], [out_pt]) - compare_pytorch_and_py(out_fg, [x_test]) - - st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3]) - out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, :3], st_pt) + out_pt = advsubtensor_op(x_pt[mask_pt], 1.0) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py(out_fg, [x_test]) st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3]) - out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, :3], st_pt) + out_pt = advsubtensor_op(x_pt[[0, 2], 0, :3], st_pt) assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([x_pt], [out_pt]) compare_pytorch_and_py(out_fg, [x_test])