From 627f44799a9f4948a6a1b8fe9e536eee0e29ea68 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Wed, 10 May 2023 03:34:48 +0900 Subject: [PATCH] [Doctests] Refactor doctests + add CI (#22987) * intiial commit * new styling * update * just run doctest in CI * remove more test for fast dev * update * update refs * update path and fetch upstream * update documentatyion trests * typo * parse pwd * don't check for files that are in hidden folders * just give paths relative to transformers * update * update * update * major refactoring * make sure options is ok * lest test that mdx is tested * doctest glob * nits * update doctest nightly * some cleaning * run correct test on diff * debug * run on a single worker * skip_cuda_test tampkate * updates * add rA and continue on failure * test options * parse `py` codeblock? * we don't need to replace ignore results, don't remember whyu I put it * cleanup * more cleaning * fix arg * more cleaning * clean an todo * more pre-processing * doctest-module has none so extra `- ` is needed * remove logs * nits * doctest-modules .... * oups * let's use sugar * make dataset go quiet * add proper timeout * nites * spleling timeout * update * properly skip tests that have CUDSA * proper skipping * cleaning main and get tests to run * remove make report? * remove tee * some updates * tee was removed but is the full output still available? * [all-test] * only our tests * don't touch tee in this PR * no atee-sys * proper sub * monkey * only replace call * fix sub * nits * nits * fix invalid syntax * add skip cuda doctest env variable * make sure all packages are installed * move file * update check repo * revert changes * nit * finish cleanup * fix re * findall * update don't test init files * ignore pycache * `-ignore-pycache` when running pytests * try to fix the import missmatch error * install dec * pytest is required as doctest_utils imports things from it * the only log issues were dataset, ignore results should work * more cleaning * Update .circleci/create_circleci_config.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * [ydshieh] empty string if cuda is found * [ydshieh] fix condition * style * [ydshieh] fix * Add comment * style * style * show failure * trigger CI --------- Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: ydshieh --- .circleci/config.yml | 1 + .circleci/create_circleci_config.py | 85 ++++++++++- .github/workflows/doctests.yml | 8 - Makefile | 8 +- conftest.py | 12 +- docs/source/en/testing.mdx | 12 +- setup.cfg | 1 + src/transformers/utils/__init__.py | 1 + src/transformers/utils/doctest_utils.py | 189 ++++++++++++++++++++++++ utils/prepare_for_doc_test.py | 148 ------------------- 10 files changed, 287 insertions(+), 178 deletions(-) create mode 100644 src/transformers/utils/doctest_utils.py delete mode 100644 utils/prepare_for_doc_test.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 63c6162fc15dfb..9af64904580321 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -166,6 +166,7 @@ jobs: - v0.6-repository_consistency - run: pip install --upgrade pip - run: pip install .[all,quality] + - run: pip install pytest - save_cache: key: v0.5-repository_consistency-{{ checksum "setup.py" }} paths: diff --git a/.circleci/create_circleci_config.py b/.circleci/create_circleci_config.py index 7208d876a97c3a..30898f9e1c2a4d 100644 --- a/.circleci/create_circleci_config.py +++ b/.circleci/create_circleci_config.py @@ -51,6 +51,8 @@ class CircleCIJob: resource_class: Optional[str] = "xlarge" tests_to_run: Optional[List[str]] = None working_directory: str = "~/transformers" + # This should be only used for doctest job! + command_timeout: Optional[int] = None def __post_init__(self): # Deal with defaults for mutable attributes. @@ -107,11 +109,15 @@ def to_dict(self): steps.append({"store_artifacts": {"path": "~/transformers/installed.txt"}}) all_options = {**COMMON_PYTEST_OPTIONS, **self.pytest_options} - pytest_flags = [f"--{key}={value}" if value is not None else f"-{key}" for key, value in all_options.items()] + pytest_flags = [f"--{key}={value}" if (value is not None or key in ["doctest-modules"]) else f"-{key}" for key, value in all_options.items()] pytest_flags.append( f"--make-reports={self.name}" if "examples" in self.name else f"--make-reports=tests_{self.name}" ) - test_command = f"python -m pytest -n {self.pytest_num_workers} " + " ".join(pytest_flags) + test_command = "" + if self.command_timeout: + test_command = f"timeout {self.command_timeout} " + test_command += f"python -m pytest -n {self.pytest_num_workers} " + " ".join(pytest_flags) + if self.parallelism == 1: if self.tests_to_run is None: test_command += " << pipeline.parameters.tests_to_run >>" @@ -161,12 +167,37 @@ def to_dict(self): steps.append({"store_artifacts": {"path": "~/transformers/tests.txt"}}) steps.append({"store_artifacts": {"path": "~/transformers/splitted_tests.txt"}}) - test_command = f"python -m pytest -n {self.pytest_num_workers} " + " ".join(pytest_flags) + test_command = "" + if self.timeout: + test_command = f"timeout {self.timeout} " + test_command += f"python -m pytest -n {self.pytest_num_workers} " + " ".join(pytest_flags) test_command += " $(cat splitted_tests.txt)" if self.marker is not None: test_command += f" -m {self.marker}" - test_command += " | tee tests_output.txt" + + if self.name == "pr_documentation_tests": + # can't use ` | tee tee tests_output.txt` as usual + test_command += " > tests_output.txt" + # Save the return code, so we can check if it is timeout in the next step. + test_command += '; touch "$?".txt' + # Never fail the test step for the doctest job. We will check the results in the next step, and fail that + # step instead if the actual test failures are found. This is to avoid the timeout being reported as test + # failure. + test_command = f"({test_command}) || true" + else: + test_command += " | tee tests_output.txt" steps.append({"run": {"name": "Run tests", "command": test_command}}) + + # return code `124` means the previous (pytest run) step is timeout + if self.name == "pr_documentation_tests": + checkout_doctest_command = 'if [ -s reports/tests_pr_documentation_tests/failures_short.txt ]; ' + checkout_doctest_command += 'then echo "some test failed"; ' + checkout_doctest_command += 'cat reports/tests_pr_documentation_tests/failures_short.txt; ' + checkout_doctest_command += 'cat reports/tests_pr_documentation_tests/summary_short.txt; exit -1; ' + checkout_doctest_command += 'elif [ -s reports/tests_pr_documentation_tests/stats.txt ]; then echo "All tests pass!"; ' + checkout_doctest_command += 'elif [ -f 124.txt ]; then echo "doctest timeout!"; else echo "other fatal error)"; exit -1; fi;' + steps.append({"run": {"name": "Check doctest results", "command": checkout_doctest_command}}) + steps.append({"store_artifacts": {"path": "~/transformers/tests_output.txt"}}) steps.append({"store_artifacts": {"path": "~/transformers/reports"}}) job["steps"] = steps @@ -401,6 +432,51 @@ def job_name(self): tests_to_run="tests/repo_utils", ) +# At this moment, only the files that are in `utils/documentation_tests.txt` will be kept (together with a dummy file). +py_command = 'import os; import json; fp = open("pr_documentation_tests.txt"); data_1 = fp.read().strip().split("\\n"); fp = open("utils/documentation_tests.txt"); data_2 = fp.read().strip().split("\\n"); to_test = [x for x in data_1 if x in set(data_2)] + ["dummy.py"]; to_test = " ".join(to_test); print(to_test)' +py_command = f"$(python3 -c '{py_command}')" +command = f'echo "{py_command}" > pr_documentation_tests_filtered.txt' +doc_test_job = CircleCIJob( + "pr_documentation_tests", + additional_env={"TRANSFORMERS_VERBOSITY": "error", "DATASETS_VERBOSITY": "error", "SKIP_CUDA_DOCTEST": "1"}, + install_steps=[ + "sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev espeak-ng time", + "pip install --upgrade pip", + "pip install -e .[dev]", + "pip install git+https://github.com/huggingface/accelerate", + "pip install --upgrade pytest pytest-sugar", + "find -name __pycache__ -delete", + "find . -name \*.pyc -delete", + # Add an empty file to keep the test step running correctly even no file is selected to be tested. + "touch dummy.py", + { + "name": "Get files to test", + "command": + "git remote add upstream https://github.com/huggingface/transformers.git && git fetch upstream \n" + "git diff --name-only --relative --diff-filter=AMR refs/remotes/upstream/main...HEAD | grep -E '\.(py|mdx)$' | grep -Ev '^\..*|/\.' | grep -Ev '__' > pr_documentation_tests.txt" + }, + { + "name": "List files beings changed: pr_documentation_tests.txt", + "command": + "cat pr_documentation_tests.txt" + }, + { + "name": "Filter pr_documentation_tests.txt", + "command": + command + }, + { + "name": "List files beings tested: pr_documentation_tests_filtered.txt", + "command": + "cat pr_documentation_tests_filtered.txt" + }, + ], + tests_to_run="$(cat pr_documentation_tests_filtered.txt)", # noqa + pytest_options={"-doctest-modules": None, "doctest-glob": "*.mdx", "dist": "loadfile", "rvsA": None}, + command_timeout=1200, # test cannot run longer than 1200 seconds + pytest_num_workers=1, +) + REGULAR_TESTS = [ torch_and_tf_job, torch_and_flax_job, @@ -411,6 +487,7 @@ def job_name(self): hub_job, onnx_job, exotic_models_job, + doc_test_job ] EXAMPLES_TESTS = [ examples_torch_job, diff --git a/.github/workflows/doctests.yml b/.github/workflows/doctests.yml index a0efe40cbbe982..55c09b1acc829b 100644 --- a/.github/workflows/doctests.yml +++ b/.github/workflows/doctests.yml @@ -37,18 +37,10 @@ jobs: - name: Show installed libraries and their versions run: pip freeze - - name: Prepare files for doctests - run: | - python3 utils/prepare_for_doc_test.py src docs - - name: Run doctests run: | python3 -m pytest -v --make-reports doc_tests_gpu --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure --doctest-glob="*.mdx" - - name: Clean files after doctests - run: | - python3 utils/prepare_for_doc_test.py src docs --remove_new_line - - name: Failure short reports if: ${{ failure() }} continue-on-error: true diff --git a/Makefile b/Makefile index 5e5a11a1fee033..d6d6966a1dadfd 100644 --- a/Makefile +++ b/Makefile @@ -47,10 +47,10 @@ repo-consistency: # this target runs checks on all files quality: - black --check $(check_dirs) setup.py + black --check $(check_dirs) setup.py conftest.py python utils/custom_init_isort.py --check_only python utils/sort_auto_mappings.py --check_only - ruff $(check_dirs) setup.py + ruff $(check_dirs) setup.py conftest.py doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source python utils/check_doc_toc.py @@ -65,8 +65,8 @@ extra_style_checks: # this target runs checks on all files and potentially modifies some of them style: - black $(check_dirs) setup.py - ruff $(check_dirs) setup.py --fix + black $(check_dirs) setup.py conftest.py + ruff $(check_dirs) setup.py conftest.py --fix ${MAKE} autogenerate_code ${MAKE} extra_style_checks diff --git a/conftest.py b/conftest.py index 53efec7a6c2d20..c57fac2b1d9cb6 100644 --- a/conftest.py +++ b/conftest.py @@ -20,6 +20,10 @@ import warnings from os.path import abspath, dirname, join +import _pytest + +from transformers.utils.doctest_utils import HfDoctestModule, HfDocTestParser + # allow having multiple repository checkouts and not needing to remember to rerun # 'pip install -e .[dev]' when switching between checkouts and running tests. @@ -38,9 +42,7 @@ def pytest_configure(config): config.addinivalue_line( "markers", "is_pt_flax_cross_test: mark test to run only when PT and FLAX interactions are tested" ) - config.addinivalue_line( - "markers", "is_pipeline_test: mark test to run only when pipelines are tested" - ) + config.addinivalue_line("markers", "is_pipeline_test: mark test to run only when pipelines are tested") config.addinivalue_line("markers", "is_staging_test: mark test to run only in the staging environment") config.addinivalue_line("markers", "accelerate_tests: mark test that require accelerate") @@ -66,7 +68,7 @@ def pytest_sessionfinish(session, exitstatus): # Doctest custom flag to ignore output. -IGNORE_RESULT = doctest.register_optionflag('IGNORE_RESULT') +IGNORE_RESULT = doctest.register_optionflag("IGNORE_RESULT") OutputChecker = doctest.OutputChecker @@ -79,3 +81,5 @@ def check_output(self, want, got, optionflags): doctest.OutputChecker = CustomOutputChecker +_pytest.doctest.DoctestModule = HfDoctestModule +doctest.DocTestParser = HfDocTestParser diff --git a/docs/source/en/testing.mdx b/docs/source/en/testing.mdx index 4663b8ac4d9338..5adbc8e44db796 100644 --- a/docs/source/en/testing.mdx +++ b/docs/source/en/testing.mdx @@ -212,20 +212,12 @@ Example: ```""" ``` -3 steps are required to debug the docstring examples: -1. In order to properly run the test, **an extra line has to be added** at the end of the docstring. This can be automatically done on any file using: -```bash -python utils/prepare_for_doc_test.py -``` -2. Then, you can use the following line to automatically test every docstring example in the desired file: +Just run the following line to automatically test every docstring example in the desired file: ```bash pytest --doctest-modules ``` -3. Once you are done debugging, you need to remove the extra line added in step **1.** by running the following: -```bash -python utils/prepare_for_doc_test.py --remove_new_line -``` +If the file has a markdown extention, you should add the `--doctest-glob="*.mdx"` argument. ### Run only modified tests diff --git a/setup.cfg b/setup.cfg index 5f47c5c6be699a..8b84d3a6d9b93b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,3 @@ [tool:pytest] doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS +doctest_glob=**/*.mdx \ No newline at end of file diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 42e856d9e4acb3..0600eb382818d9 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -27,6 +27,7 @@ copy_func, replace_return_docstrings, ) +from .doctest_utils import HfDocTestParser from .generic import ( ContextManagers, ExplicitEnum, diff --git a/src/transformers/utils/doctest_utils.py b/src/transformers/utils/doctest_utils.py new file mode 100644 index 00000000000000..90f37e8ce69457 --- /dev/null +++ b/src/transformers/utils/doctest_utils.py @@ -0,0 +1,189 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utils to run the documentation tests without having to overwrite any files. + +The `preprocess_string` function adds `# doctest: +IGNORE_RESULT` markers on the fly anywhere a `load_dataset` call is +made as a print would otherwise fail the corresonding line. + +To skip cuda tests, make sure to call `SKIP_CUDA_DOCTEST=1 pytest --doctest-modules +""" +import doctest +import inspect +import os +import re +from typing import Iterable + +from _pytest.doctest import ( + Module, + _get_checker, + _get_continue_on_failure, + _get_runner, + _is_mocked, + _patch_unwrap_mock_aware, + get_optionflags, + import_path, +) +from _pytest.outcomes import skip +from pytest import DoctestItem + + +def preprocess_string(string, skip_cuda_tests): + """Prepare a docstring or a `.mdx` file to be run by doctest. + + The argument `string` would be the whole file content if it is a `.mdx` file. For a python file, it would be one of + its docstring. In each case, it may contain multiple python code examples. If `skip_cuda_tests` is `True` and a + cuda stuff is detective (with a heuristic), this method will return an empty string so no doctest will be run for + `string`. + """ + codeblock_pattern = r"(```(?:python|py)\s*\n\s*>>> )((?:.*?\n)*?.*?```)" + codeblocks = re.split(re.compile(codeblock_pattern, flags=re.MULTILINE | re.DOTALL), string) + is_cuda_found = False + for i, codeblock in enumerate(codeblocks): + if "load_dataset(" in codeblock and "# doctest: +IGNORE_RESULT" not in codeblock: + codeblocks[i] = re.sub(r"(>>> .*load_dataset\(.*)", r"\1 # doctest: +IGNORE_RESULT", codeblock) + if ( + (">>>" in codeblock or "..." in codeblock) + and re.search(r"cuda|to\(0\)|device=0", codeblock) + and skip_cuda_tests + ): + is_cuda_found = True + break + modified_string = "" + if not is_cuda_found: + modified_string = "".join(codeblocks) + return modified_string + + +class HfDocTestParser(doctest.DocTestParser): + """ + Overwrites the DocTestParser from doctest to properly parse the codeblocks that are formatted with black. This + means that there are no extra lines at the end of our snippets. The `# doctest: +IGNORE_RESULT` marker is also + added anywhere a `load_dataset` call is made as a print would otherwise fail the corresponding line. + + Tests involving cuda are skipped base on a naive pattern that should be updated if it is not enough. + """ + + # This regular expression is used to find doctest examples in a + # string. It defines three groups: `source` is the source code + # (including leading indentation and prompts); `indent` is the + # indentation of the first (PS1) line of the source code; and + # `want` is the expected output (including leading indentation). + # fmt: off + _EXAMPLE_RE = re.compile(r''' + # Source consists of a PS1 line followed by zero or more PS2 lines. + (?P + (?:^(?P [ ]*) >>> .*) # PS1 line + (?:\n [ ]* \.\.\. .*)*) # PS2 lines + \n? + # Want consists of any non-blank lines that do not start with PS1. + (?P (?:(?![ ]*$) # Not a blank line + (?![ ]*>>>) # Not a line starting with PS1 + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + (?:(?!```).)* # Match any character except '`' until a '```' is found (this is specific to HF because black removes the last line) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + (?:\n|$) # Match a new line or end of string + )*) + ''', re.MULTILINE | re.VERBOSE + ) + # fmt: on + + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + skip_cuda_tests: bool = bool(os.environ.get("SKIP_CUDA_DOCTEST", False)) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + + def parse(self, string, name=""): + """ + Overwrites the `parse` method to incorporate a skip for CUDA tests, and remove logs and dataset prints before + calling `super().parse` + """ + string = preprocess_string(string, self.skip_cuda_tests) + return super().parse(string, name) + + +class HfDoctestModule(Module): + """ + Overwrites the `DoctestModule` of the pytest package to make sure the HFDocTestParser is used when discovering + tests. + """ + + def collect(self) -> Iterable[DoctestItem]: + class MockAwareDocTestFinder(doctest.DocTestFinder): + """A hackish doctest finder that overrides stdlib internals to fix a stdlib bug. + + https://github.com/pytest-dev/pytest/issues/3456 https://bugs.python.org/issue25532 + """ + + def _find_lineno(self, obj, source_lines): + """Doctest code does not take into account `@property`, this + is a hackish way to fix it. https://bugs.python.org/issue17446 + + Wrapped Doctests will need to be unwrapped so the correct line number is returned. This will be + reported upstream. #8796 + """ + if isinstance(obj, property): + obj = getattr(obj, "fget", obj) + + if hasattr(obj, "__wrapped__"): + # Get the main obj in case of it being wrapped + obj = inspect.unwrap(obj) + + # Type ignored because this is a private function. + return super()._find_lineno( # type:ignore[misc] + obj, + source_lines, + ) + + def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None: + if _is_mocked(obj): + return + with _patch_unwrap_mock_aware(): + # Type ignored because this is a private function. + super()._find( # type:ignore[misc] + tests, obj, name, module, source_lines, globs, seen + ) + + if self.path.name == "conftest.py": + module = self.config.pluginmanager._importconftest( + self.path, + self.config.getoption("importmode"), + rootpath=self.config.rootpath, + ) + else: + try: + module = import_path( + self.path, + root=self.config.rootpath, + mode=self.config.getoption("importmode"), + ) + except ImportError: + if self.config.getvalue("doctest_ignore_import_errors"): + skip("unable to import module %r" % self.path) + else: + raise + + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + finder = MockAwareDocTestFinder(parser=HfDocTestParser()) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + optionflags = get_optionflags(self) + runner = _get_runner( + verbose=False, + optionflags=optionflags, + checker=_get_checker(), + continue_on_failure=_get_continue_on_failure(self.config), + ) + for test in finder.find(module, module.__name__): + if test.examples: # skip empty doctests and cuda + yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test) diff --git a/utils/prepare_for_doc_test.py b/utils/prepare_for_doc_test.py deleted file mode 100644 index c55f3540d99414..00000000000000 --- a/utils/prepare_for_doc_test.py +++ /dev/null @@ -1,148 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Style utils to preprocess files for doc tests. - - The doc precossing function can be run on a list of files and/org - directories of files. It will recursively check if the files have - a python code snippet by looking for a ```python or ```py syntax. - In the default mode - `remove_new_line==False` the script will - add a new line before every python code ending ``` line to make - the docstrings ready for pytest doctests. - However, we don't want to have empty lines displayed in the - official documentation which is why the new line command can be - reversed by adding the flag `--remove_new_line` which sets - `remove_new_line==True`. - - When debugging the doc tests locally, please make sure to - always run: - - ```python utils/prepare_for_doc_test.py src docs``` - - before running the doc tests: - - ```pytest --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure --doctest-glob="*.mdx"``` - - Afterwards you should revert the changes by running - - ```python utils/prepare_for_doc_test.py src docs --remove_new_line``` -""" - -import argparse -import os - - -def process_code_block(code, add_new_line=True): - if add_new_line: - return maybe_append_new_line(code) - else: - return maybe_remove_new_line(code) - - -def maybe_append_new_line(code): - """ - Append new line if code snippet is a - Python code snippet - """ - lines = code.split("\n") - - if lines[0] in ["py", "python"]: - # add new line before last line being ``` - last_line = lines[-1] - lines.pop() - lines.append("\n" + last_line) - - return "\n".join(lines) - - -def maybe_remove_new_line(code): - """ - Remove new line if code snippet is a - Python code snippet - """ - lines = code.split("\n") - - if lines[0] in ["py", "python"]: - # add new line before last line being ``` - lines = lines[:-2] + lines[-1:] - - return "\n".join(lines) - - -def process_doc_file(code_file, add_new_line=True): - """ - Process given file. - - Args: - code_file (`str` or `os.PathLike`): The file in which we want to style the docstring. - """ - with open(code_file, "r", encoding="utf-8", newline="\n") as f: - code = f.read() - - # fmt: off - splits = code.split("```") - if len(splits) % 2 != 1: - raise ValueError("The number of occurrences of ``` should be an even number.") - - splits = [s if i % 2 == 0 else process_code_block(s, add_new_line=add_new_line) for i, s in enumerate(splits)] - clean_code = "```".join(splits) - # fmt: on - - diff = clean_code != code - if diff: - print(f"Overwriting content of {code_file}.") - with open(code_file, "w", encoding="utf-8", newline="\n") as f: - f.write(clean_code) - - -def process_doc_files(*files, add_new_line=True): - """ - Applies doc styling or checks everything is correct in a list of files. - - Args: - files (several `str` or `os.PathLike`): The files to treat. - Whether to restyle file or just check if they should be restyled. - - Returns: - List[`str`]: The list of files changed or that should be restyled. - """ - for file in files: - # Treat folders - if os.path.isdir(file): - files = [os.path.join(file, f) for f in os.listdir(file)] - files = [f for f in files if os.path.isdir(f) or f.endswith(".mdx") or f.endswith(".py")] - process_doc_files(*files, add_new_line=add_new_line) - else: - try: - process_doc_file(file, add_new_line=add_new_line) - except Exception: - print(f"There is a problem in {file}.") - raise - - -def main(*files, add_new_line=True): - process_doc_files(*files, add_new_line=add_new_line) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("files", nargs="+", help="The file(s) or folder(s) to restyle.") - parser.add_argument( - "--remove_new_line", - action="store_true", - help="Whether to remove new line after each python code block instead of adding one.", - ) - args = parser.parse_args() - - main(*args.files, add_new_line=not args.remove_new_line)