Skip to content

Commit

Permalink
5762 pprint head and tail bundle script (#5969)
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>

Fixes #5762

### Description

limiting the number of printing lines

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Feb 10, 2023
1 parent 94feae5 commit 5657b8f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 3 deletions.
6 changes: 3 additions & 3 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import ast
import json
import os
import pprint
import re
import time
import warnings
Expand All @@ -37,7 +36,7 @@
from monai.data import load_net_with_metadata, save_net_with_metadata
from monai.networks import convert_to_torchscript, copy_model_state, get_state_dict, save_state
from monai.utils import check_parent_dir, get_equivalent_dtype, min_version, optional_import
from monai.utils.misc import ensure_tuple
from monai.utils.misc import ensure_tuple, pprint_edges

validate, _ = optional_import("jsonschema", name="validate")
ValidationError, _ = optional_import("jsonschema.exceptions", name="ValidationError")
Expand All @@ -48,6 +47,7 @@

# set BUNDLE_DOWNLOAD_SRC="ngc" to use NGC source in default for bundle download
download_source = os.environ.get("BUNDLE_DOWNLOAD_SRC", "github")
PPRINT_CONFIG_N = 5


def _update_args(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
Expand Down Expand Up @@ -88,7 +88,7 @@ def _pop_args(src: dict, *args: Any, **kwargs: Any) -> tuple:
def _log_input_summary(tag: str, args: dict) -> None:
logger.info(f"--- input summary of monai.bundle.scripts.{tag} ---")
for name, val in args.items():
logger.info(f"> {name}: {pprint.pformat(val)}")
logger.info(f"> {name}: {pprint_edges(val, PPRINT_CONFIG_N)}")
logger.info("---\n\n")


Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
issequenceiterable,
list_to_dict,
path_to_uri,
pprint_edges,
progress_bar,
sample_slices,
save_obj,
Expand Down
16 changes: 16 additions & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import inspect
import itertools
import os
import pprint
import random
import shutil
import tempfile
Expand Down Expand Up @@ -60,6 +61,7 @@
"save_obj",
"label_union",
"path_to_uri",
"pprint_edges",
]

_seed = None
Expand Down Expand Up @@ -626,3 +628,17 @@ def path_to_uri(path: PathLike) -> str:
"""
return Path(path).absolute().as_uri()


def pprint_edges(val: Any, n_lines: int = 20) -> str:
"""
Pretty print the head and tail ``n_lines`` of ``val``, and omit the middle part if the part has more than 3 lines.
Returns: the formatted string.
"""
val_str = pprint.pformat(val).splitlines(True)
n_lines = max(n_lines, 1)
if len(val_str) > n_lines * 2 + 3:
hidden_n = len(val_str) - n_lines * 2
val_str = val_str[:n_lines] + [f"\n ... omitted {hidden_n} line(s)\n\n"] + val_str[-n_lines:]
return "".join(val_str)
12 changes: 12 additions & 0 deletions tests/test_bundle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from monai.bundle.utils import load_bundle_config
from monai.networks.nets import UNet
from monai.utils import pprint_edges
from tests.utils import command_line_tests, skip_if_windows

metadata = """
Expand Down Expand Up @@ -117,5 +118,16 @@ def test_load_config_ts(self):
self.assertEqual(p["test_dict"]["b"], "c")


class TestPPrintEdges(unittest.TestCase):
def test_str(self):
self.assertEqual(pprint_edges("", 0), "''")
self.assertEqual(pprint_edges({"a": 1, "b": 2}, 0), "{'a': 1, 'b': 2}")
self.assertEqual(
pprint_edges([{"a": 1, "b": 2}] * 20, 1),
"[{'a': 1, 'b': 2},\n\n ... omitted 18 line(s)\n\n {'a': 1, 'b': 2}]",
)
self.assertEqual(pprint_edges([{"a": 1, "b": 2}] * 8, 4), pprint_edges([{"a": 1, "b": 2}] * 8, 3))


if __name__ == "__main__":
unittest.main()

0 comments on commit 5657b8f

Please sign in to comment.