Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some added tests for TokenClassificationArgumentHandler #8366

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,10 +1333,7 @@ class TokenClassificationArgumentHandler(ArgumentHandler):
def __call__(self, *args, **kwargs):

if args is not None and len(args) > 0:
if isinstance(args, str):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can never be call starargs is always a tuple.

inputs = [args]
else:
inputs = args
inputs = list(args)
batch_size = len(inputs)
else:
raise ValueError("At least one input is required.")
Expand Down
83 changes: 60 additions & 23 deletions tests/test_pipelines_ner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest

from transformers import AutoTokenizer, pipeline
from transformers.pipelines import Pipeline
from transformers.pipelines import Pipeline, TokenClassificationArgumentHandler
from transformers.testing_utils import require_tf, require_torch

from .test_pipelines_common import CustomInputPipelineCommonMixin
Expand Down Expand Up @@ -107,13 +107,9 @@ def _test_pipeline(self, nlp: Pipeline):
def test_tf_only(self):
model_name = "Narsil/small" # This model only has a TensorFlow version
# We test that if we don't specificy framework='tf', it gets detected automatically
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer)
nlp = pipeline(task="ner", model=model_name)
self._test_pipeline(nlp)

# offset=tokenizer(VALID_INPUTS[0],return_offsets_mapping=True)['offset_mapping']
# pipeline_running_kwargs = {"offset_mapping"} # Additional kwargs to run the pipeline with

@require_tf
def test_tf_defaults(self):
for model_name in self.small_models:
Expand All @@ -122,9 +118,8 @@ def test_tf_defaults(self):
self._test_pipeline(nlp)

@require_tf
def test_tf_small(self):
def test_tf_small_ignore_subwords_available_for_fast_tokenizers(self):
for model_name in self.small_models:
print(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
nlp = pipeline(
task="ner",
Expand All @@ -136,20 +131,20 @@ def test_tf_small(self):
)
self._test_pipeline(nlp)

for model_name in self.small_models:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
nlp = pipeline(
task="ner",
model=model_name,
tokenizer=tokenizer,
framework="tf",
grouped_entities=True,
ignore_subwords=False,
)
self._test_pipeline(nlp)
for model_name in self.small_models:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
nlp = pipeline(
task="ner",
model=model_name,
tokenizer=tokenizer,
framework="tf",
grouped_entities=True,
ignore_subwords=False,
)
self._test_pipeline(nlp)

@require_torch
def test_pt_defaults_slow_tokenizer_raises(self):
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
for model_name in self.small_models:
tokenizer = AutoTokenizer.from_pretrained(model_name)

Expand All @@ -166,12 +161,11 @@ def test_pt_defaults_slow_tokenizer(self):
@require_torch
def test_pt_defaults(self):
for model_name in self.small_models:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
nlp = pipeline(task="ner", model=model_name, tokenizer=tokenizer)
nlp = pipeline(task="ner", model=model_name)
self._test_pipeline(nlp)

@require_torch
def test_torch_small(self):
def test_pt_small_ignore_subwords_available_for_fast_tokenizers(self):
for model_name in self.small_models:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
nlp = pipeline(
Expand All @@ -185,3 +179,46 @@ def test_torch_small(self):
task="ner", model=model_name, tokenizer=tokenizer, grouped_entities=True, ignore_subwords=False
)
self._test_pipeline(nlp)


class TokenClassificationArgumentHandlerTestCase(unittest.TestCase):
def setUp(self):
self.args_parser = TokenClassificationArgumentHandler()

def test_simple(self):
string = "This is a simple input"

inputs, offset_mapping = self.args_parser(string)
self.assertEqual(inputs, [string])
self.assertEqual(offset_mapping, None)

inputs, offset_mapping = self.args_parser(string, string)
self.assertEqual(inputs, [string, string])
self.assertEqual(offset_mapping, None)

inputs, offset_mapping = self.args_parser(string, offset_mapping=[(0, 1), (1, 2)])
self.assertEqual(inputs, [string])
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)]])

inputs, offset_mapping = self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])
self.assertEqual(inputs, [string, string])
self.assertEqual(offset_mapping, [[(0, 1), (1, 2)], [(0, 2), (2, 3)]])

def test_errors(self):
string = "This is a simple input"

# 2 sentences, 1 offset_mapping
with self.assertRaises(ValueError):
self.args_parser(string, string, offset_mapping=[[(0, 1), (1, 2)]])

# 2 sentences, 1 offset_mapping
with self.assertRaises(ValueError):
self.args_parser(string, string, offset_mapping=[(0, 1), (1, 2)])

# 1 sentences, 2 offset_mapping
with self.assertRaises(ValueError):
self.args_parser(string, offset_mapping=[[(0, 1), (1, 2)], [(0, 2), (2, 3)]])

# 0 sentences, 1 offset_mapping
with self.assertRaises(ValueError):
self.args_parser(offset_mapping=[[(0, 1), (1, 2)]])