Skip to content

Commit

Permalink
Style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
francescorubbo committed Apr 25, 2021
1 parent 324f641 commit 7fcfc4e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 20 deletions.
3 changes: 2 additions & 1 deletion src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,8 @@ def _get_module(self, module_name: str) -> ModuleType:

class AggregationStrategy(ExplicitEnum):
"""
Possible values for the ``aggregation_strategy`` argument in :meth:`TokenClassificationPipeline.__init__`. Useful for tab-completion in an IDE.
Possible values for the ``aggregation_strategy`` argument in :meth:`TokenClassificationPipeline.__init__`. Useful
for tab-completion in an IDE.
"""

FIRST = "first"
Expand Down
24 changes: 12 additions & 12 deletions src/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from ..file_utils import add_end_docstrings, is_tf_available, is_torch_available, AggregationStrategy
from ..file_utils import AggregationStrategy, add_end_docstrings, is_tf_available, is_torch_available
from ..modelcard import ModelCard
from ..models.bert.tokenization_bert import BasicTokenizer
from ..tokenization_utils import PreTrainedTokenizer
Expand All @@ -22,7 +22,7 @@

from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING


class TokenClassificationArgumentHandler(ArgumentHandler):
"""
Handles arguments for token classification.
Expand Down Expand Up @@ -191,23 +191,23 @@ def __call__(self, inputs: Union[str, List[str]], **kwargs):
entity = {}
if offset_mapping is not None:
start_ind, end_ind = offset_mapping[idx]
entity['start'], entity['end'] = (start_ind, end_ind)
entity["start"], entity["end"] = (start_ind, end_ind)
word_ref = sentence[start_ind:end_ind]
word = self.tokenizer.convert_ids_to_tokens([int(input_ids[idx])])[0]
entity['word'] = word
entity['is_subword'] = len(word_ref) != len(word)
entity["word"] = word
entity["is_subword"] = len(word_ref) != len(word)

if int(input_ids[idx]) == self.tokenizer.unk_token_id:
entity['word'] = word_ref
entity['is_subword'] = False
entity["word"] = word_ref
entity["is_subword"] = False
else:
entity['word'] = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))
entity["word"] = self.tokenizer.convert_ids_to_tokens(int(input_ids[idx]))

entity['start'] = None
entity['end'] = None
entity["start"] = None
entity["end"] = None

entity['score'] = score[idx]
entity['index'] = idx
entity["score"] = score[idx]
entity["index"] = idx

entities += [entity]

Expand Down
14 changes: 7 additions & 7 deletions tests/test_pipelines_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import json
import os
import unittest

import numpy as np

from transformers import AutoTokenizer, pipeline
Expand All @@ -27,6 +27,7 @@

VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]


class NerPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "ner"
small_models = [
Expand Down Expand Up @@ -562,16 +563,15 @@ def _test_pipeline(self, nlp: Pipeline):
],
]


expected_aligned_results_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"fixtures/ner_pipeline_aligned.json")
os.path.dirname(os.path.abspath(__file__)), "fixtures/ner_pipeline_aligned.json"
)
with open(expected_aligned_results_filepath) as expected_aligned_results_file:
expected_aligned_results = json.load(expected_aligned_results_file)

expected_aligned_results_w_subword_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"fixtures/ner_pipeline_aligned_w_subwords.json")
os.path.dirname(os.path.abspath(__file__)), "fixtures/ner_pipeline_aligned_w_subwords.json"
)
with open(expected_aligned_results_w_subword_filepath) as expected_aligned_results_w_subword_file:
expected_aligned_results_w_subword = json.load(expected_aligned_results_w_subword_file)

Expand Down

0 comments on commit 7fcfc4e

Please sign in to comment.