Skip to content

Commit

Permalink
port tests for transforms.Lambda (#8011)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Oct 3, 2023
1 parent b6189a8 commit 0040fe7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 45 deletions.
35 changes: 0 additions & 35 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,38 +574,3 @@ def test_sanitize_bounding_boxes_errors():
with pytest.raises(ValueError, match="Number of boxes"):
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
transforms.SanitizeBoundingBoxes()(different_sizes)


class TestLambda:
inputs = pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])

@inputs
def test_default(self, input):
was_applied = False

def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input

transform = transforms.Lambda(was_applied_fn)

transform(input)

assert was_applied

@inputs
def test_with_types(self, input):
was_applied = False

def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input

types = (torch.Tensor, np.ndarray)
transform = transforms.Lambda(was_applied_fn, *types)

transform(input)

assert was_applied is isinstance(input, types)
10 changes: 0 additions & 10 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,6 @@ def __init__(
LINEAR_TRANSFORMATION_MATRIX = torch.rand([LINEAR_TRANSFORMATION_MEAN.numel()] * 2)

CONSISTENCY_CONFIGS = [
ConsistencyConfig(
v2_transforms.Lambda,
legacy_transforms.Lambda,
[
NotScriptableArgsKwargs(lambda image: image / 2),
],
# Technically, this also supports PIL, but it is overkill to write a function here that supports tensor and PIL
# images given that the transform does nothing but call it anyway.
supports_pil=False,
),
ConsistencyConfig(
v2_transforms.Compose,
legacy_transforms.Compose,
Expand Down
18 changes: 18 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -5126,3 +5126,21 @@ def test_functional_and_transform(self, color_space, fn):
def test_functional_error(self):
with pytest.raises(TypeError, match="pic should be PIL Image"):
F.pil_to_tensor(object())


class TestLambda:
@pytest.mark.parametrize("input", [object(), torch.empty(()), np.empty(()), "string", 1, 0.0])
@pytest.mark.parametrize("types", [(), (torch.Tensor, np.ndarray)])
def test_transform(self, input, types):
was_applied = False

def was_applied_fn(input):
nonlocal was_applied
was_applied = True
return input

transform = transforms.Lambda(was_applied_fn, *types)
output = transform(input)

assert output is input
assert was_applied is (not types or isinstance(input, types))

0 comments on commit 0040fe7

Please sign in to comment.