Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Swanson authored and Jesse Swanson committed Apr 27, 2021
1 parent df43f26 commit e406685
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
10 changes: 5 additions & 5 deletions jiant/tasks/lib/templates/squad_style/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def to_feature_list(
end_position = self.end_position

# If the answer cannot be found in the text, then skip this example.
actual_text = " ".join(self.doc_tokens[start_position: (end_position + 1)])
actual_text = " ".join(self.doc_tokens[start_position : (end_position + 1)])
cleaned_answer_text = " ".join(whitespace_tokenize(self.answer_text))
if actual_text.find(cleaned_answer_text) == -1:
logger.warning(
Expand Down Expand Up @@ -193,7 +193,7 @@ def to_feature_list(
- 1
- encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id)
)
non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1:]
non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :]

else:
non_padded_ids = encoded_dict["input_ids"]
Expand Down Expand Up @@ -249,9 +249,9 @@ def to_feature_list(
# Original TF implementation also keep the classification token (set to 0)
p_mask = np.ones_like(span["token_type_ids"])
if tokenizer.padding_side == "right":
p_mask[len(truncated_query) + sequence_added_tokens:] = 0
p_mask[len(truncated_query) + sequence_added_tokens :] = 0
else:
p_mask[-len(span["tokens"]): -(len(truncated_query) + sequence_added_tokens)] = 0
p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0

pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id)
special_token_indices = np.asarray(
Expand Down Expand Up @@ -538,7 +538,7 @@ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_ans

for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1):
text_span = " ".join(doc_tokens[new_start: (new_end + 1)])
text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
if text_span == tok_answer_text:
return new_start, new_end

Expand Down
10 changes: 5 additions & 5 deletions jiant/tasks/lib/templates/squad_style/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,10 @@ def compute_predictions_logits_v2(
break
feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index: (pred.end_index + 1)]
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start: (orig_doc_end + 1)]
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]

tok_text = tokenizer.convert_tokens_to_string(tok_tokens)

Expand Down Expand Up @@ -375,10 +375,10 @@ def compute_predictions_logits(
break
feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index: (pred.end_index + 1)]
tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start: (orig_doc_end + 1)]
orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]

tok_text = tokenizer.convert_tokens_to_string(tok_tokens)

Expand Down Expand Up @@ -549,7 +549,7 @@ def _strip_spaces(text):
if orig_end_position is None:
return orig_text

output_text = orig_text[orig_start_position: (orig_end_position + 1)]
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
return output_text


Expand Down

0 comments on commit e406685

Please sign in to comment.