Skip to content

Commit

Permalink
flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Swanson authored and Jesse Swanson committed Apr 27, 2021
1 parent 28fea64 commit df43f26
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
13 changes: 7 additions & 6 deletions jiant/tasks/lib/templates/squad_style/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def to_feature_list(
end_position = self.end_position

# If the answer cannot be found in the text, then skip this example.
actual_text = " ".join(self.doc_tokens[start_position : (end_position + 1)])
actual_text = " ".join(self.doc_tokens[start_position: (end_position + 1)])
cleaned_answer_text = " ".join(whitespace_tokenize(self.answer_text))
if actual_text.find(cleaned_answer_text) == -1:
logger.warning(
Expand Down Expand Up @@ -193,7 +193,7 @@ def to_feature_list(
- 1
- encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id)
)
non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :]
non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1:]

else:
non_padded_ids = encoded_dict["input_ids"]
Expand Down Expand Up @@ -244,13 +244,14 @@ def to_feature_list(
# Identify the position of the CLS token
cls_index = span["input_ids"].index(tokenizer.cls_token_id)

# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token
# which can be in an answer)
# Original TF implementation also keep the classification token (set to 0)
p_mask = np.ones_like(span["token_type_ids"])
if tokenizer.padding_side == "right":
p_mask[len(truncated_query) + sequence_added_tokens :] = 0
p_mask[len(truncated_query) + sequence_added_tokens:] = 0
else:
p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0
p_mask[-len(span["tokens"]): -(len(truncated_query) + sequence_added_tokens)] = 0

pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id)
special_token_indices = np.asarray(
Expand Down Expand Up @@ -537,7 +538,7 @@ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_ans

for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1):
text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
text_span = " ".join(doc_tokens[new_start: (new_end + 1)])
if text_span == tok_answer_text:
return new_start, new_end

Expand Down
19 changes: 10 additions & 9 deletions jiant/tasks/lib/templates/squad_style/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,21 @@ class SquadResult:


class ExplicitEnum(Enum):
"""
Enum with more explicit error message for missing values.
"""Enum with more explicit error message for missing values.
"""

@classmethod
def _missing_(cls, value):
raise ValueError(
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
f"{value} is not a valid {cls.__name__}, please select one \
of {list(cls._value2member_map_.keys())}"
)


class TruncationStrategy(ExplicitEnum):
"""
Possible values for the ``truncation`` argument in :meth:`PreTrainedTokenizerBase.__call__`. Useful for
Possible values for the ``truncation`` argument in
:meth:`PreTrainedTokenizerBase.__call__`. Useful for
tab-completion in an IDE.
"""

Expand Down Expand Up @@ -176,10 +177,10 @@ def compute_predictions_logits_v2(
break
feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
tok_tokens = feature.tokens[pred.start_index: (pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
orig_tokens = example.doc_tokens[orig_doc_start: (orig_doc_end + 1)]

tok_text = tokenizer.convert_tokens_to_string(tok_tokens)

Expand Down Expand Up @@ -374,10 +375,10 @@ def compute_predictions_logits(
break
feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
tok_tokens = feature.tokens[pred.start_index: (pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
orig_tokens = example.doc_tokens[orig_doc_start: (orig_doc_end + 1)]

tok_text = tokenizer.convert_tokens_to_string(tok_tokens)

Expand Down Expand Up @@ -548,7 +549,7 @@ def _strip_spaces(text):
if orig_end_position is None:
return orig_text

output_text = orig_text[orig_start_position : (orig_end_position + 1)]
output_text = orig_text[orig_start_position: (orig_end_position + 1)]
return output_text


Expand Down

0 comments on commit df43f26

Please sign in to comment.