From f5ac84c7d44c287547fa46d2b1b33a03c616de9a Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Tue, 22 Jun 2021 14:59:42 -0400 Subject: [PATCH] Add possibility to maintain full copies of files --- .../tensorflow/question-answering/utils_qa.py | 2 ++ utils/check_copies.py | 28 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/examples/tensorflow/question-answering/utils_qa.py b/examples/tensorflow/question-answering/utils_qa.py index 36d911b9e9acfb..2f8f0a60c45fe5 100644 --- a/examples/tensorflow/question-answering/utils_qa.py +++ b/examples/tensorflow/question-answering/utils_qa.py @@ -38,6 +38,7 @@ def postprocess_qa_predictions( null_score_diff_threshold: float = 0.0, output_dir: Optional[str] = None, prefix: Optional[str] = None, + is_world_process_zero: bool = True, ): """ Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the @@ -90,6 +91,7 @@ def postprocess_qa_predictions( scores_diff_json = collections.OrderedDict() # Logging. + logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN) logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") # Let's loop over all the examples! diff --git a/utils/check_copies.py b/utils/check_copies.py index c1ed7c1a222995..c9e7514c5b1754 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -27,6 +27,9 @@ PATH_TO_DOCS = "docs/source" REPO_PATH = "." +# Mapping for files that are full copies of others (keys are copies, values the file to keep them up to data with) +FULL_COPIES = {"examples/tensorflow/question-answering/utils_qa.py": "examples/pytorch/question-answering/utils_qa.py"} + def _should_continue(line, indent): return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\):\s*$", line) is not None @@ -192,6 +195,30 @@ def check_copies(overwrite: bool = False): check_model_list_copy(overwrite=overwrite) +def check_full_copies(overwrite: bool = False): + diffs = [] + for target, source in FULL_COPIES.items(): + with open(source, "r", encoding="utf-8") as f: + source_code = f.read() + with open(target, "r", encoding="utf-8") as f: + target_code = f.read() + if source_code != target_code: + if overwrite: + with open(target, "w", encoding="utf-8") as f: + print(f"Replacing the content of {target} by the one of {source}.") + f.write(source_code) + else: + diffs.append(f"- {target}: copy does not match {source}.") + + if not overwrite and len(diffs) > 0: + diff = "\n".join(diffs) + raise Exception( + "Found the following copy inconsistencies:\n" + + diff + + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them." + ) + + def get_model_list(): """Extracts the model list from the README.""" # If the introduction or the conclusion of the list change, the prompts may need to be updated. @@ -324,3 +351,4 @@ def check_model_list_copy(overwrite=False, max_per_line=119): args = parser.parse_args() check_copies(args.fix_and_overwrite) + check_full_copies(args.fix_and_overwrite)