From e406685178fca18a44054c7ad8b317ab8ac52b2a Mon Sep 17 00:00:00 2001 From: Jesse Swanson Date: Mon, 26 Apr 2021 21:52:08 -0400 Subject: [PATCH] black --- jiant/tasks/lib/templates/squad_style/core.py | 10 +++++----- jiant/tasks/lib/templates/squad_style/utils.py | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/jiant/tasks/lib/templates/squad_style/core.py b/jiant/tasks/lib/templates/squad_style/core.py index 685483784..17cb426ad 100644 --- a/jiant/tasks/lib/templates/squad_style/core.py +++ b/jiant/tasks/lib/templates/squad_style/core.py @@ -91,7 +91,7 @@ def to_feature_list( end_position = self.end_position # If the answer cannot be found in the text, then skip this example. - actual_text = " ".join(self.doc_tokens[start_position: (end_position + 1)]) + actual_text = " ".join(self.doc_tokens[start_position : (end_position + 1)]) cleaned_answer_text = " ".join(whitespace_tokenize(self.answer_text)) if actual_text.find(cleaned_answer_text) == -1: logger.warning( @@ -193,7 +193,7 @@ def to_feature_list( - 1 - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id) ) - non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1:] + non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :] else: non_padded_ids = encoded_dict["input_ids"] @@ -249,9 +249,9 @@ def to_feature_list( # Original TF implementation also keep the classification token (set to 0) p_mask = np.ones_like(span["token_type_ids"]) if tokenizer.padding_side == "right": - p_mask[len(truncated_query) + sequence_added_tokens:] = 0 + p_mask[len(truncated_query) + sequence_added_tokens :] = 0 else: - p_mask[-len(span["tokens"]): -(len(truncated_query) + sequence_added_tokens)] = 0 + p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0 pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id) special_token_indices = np.asarray( @@ -538,7 +538,7 @@ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_ans for new_start in range(input_start, input_end + 1): for new_end in range(input_end, new_start - 1, -1): - text_span = " ".join(doc_tokens[new_start: (new_end + 1)]) + text_span = " ".join(doc_tokens[new_start : (new_end + 1)]) if text_span == tok_answer_text: return new_start, new_end diff --git a/jiant/tasks/lib/templates/squad_style/utils.py b/jiant/tasks/lib/templates/squad_style/utils.py index 20ff029b4..45474ce88 100644 --- a/jiant/tasks/lib/templates/squad_style/utils.py +++ b/jiant/tasks/lib/templates/squad_style/utils.py @@ -177,10 +177,10 @@ def compute_predictions_logits_v2( break feature = features[pred.feature_index] if pred.start_index > 0: # this is a non-null prediction - tok_tokens = feature.tokens[pred.start_index: (pred.end_index + 1)] + tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)] orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_end = feature.token_to_orig_map[pred.end_index] - orig_tokens = example.doc_tokens[orig_doc_start: (orig_doc_end + 1)] + orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)] tok_text = tokenizer.convert_tokens_to_string(tok_tokens) @@ -375,10 +375,10 @@ def compute_predictions_logits( break feature = features[pred.feature_index] if pred.start_index > 0: # this is a non-null prediction - tok_tokens = feature.tokens[pred.start_index: (pred.end_index + 1)] + tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)] orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_end = feature.token_to_orig_map[pred.end_index] - orig_tokens = example.doc_tokens[orig_doc_start: (orig_doc_end + 1)] + orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)] tok_text = tokenizer.convert_tokens_to_string(tok_tokens) @@ -549,7 +549,7 @@ def _strip_spaces(text): if orig_end_position is None: return orig_text - output_text = orig_text[orig_start_position: (orig_end_position + 1)] + output_text = orig_text[orig_start_position : (orig_end_position + 1)] return output_text