Skip to content

Commit

Permalink
Parametrize tests
Browse files Browse the repository at this point in the history
  • Loading branch information
HarshvirSandhu committed Aug 28, 2024
1 parent b68f493 commit d454052
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 97 deletions.
67 changes: 38 additions & 29 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def subtensor(x, *ilists):
return subtensor


@pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(*x):
return slice(x)

Check warning on line 43 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L42-L43

Added lines #L42 - L43 were not covered by tests

return makeslice

Check warning on line 45 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L45

Added line #L45 was not covered by tests


@pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
Expand All @@ -47,38 +55,35 @@ def advsubtensor(x, *indices):
return advsubtensor


@pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(*x):
return slice(x)

return makeslice


@pytorch_funcify.register(IncSubtensor)
def pytorch_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)

if getattr(op, "set_instead_of_inc", False):

def torch_fn(x, indices, y):
check_negative_steps(indices)
x[indices] = y
return x
if op.inplace:
x[tuple(indices)] = y
return x

Check warning on line 67 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L66-L67

Added lines #L66 - L67 were not covered by tests
x1 = x.clone()
x1[tuple(indices)] = y
return x1

else:

def torch_fn(x, indices, y):
check_negative_steps(indices)
x1 = x.clone()
x1[indices] += y
return x1

def incsubtensor(x, y, *ilist, torch_fn=torch_fn, idx_list=idx_list):
if op.inplace:
x[tuple(indices)] += y
return x

Check warning on line 77 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L76-L77

Added lines #L76 - L77 were not covered by tests
else:
x1 = x.clone()
x1[tuple(indices)] += y
return x1

def incsubtensor(x, y, *ilist):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]

check_negative_steps(indices)
return torch_fn(x, indices, y)

return incsubtensor
Expand All @@ -90,22 +95,26 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):

def torch_fn(x, indices, y):
check_negative_steps(indices)
x[indices] = y
return x
if op.inplace:
x[tuple(indices)] = y
return x

Check warning on line 100 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L99-L100

Added lines #L99 - L100 were not covered by tests
x1 = x.clone()
x1[tuple(indices)] = y
return x1

else:

def torch_fn(x, indices, y):
check_negative_steps(indices)
x1 = x.clone()
x1[indices] += y
return x1
if op.inplace:
x[tuple(indices)] += y
return x

Check warning on line 110 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L109-L110

Added lines #L109 - L110 were not covered by tests
else:
x1 = x.clone()
x1[tuple(indices)] += y
return x1

def incsubtensor(x, y, *indices, torch_fn=torch_fn):
if len(indices) == 1:
indices = indices[0]

check_negative_steps(indices)
return torch_fn(x, indices, y)

return incsubtensor
83 changes: 15 additions & 68 deletions tests/link/pytorch/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,83 +109,36 @@ def test_pytorch_AdvSubtensor():
compare_pytorch_and_py(out_fg, [x_np])


def test_pytorch_SetSubtensor():
@pytest.mark.parametrize(
"subtensor_op", [pt_subtensor.set_subtensor, pt_subtensor.inc_subtensor]
)
def test_pytorch_SetSubtensor(subtensor_op):
x_pt = pt.tensor3("x")
x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)

# "Set" basic indices
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_pt = pt_subtensor.set_subtensor(x_pt[1, 2, 3], st_pt)
out_pt = subtensor_op(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])

st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.set_subtensor(x_pt[:2, 0, 0], st_pt)
out_pt = subtensor_op(x_pt[:2, 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])

out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt)
out_pt = subtensor_op(x_pt[0, 1:3, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])

out_pt = pt_subtensor.set_subtensor(x_pt[0, 1:3, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])


def test_pytorch_AdvSetSubtensor():
rng = np.random.default_rng(42)

x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_pt = pt.tensor3("x")
x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)

# "Set" advanced indices
st_pt = pt.as_tensor_variable(
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
)
out_pt = pt_subtensor.set_subtensor(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])

st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])

# "Set" boolean indices
mask_pt = pt.constant(x_np > 0)
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 0.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])


def test_pytorch_IncSubtensor():
x_pt = pt.tensor3("x")
x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)

# "Increment" basic indices
st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_pt = pt_subtensor.inc_subtensor(x_pt[1, 2, 3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])

st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.inc_subtensor(x_pt[:2, 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])


def test_pytorch_AvdancedIncSubtensor():
@pytest.mark.parametrize(
"advsubtensor_op", [pt_subtensor.set_subtensor, pt_subtensor.inc_subtensor]
)
def test_pytorch_AvdancedIncSubtensor(advsubtensor_op):
rng = np.random.default_rng(42)

x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
Expand All @@ -196,32 +149,26 @@ def test_pytorch_AvdancedIncSubtensor():
st_pt = pt.as_tensor_variable(
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
)
out_pt = pt_subtensor.inc_subtensor(x_pt[np.r_[0, 2]], st_pt)
out_pt = advsubtensor_op(x_pt[np.r_[0, 2]], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])

st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, 0], st_pt)
out_pt = advsubtensor_op(x_pt[[0, 2], 0, 0], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])

# "Increment" boolean indices
mask_pt = pt.constant(x_np > 0)
out_pt = pt_subtensor.set_subtensor(x_pt[mask_pt], 1.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])

st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_pt = pt_subtensor.set_subtensor(x_pt[[0, 2], 0, :3], st_pt)
out_pt = advsubtensor_op(x_pt[mask_pt], 1.0)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])

st_pt = pt.as_tensor_variable(x_np[[0, 2], 0, :3])
out_pt = pt_subtensor.inc_subtensor(x_pt[[0, 2], 0, :3], st_pt)
out_pt = advsubtensor_op(x_pt[[0, 2], 0, :3], st_pt)
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_pt], [out_pt])
compare_pytorch_and_py(out_fg, [x_test])

0 comments on commit d454052

Please sign in to comment.