Skip to content

Commit

Permalink
Merge pull request #7 from soldni/soldni/rich_argparse
Browse files Browse the repository at this point in the history
Adding more rich features
  • Loading branch information
soldni authored Oct 19, 2022
2 parents dc05648 + 88825bb commit 629c579
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 21 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "springs"
version = "1.5.2"
version = "1.6.0"
description = "A set of utilities to create and manage typed configuration files effectively, built on top of OmegaConf."
authors = [
{name = "Luca Soldaini", email = "luca@soldaini.net" }
Expand All @@ -14,7 +14,7 @@ dependencies = [
"typing_extensions>=4.2.0",
"get-annotations>=0.1.2",
"platformdirs>=2.5.0",
"rich>=12.0.0",
"rich>=11.0.0",
]
classifiers = [
"Development Status :: 5 - Production/Stable",
Expand Down
41 changes: 24 additions & 17 deletions src/springs/commandline.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import os
import sys
from argparse import Action, ArgumentParser
from argparse import Action
from dataclasses import dataclass, fields, is_dataclass
from inspect import getfile, getfullargspec, isclass
from pathlib import Path
from inspect import getfullargspec, isclass
from typing import (
Any,
Callable,
Expand All @@ -27,7 +25,12 @@
merge,
unsafe_merge,
)
from .rich_utils import add_pretty_traceback, print_config_as_tree, print_table
from .rich_utils import (
RichArgumentParser,
add_pretty_traceback,
print_config_as_tree,
print_table,
)

# parameters for the main function
MP = ParamSpec("MP")
Expand Down Expand Up @@ -64,7 +67,7 @@ def usage(self) -> str:
def long(self) -> str:
return f"--{self.name}"

def add_argparse(self, parser: ArgumentParser) -> Action:
def add_argparse(self, parser: RichArgumentParser) -> Action:
kwargs: Dict[str, Any] = {"help": self.help}
if self.action is not MISSING:
kwargs["action"] = self.action
Expand Down Expand Up @@ -136,29 +139,22 @@ def flags(self) -> Iterable[Flag]:
if isinstance(maybe_flag, Flag):
yield maybe_flag

def add_argparse(self, parser: ArgumentParser) -> Sequence[Action]:
def add_argparse(self, parser: RichArgumentParser) -> Sequence[Action]:
return [flag.add_argparse(parser) for flag in self.flags]

@property
def usage(self) -> str:
"""Print the usage string for the CLI flags."""
return " ".join(flag.usage for flag in self.flags)

def make_cli(self, func: Callable, name: str) -> ArgumentParser:
def make_cli(self, func: Callable, name: str) -> RichArgumentParser:
"""Sets up argument parser ahead of running the CLI. This includes
creating a help message, and adding a series of flags."""

# we find the path to the script we are decorating with the
# cli so that we can display that to the user.
current_dir = Path(os.getcwd())
path_to_fn_file = Path(getfile(func))
rel_fn_file_path = str(path_to_fn_file).replace(str(current_dir), "")

# Program name and usage added here.
ap = ArgumentParser(
ap = RichArgumentParser(
description=f"Parser for configuration {name}",
usage=(
f"python3 {rel_fn_file_path} {self.usage} "
f"python3 {sys.argv[0]} {self.usage} "
"param1=value1 … paramN=valueN"
),
)
Expand Down Expand Up @@ -241,6 +237,11 @@ def wrap_main_method(
title="Registered Resolvers",
columns=["Resolver Name"],
values=[(r,) for r in sorted(all_resolvers())],
caption=(
"Resolvers use syntax ${resolver_name:'arg1','arg2'}. "
"For more information, visit https://omegaconf.readthedocs.io/"
"en/latest/custom_resolvers.html"
)
)

if opts.nicknames:
Expand All @@ -250,6 +251,12 @@ def wrap_main_method(
title="Registered Nicknames",
columns=["Nickname", "Path"],
values=NicknameRegistry().all(),
caption=(
"Nicknames are invoked via: "
"${sp.from_node:nickname,'path.to.key1=value1',...}. "
"\nOverride keys are optional (but quotes are required)."

)
)

# Print default options if requested py the user
Expand Down
37 changes: 35 additions & 2 deletions src/springs/rich_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Dict, Optional, Sequence, Union
from argparse import ArgumentParser, HelpFormatter
import os
from typing import IO, Any, Dict, Optional, Sequence, Type, Union

from omegaconf import DictConfig, ListConfig
from rich.console import Console
Expand All @@ -20,6 +22,7 @@ def print_table(
columns: Sequence[str],
values: Sequence[Sequence[Any]],
colors: Optional[Sequence[str]] = None,
caption: Optional[str] = None,
):
colors = list(
colors or ["magenta", "cyan", "red", "green", "yellow", "blue"]
Expand All @@ -28,13 +31,21 @@ def print_table(
# repeat colors if we have more columns than colors
colors = colors * (len(columns) // len(colors) + 1)

min_width = min(
max(len(title), len(caption or '')) + 2,
os.get_terminal_size().columns - 2
)

table = Table(
*(
Column(column, justify="center", style=color, vertical="middle")
for column, color in zip(columns, colors)
),
title=f"\n{title}",
min_width=len(title) + 2,
min_width=min_width,
caption=caption,
title_style="bold",
caption_style="grey74"
)
for row in values:
table.add_row(*row)
Expand Down Expand Up @@ -86,3 +97,25 @@ def get_parent_path(path: str) -> str:
)

Console().print(root)


class RichFormatter(HelpFormatter):
...


class RichArgumentParser(ArgumentParser):
def __init__(
self,
*args,
formatter_class: Type[HelpFormatter] = RichFormatter,
console_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.console_kwargs = console_kwargs or {}

def _print_message(
self, message: Any, file: Optional[IO[str]] = None
) -> None:
console = Console(**{**self.console_kwargs, "file": file})
console.print(message)

0 comments on commit 629c579

Please sign in to comment.