diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index b2399146a3b601..d9afaa1c12022f 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -1333,10 +1333,7 @@ class TokenClassificationArgumentHandler(ArgumentHandler): def __call__(self, *args, **kwargs): if args is not None and len(args) > 0: - if isinstance(args, str): - inputs = [args] - else: - inputs = args + inputs = list(args) batch_size = len(inputs) else: raise ValueError("At least one input is required.") diff --git a/tests/test_pipelines_ner.py b/tests/test_pipelines_ner.py index 3eb15e544d0ce4..bc12900d8422c3 100644 --- a/tests/test_pipelines_ner.py +++ b/tests/test_pipelines_ner.py @@ -1,7 +1,7 @@ import unittest from transformers import AutoTokenizer, pipeline -from transformers.pipelines import Pipeline +from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler from transformers.testing_utils import require_tf, require_torch from .test_pipelines_common import CustomInputPipelineCommonMixin @@ -107,13 +107,9 @@ def _test_pipeline(self, nlp: Pipeline): def test_tf_only(self): model_name = "Narsil/small" # This model only has a TensorFlow version # We test that if we don't specificy framework='tf', it gets detected automatically - tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) - nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer) + nlp = pipeline(task="ner", model=model_name) self._test_pipeline(nlp) - # offset=tokenizer(VALID_INPUTS[0],return_offsets_mapping=True)['offset_mapping'] - # pipeline_running_kwargs = {"offset_mapping"} # Additional kwargs to run the pipeline with - @require_tf def test_tf_defaults(self): for model_name in self.small_models: @@ -122,9 +118,8 @@ def test_tf_defaults(self): self._test_pipeline(nlp) @require_tf - def test_tf_small(self): + def test_tf_small_ignore_subwords_available_for_fast_tokenizers(self): for model_name in self.small_models: - print(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) nlp = pipeline( task="ner", @@ -136,20 +131,20 @@ def test_tf_small(self): ) self._test_pipeline(nlp) - for model_name in self.small_models: - tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) - nlp = pipeline( - task="ner", - model=model_name, - tokenizer=tokenizer, - framework="tf", - grouped_entities=True, - ignore_subwords=False, - ) - self._test_pipeline(nlp) + for model_name in self.small_models: + tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) + nlp = pipeline( + task="ner", + model=model_name, + tokenizer=tokenizer, + framework="tf", + grouped_entities=True, + ignore_subwords=False, + ) + self._test_pipeline(nlp) @require_torch - def test_pt_defaults_slow_tokenizer_raises(self): + def test_pt_ignore_subwords_slow_tokenizer_raises(self): for model_name in self.small_models: tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -166,12 +161,11 @@ def test_pt_defaults_slow_tokenizer(self): @require_torch def test_pt_defaults(self): for model_name in self.small_models: - tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) - nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer) + nlp = pipeline(task="ner", model=model_name) self._test_pipeline(nlp) @require_torch - def test_torch_small(self): + def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self): for model_name in self.small_models: tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) nlp = pipeline( @@ -185,3 +179,46 @@ def test_torch_small(self): task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=False ) self._test_pipeline(nlp) + + +class TokenClassificationArgumentHandlerTestCase(unittest.TestCase): + def setUp(self): + self.args_parser = TokenClassificationArgumentHandler() + + def test_simple(self): + string = "This is a simple input" + + inputs, offset_mapping = self.args_parser(string) + self.assertEqual(inputs, [string]) + self.assertEqual(offset_mapping, None) + + inputs, offset_mapping = self.args_parser(string, string) + self.assertEqual(inputs, [string, string]) + self.assertEqual(offset_mapping, None) + + inputs, offset_mapping = self.args_parser(string, offset_mapping=[(0, 1), (1, 2)]) + self.assertEqual(inputs, [string]) + self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]]) + + inputs, offset_mapping = self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]) + self.assertEqual(inputs, [string, string]) + self.assertEqual(offset_mapping, [[(0, 1), (1, 2)], [(0, 2), (2, 3)]]) + + def test_errors(self): + string = "This is a simple input" + + # 2 sentences, 1 offset_mapping + with self.assertRaises(ValueError): + self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)]]) + + # 2 sentences, 1 offset_mapping + with self.assertRaises(ValueError): + self.args_parser(string, string, offset_mapping=[(0, 1), (1, 2)]) + + # 1 sentences, 2 offset_mapping + with self.assertRaises(ValueError): + self.args_parser(string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]]) + + # 0 sentences, 1 offset_mapping + with self.assertRaises(ValueError): + self.args_parser(offset_mapping=[[(0, 1), (1, 2)]])