Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Jul 17, 2024
1 parent 658d912 commit 7047811
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,9 @@ def decorator(func):
is_instance_method = "self" in function_named_args
is_class_method = "cls" in function_named_args

# Mark function as decorated
func._filter_out_non_signature_kwargs = True

@wraps(func)
def wrapper(*args, **kwargs):
valid_kwargs = {}
Expand Down
17 changes: 17 additions & 0 deletions tests/models/vitmatte/test_image_processing_vitmatte.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


import unittest
import warnings

import numpy as np

Expand Down Expand Up @@ -198,3 +199,19 @@ def test_padding(self):
image = np.random.randn(3, 249, 512)
images = image_processing.pad_image(image)
assert images.shape == (3, 256, 512)

def test_image_processor_preprocess_arguments(self):
# vitmatte require additional trimap input for image_processor
# that is why we override original common test

for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)
image = self.image_processor_tester.prepare_image_inputs()[0]
trimap = np.random.randint(0, 3, size=image.size[::-1])

with warnings.catch_warnings(record=True) as raised_warnings:
warnings.simplefilter("always")
image_processor(image, trimaps=trimap, extra_argument=True)

self.assertEqual(len(raised_warnings), 1)
self.assertIn("extra_argument", str(raised_warnings[0].message))
26 changes: 26 additions & 0 deletions tests/test_image_processing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pathlib
import tempfile
import time
import warnings

import numpy as np
import requests
Expand Down Expand Up @@ -425,15 +426,40 @@ def test_call_numpy_4_channels(self):
)

def test_image_processor_preprocess_arguments(self):
is_tested = False

for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)

# validation done by _valid_processor_keys attribute
if hasattr(image_processor, "_valid_processor_keys") and hasattr(image_processor, "preprocess"):
preprocess_parameter_names = inspect.getfullargspec(image_processor.preprocess).args
preprocess_parameter_names.remove("self")
preprocess_parameter_names.sort()
valid_processor_keys = image_processor._valid_processor_keys
valid_processor_keys.sort()
self.assertEqual(preprocess_parameter_names, valid_processor_keys)
is_tested = True

# validation done by @filter_out_non_signature_kwargs decorator
if hasattr(image_processor.preprocess, "_filter_out_non_signature_kwargs"):
if hasattr(self.image_processor_tester, "prepare_image_inputs"):
inputs = self.image_processor_tester.prepare_image_inputs()
elif hasattr(self.image_processor_tester, "prepare_video_inputs"):
inputs = self.image_processor_tester.prepare_video_inputs()
else:
self.skipTest(reason="No valid input preparation method found")

with warnings.catch_warnings(record=True) as raised_warnings:
warnings.simplefilter("always")
image_processor(inputs, extra_argument=True)

self.assertEqual(len(raised_warnings), 1)
self.assertIn("extra_argument", str(raised_warnings[0].message))
is_tested = True

if not is_tested:
self.skipTest(reason="No validation found for `preprocess` method")


class AnnotationFormatTestMixin:
Expand Down

0 comments on commit 7047811

Please sign in to comment.