diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 51408df74158c5..b0d82a92b8489f 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -1619,7 +1619,8 @@ def _get_module(self, module_name: str) -> ModuleType: class AggregationStrategy(ExplicitEnum): """ - Possible values for the ``aggregation_strategy`` argument in :meth:`TokenClassificationPipeline.__init__`. Useful for tab-completion in an IDE. + Possible values for the ``aggregation_strategy`` argument in :meth:`TokenClassificationPipeline.__init__`. Useful + for tab-completion in an IDE. """ FIRST = "first" diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index 23a67f13bb1f3e..f4254dc6a808a0 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -2,7 +2,7 @@ import numpy as np -from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available, AggregationStrategy +from ..file_utils import AggregationStrategy, add_end_docstrings, is_tf_available, is_torch_available from ..modelcard import ModelCard from ..models.bert.tokenization_bert import BasicTokenizer from ..tokenization_utils import PreTrainedTokenizer @@ -22,7 +22,7 @@ from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING - + class TokenClassificationArgumentHandler(ArgumentHandler): """ Handles arguments for token classification. @@ -191,23 +191,23 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs): entity = {} if offset_mapping is not None: start_ind, end_ind = offset_mapping[idx] - entity['start'], entity['end'] = (start_ind, end_ind) + entity["start"], entity["end"] = (start_ind, end_ind) word_ref = sentence[start_ind:end_ind] word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0] - entity['word'] = word - entity['is_subword'] = len(word_ref) != len(word) + entity["word"] = word + entity["is_subword"] = len(word_ref) != len(word) if int(input_ids[idx]) == self.tokenizer.unk_token_id: - entity['word'] = word_ref - entity['is_subword'] = False + entity["word"] = word_ref + entity["is_subword"] = False else: - entity['word'] = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])) + entity["word"] = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx])) - entity['start'] = None - entity['end'] = None + entity["start"] = None + entity["end"] = None - entity['score'] = score[idx] - entity['index'] = idx + entity["score"] = score[idx] + entity["index"] = idx entities += [entity] diff --git a/tests/test_pipelines_ner.py b/tests/test_pipelines_ner.py index f9025b540ee739..67dc6117014350 100644 --- a/tests/test_pipelines_ner.py +++ b/tests/test_pipelines_ner.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - import json import os +import unittest + import numpy as np from transformers import AutoTokenizer, pipeline @@ -27,6 +27,7 @@ VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]] + class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase): pipeline_task = "ner" small_models = [ @@ -562,16 +563,15 @@ def _test_pipeline(self, nlp: Pipeline): ], ] - expected_aligned_results_filepath = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "fixtures/ner_pipeline_aligned.json") + os.path.dirname(os.path.abspath(__file__)), "fixtures/ner_pipeline_aligned.json" + ) with open(expected_aligned_results_filepath) as expected_aligned_results_file: expected_aligned_results = json.load(expected_aligned_results_file) expected_aligned_results_w_subword_filepath = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "fixtures/ner_pipeline_aligned_w_subwords.json") + os.path.dirname(os.path.abspath(__file__)), "fixtures/ner_pipeline_aligned_w_subwords.json" + ) with open(expected_aligned_results_w_subword_filepath) as expected_aligned_results_w_subword_file: expected_aligned_results_w_subword = json.load(expected_aligned_results_w_subword_file)