-
-
Notifications
You must be signed in to change notification settings - Fork 42
/
__main__.py
712 lines (600 loc) · 21 KB
/
__main__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
"""Run third-party tool (e.g. :code:`mypy`) against notebook or directory."""
import json
import os
import re
import subprocess
import sys
import tempfile
from importlib import import_module
from pathlib import Path
from textwrap import dedent
from typing import (
Any,
Dict,
Iterator,
Mapping,
MutableMapping,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
)
import tomli
from pkg_resources import parse_version
from nbqa import replace_source, save_code_source, save_markdown_source
from nbqa.cmdline import CLIArgs
from nbqa.config.config import Configs, get_default_config
from nbqa.find_root import find_project_root
from nbqa.notebook_info import NotebookInfo
from nbqa.optional import metadata
from nbqa.output_parser import Output, map_python_line_to_nb_lines
from nbqa.path_utils import get_relative_and_absolute_paths, remove_suffix
from nbqa.save_code_source import CODE_SEPARATOR
from nbqa.text import BOLD, RESET
BASE_ERROR_MESSAGE = (
f'{BOLD}nbQA failed to process {{notebook}} with exception "{{exp}}"{RESET}\n'
)
MIN_VERSIONS = {"isort": "5.3.0"}
VIRTUAL_ENVIRONMENTS_URL = (
"https://realpython.com/python-virtual-environments-a-primer/"
)
EXCLUDES = (
r"/("
r"\.direnv|\.eggs|\.git|\.hg|\.ipynb_checkpoints|\.mypy_cache|\.nox|\.svn|\.tox|\.venv|"
r"_build|buck-out|build|dist|venv"
r")/"
)
REPLACE_FUNCTION = {
True: replace_source.diff,
False: replace_source.mutate,
}
SUFFIX = {False: ".py", True: ".md"}
COMMAND_TO_PYTHON_MODULE = {"blacken-docs": "blacken_docs"}
class TemporaryFile(NamedTuple):
"""Temporary file and file descriptor."""
fd: int
file: str
class SavedSources(NamedTuple):
"""Mapping between notebooks and Python files, failed notebooks, non-Python notebooks."""
nb_info_mapping: Mapping[str, NotebookInfo]
failed_notebooks: MutableMapping[str, str]
non_python_notebooks: Set[str]
class UnsupportedPackageVersionError(Exception):
"""Raise if installed module is older than minimum required version."""
def __init__(self, command: str, current_version: str, min_version: str) -> None:
"""Initialise with command, current version, and minimum version."""
self.msg = (
f"{BOLD}nbqa only works with {command} >= {min_version}, "
f"while you have {current_version} installed.{RESET}"
)
super().__init__(self.msg)
def _get_notebooks(root_dir: str) -> Iterator[Path]:
"""
Get generator with all notebooks in directory.
Parameters
----------
root_dir
Notebook or directory to run third-party tool on.
Returns
-------
notebooks
All Jupyter Notebooks found in directory.
"""
if not os.path.isdir(root_dir):
return iter((Path(root_dir),))
return (
i
for i in Path(root_dir).rglob("*.ipynb")
if not re.search(EXCLUDES, str(i.resolve().as_posix()))
)
def _filter_by_include_exclude(
notebooks: Iterator[Path],
include: Optional[str],
exclude: Optional[str],
) -> Iterator[str]:
"""
Include files which match include, exclude those matching exclude.
notebooks
Notebooks (not directories) to run code quality tool on.
include:
Global file include pattern.
exclude:
Global file exclude pattern.
Returns
-------
Iterator
Notebooks matching include and not matching exclude.
"""
include = include or ""
exclude = exclude or "^$"
include_re, exclude_re = re.compile(include), re.compile(exclude)
return (
str(notebook)
for notebook in notebooks
if include_re.search(str(notebook.as_posix()))
if not exclude_re.search(str(notebook.as_posix()))
)
def _get_all_notebooks(
root_dirs: Sequence[str], files: Optional[str], exclude: Optional[str]
) -> Iterator[str]:
"""
Get generator with all notebooks passed in via the command-line, applying exclusions.
Parameters
----------
root_dirs
All the notebooks/directories passed in via the command-line.
files
Pattern of files to include.
exclude
Pattern of files to exclude.
Returns
-------
Iterator
All Jupyter Notebooks found in all passed directories/notebooks.
"""
return _filter_by_include_exclude(
(j for i in root_dirs for j in _get_notebooks(i)), files, exclude
)
def _replace_temp_python_file_references_in_out_err(
temp_python_file: str,
notebook: str,
out: str,
err: str,
*,
md: bool,
) -> Output:
"""
Replace references to temporary Python file with references to notebook.
Parameters
----------
temp_python_file
Temporary Python file where notebook was converted to.
notebook
Original Jupyter notebook.
out
Captured stdout from third-party tool.
err
Captured stderr from third-party tool.
Returns
-------
Output
Stdout, stderr with temporary directory replaced by current working directory.
"""
py_basename = os.path.basename(temp_python_file)
nb_basename = os.path.basename(notebook)
out = out.replace(py_basename, nb_basename)
err = err.replace(py_basename, nb_basename)
out = out.replace(
remove_suffix(py_basename, SUFFIX[md]), remove_suffix(nb_basename, ".ipynb")
)
err = err.replace(
remove_suffix(py_basename, SUFFIX[md]), remove_suffix(nb_basename, ".ipynb")
)
return Output(out, err)
def _get_mtimes(arg: str) -> Set[float]:
"""
Get the modification times of any converted notebooks.
Parameters
----------
arg
Notebook to run 3rd party tool on.
Returns
-------
Set
Modification times of any converted notebooks.
"""
return {os.path.getmtime(arg)}
def _run_command(
command: str,
cmd_args: Sequence[str],
args: Sequence[str],
) -> Tuple[Output, int, bool]:
"""
Run third-party tool against given file or directory.
Parameters
----------
command
Third-party tool (e.g. :code:`mypy`) to run against notebook.
cmd_args
Flags to pass to third-party tool (e.g. :code:`--verbose`).
args
Notebooks, or directories of notebooks, third-party tool is being run on.
Returns
-------
output
Captured stdout, stderr from running third-party tool.
output_code
Return code from third-party tool.
mutated
Whether 3rd party tool modified any files.
Raises
------
ValueError
If third-party tool isn't found in system.
"""
before = [_get_mtimes(i) for i in args]
my_env = os.environ.copy()
if command == "mypy" and "MYPY_FORCE_COLOR" not in my_env:
my_env["MYPY_FORCE_COLOR"] = "1"
python_module = COMMAND_TO_PYTHON_MODULE.get(command, command)
output = subprocess.run(
[sys.executable, "-m", python_module, *args, *cmd_args],
capture_output=True,
text=True,
env=my_env,
)
mutated = [_get_mtimes(i) for i in args] != before
output_code = output.returncode
out = output.stdout
err = output.stderr
return Output(out, err), output_code, mutated
def _get_command_not_found_msg(command: str) -> str:
"""Return the message to display when the command is not found by nbqa.
Parameters
----------
command : str
Command passed to nbqa to find.
Returns
-------
str
Message to display to stdout.
"""
template = dedent(
f"""\
{BOLD}Command `{command}` not found by nbqa.{RESET}
Please make sure you have it installed in the same Python environment as nbqa. See
e.g. {VIRTUAL_ENVIRONMENTS_URL} for how to set up
a virtual environment in Python, and run:
`python -m pip install {command}`.
"""
)
python_executable = sys.executable
nbqa_file = sys.modules["nbqa"].__file__
assert nbqa_file is not None
nbqa_loc = str(Path(nbqa_file).parent)
return template.format(python=python_executable, nbqa_loc=nbqa_loc)
def _get_configs(cli_args: CLIArgs, project_root: Path) -> Configs:
"""
Deal with extra configs for 3rd party tool.
Parameters
----------
cli_args
Commandline arguments passed to nbqa
project_root
Root of repository, where .git / .hg / .nbqa.ini file is.
Returns
-------
Configs
Taken from CLI (if given), else from .nbqa.ini.
"""
# start with default config.
config = get_default_config()
# If a section is in pyproject.toml, use that.
pyproject_path = project_root / "pyproject.toml"
if pyproject_path.is_file():
config_file = tomli.loads(pyproject_path.read_text("utf-8"))
if "tool" in config_file and "nbqa" in config_file["tool"]:
file_config = config_file["tool"]["nbqa"]
for section in config:
if section in file_config and cli_args.command in file_config[section]:
# TypedDict key must be a string literal
config[section] = file_config[section][cli_args.command] # type: ignore
# If a section was passed via CLI, use that.
for section in config:
if getattr(cli_args, section) is not None:
if section == "addopts":
# addopts are added to / overridden rather than replaced outright
config["addopts"] = (*config["addopts"], *getattr(cli_args, section))
else:
# TypedDict key must be a string literal
config[section] = getattr(cli_args, section) # type: ignore
# add default options
if cli_args.command == "isort":
config["addopts"] = (
*config["addopts"],
"--treat-comment-as-code",
CODE_SEPARATOR.rstrip("\n"),
)
return config
def _clean_up_tmp_files(nb_to_py_mapping: Mapping[str, Tuple[int, str]]) -> None:
"""Remove temporary files."""
for file_descriptor, tmp_path in nb_to_py_mapping.values():
try:
os.close(file_descriptor)
except OSError:
# was already closed
pass
os.remove(tmp_path)
def _get_nb_to_tmp_mapping(
root_dirs: Sequence[str], files: Optional[str], exclude: Optional[str], md: bool
) -> Dict[str, TemporaryFile]:
"""
Get mapping between notebooks and temporary files.
Parameters
----------
root_dirs
All the notebooks/directories passed in via the command-line.
files
Pattern of files to include.
exclude
Pattern of files to exclude.
Returns
-------
Dict[str, Tuple[int, str]]
Mapping between notebooks and temporary files.
Raises
------
FileNotFoundError
If notebook isn't found.
"""
nb_to_tmp_mapping: Dict[str, TemporaryFile] = {}
for notebook in _get_all_notebooks(root_dirs, files, exclude):
if not os.path.exists(notebook):
_clean_up_tmp_files(nb_to_tmp_mapping)
raise FileNotFoundError(
f"{BOLD}No such file or directory: {notebook}{RESET}\n"
)
nb_to_tmp_mapping[notebook] = TemporaryFile(
*tempfile.mkstemp(
dir=os.path.dirname(notebook),
prefix=remove_suffix(os.path.basename(notebook), ".ipynb"),
suffix=SUFFIX[md],
)
)
relative_path, _ = get_relative_and_absolute_paths(
nb_to_tmp_mapping[notebook].file
)
nb_to_tmp_mapping[notebook] = nb_to_tmp_mapping[notebook]._replace(
file=relative_path
)
return nb_to_tmp_mapping
def _print_failed_notebook_errors(failed_notebooks: Mapping[str, str]) -> None:
"""Print exceptions from failed notebooks."""
sys.stderr.write("\n")
for failure, exp_repr in failed_notebooks.items():
sys.stderr.write(BASE_ERROR_MESSAGE.format(notebook=failure, exp=exp_repr))
sys.stderr.write(
f"{BOLD}\n"
"If you believe the notebook(s) to be valid, please "
f"report a bug at https://github.com/nbQA-dev/nbQA/issues {RESET}\n"
)
sys.stderr.write("\n")
def _is_non_python_notebook(notebook: MutableMapping[str, Any]) -> bool:
"""
If notebook is marked as non-Python, don't format it.
All notebook metadata fields are optional, see
https://nbformat.readthedocs.io/en/latest/format_description.html. So
if a notebook has empty metadata, we will try to parse it anyway.
"""
language = notebook.get("metadata", {}).get("language_info", {}).get("name", None)
return language is not None and language != "python"
def _save_code_sources(
nb_to_py_mapping: Dict[str, TemporaryFile],
process_cells: Sequence[str],
skip_celltags: Sequence[str],
dont_skip_bad_cells: bool,
command: str,
) -> SavedSources:
"""
Save sources of notebooks.
Record which notebooks fail to process, and which ones are non-Python ones.
"""
failed_notebooks = {}
non_python_notebooks = set()
nb_info_mapping: MutableMapping[str, NotebookInfo] = {}
for notebook, (file_descriptor, _) in nb_to_py_mapping.items():
with open(str(notebook), encoding="utf-8") as handle:
content = handle.read()
try:
notebook_json = json.loads(content)
if _is_non_python_notebook(notebook_json):
non_python_notebooks.add(notebook)
continue
nb_info_mapping[notebook] = save_code_source.main(
notebook_json,
file_descriptor,
process_cells,
command,
skip_celltags,
dont_skip_bad_cells=dont_skip_bad_cells,
)
except Exception as exp_repr: # pylint: disable=W0703
failed_notebooks[notebook] = repr(exp_repr)
return SavedSources(nb_info_mapping, failed_notebooks, non_python_notebooks)
def _save_markdown_sources(
nb_to_md_mapping: Dict[str, TemporaryFile],
process_cells: Sequence[str], # pylint: disable=W0613
skip_celltags: Sequence[str],
dont_skip_bad_cells: bool, # pylint: disable=W0613
command: str, # pylint: disable=W0613
) -> SavedSources:
"""
Save markdown sources of notebooks.
Record which notebooks fail to process.
"""
failed_notebooks = {}
nb_info_mapping: MutableMapping[str, NotebookInfo] = {}
for notebook, (file_descriptor, _) in nb_to_md_mapping.items():
with open(str(notebook), encoding="utf-8") as handle:
content = handle.read()
try:
notebook_json = json.loads(content)
nb_info_mapping[notebook] = save_markdown_source.main(
notebook_json,
file_descriptor,
skip_celltags,
)
except Exception as exp_repr: # pylint: disable=W0703
failed_notebooks[notebook] = repr(exp_repr)
return SavedSources(nb_info_mapping, failed_notebooks, set())
SAVE_SOURCES = {False: _save_code_sources, True: _save_markdown_sources}
def _post_process_notebooks( # pylint: disable=R0913
saved_sources: SavedSources,
nb_to_py_mapping: Mapping[str, TemporaryFile],
mutated: bool,
diff: bool,
command: str,
output: Output,
*,
md: bool,
) -> Tuple[bool, Output]:
"""Replace source in notebooks, modify output so it refers to notebooks."""
actually_mutated = False
for notebook, (_, temp_python_file) in nb_to_py_mapping.items():
if (
notebook in saved_sources.failed_notebooks
or notebook in saved_sources.non_python_notebooks
):
continue
output = _replace_temp_python_file_references_in_out_err(
temp_python_file, notebook, output.out, output.err, md=md
)
output = map_python_line_to_nb_lines(
command,
output.out,
output.err,
notebook,
saved_sources.nb_info_mapping[notebook].cell_mappings,
)
if mutated:
try:
actually_mutated = (
REPLACE_FUNCTION[diff](
temp_python_file,
notebook,
saved_sources.nb_info_mapping[notebook],
md=md,
)
or actually_mutated
)
except Exception as exp_repr: # pylint: disable=W0703
saved_sources.failed_notebooks[notebook] = repr(exp_repr)
return actually_mutated, output
def _main(cli_args: CLIArgs, configs: Configs) -> int:
"""
Run third-party tool on a single notebook or directory.
Parameters
----------
cli_args
Commanline arguments passed to nbqa.
configs
Configuration passed to nbqa from commandline or via a config file
Returns
-------
int
Output code from third-party tool.
"""
try:
nb_to_tmp_mapping = _get_nb_to_tmp_mapping(
cli_args.root_dirs, configs["files"], configs["exclude"], configs["md"]
)
except FileNotFoundError as exc:
sys.stderr.write(str(exc))
return 1
try: # pylint disable=R0912
if not nb_to_tmp_mapping:
sys.stderr.write(
"No .ipynb notebooks found in given directories: "
f"{' '.join(i for i in cli_args.root_dirs if os.path.isdir(i))}\n"
)
return 0
saved_sources = SAVE_SOURCES[configs["md"]](
nb_to_tmp_mapping,
configs["process_cells"],
configs["skip_celltags"],
configs["dont_skip_bad_cells"],
cli_args.command,
)
if len(saved_sources.failed_notebooks) == len(nb_to_tmp_mapping):
sys.stderr.write("No valid .ipynb notebooks found\n")
_print_failed_notebook_errors(saved_sources.failed_notebooks)
return 123
output, output_code, mutated = _run_command(
cli_args.command,
configs["addopts"],
[
i.file
for key, i in nb_to_tmp_mapping.items()
if key not in saved_sources.failed_notebooks
],
)
actually_mutated, output = _post_process_notebooks(
saved_sources,
nb_to_tmp_mapping,
mutated,
configs["diff"],
cli_args.command,
output,
md=configs["md"],
)
sys.stdout.write(output.out)
sys.stderr.write(output.err)
if mutated and not actually_mutated:
output_code = 0
mutated = False
if saved_sources.failed_notebooks:
output_code = 123
_print_failed_notebook_errors(saved_sources.failed_notebooks)
if configs["diff"]:
if mutated:
sys.stdout.write(
"To apply these changes, remove the `--nbqa-diff` flag\n"
)
else:
sys.stdout.write("Notebook(s) would be left unchanged\n")
# For diff, we return 0 if no mutation would've occurred, and 1 otherwise.
return int(mutated)
finally:
_clean_up_tmp_files(nb_to_tmp_mapping)
return output_code
def _check_command_is_installed(command: str) -> None:
"""
Check whether third-party tool is installed.
Parameters
----------
command
Third-party tool being run on notebook(s).
Raises
------
ModuleNotFoundError
If third-party tool isn't installed.
UnsupportedPackageVersionError
If third-party tool is of an unsupported version.
"""
python_module = COMMAND_TO_PYTHON_MODULE.get(command, command)
try:
command_version = metadata.version(python_module)
except metadata.PackageNotFoundError:
try:
import_module(python_module)
except ImportError as exc:
if not os.path.isdir(python_module) and not os.path.isfile(
f"{os.path.join(*python_module.split('.'))}.py"
): # pragma: nocover(py<37)
# I presume the lack of coverage in Python3.6 here is a bug, as all
# these branches are actually covered.
raise ModuleNotFoundError(_get_command_not_found_msg(command)) from exc
else:
if command in MIN_VERSIONS:
min_version = MIN_VERSIONS[command]
if parse_version(command_version) < parse_version(min_version):
raise UnsupportedPackageVersionError(
command, command_version, min_version
)
def main(argv: Optional[Sequence[str]] = None) -> int:
"""
Run third-party tool (e.g. :code:`mypy`) against notebook or directory.
Parameters
----------
argv
Command-line arguments (if calling this function directly), defaults to
:code:`None` if calling via command-line.
"""
cli_args: CLIArgs = CLIArgs.parse_args(argv)
_check_command_is_installed(cli_args.command)
project_root: Path = find_project_root(tuple(cli_args.root_dirs))
configs: Configs = _get_configs(cli_args, project_root)
return _main(cli_args, configs)
if __name__ == "__main__":
sys.exit(main())