Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions src/argparse_pydantic/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class Arg(NamedTuple):
name: str
model: BaseModel
model: Type[BaseModel]


def get_args(func: Callable[..., Any]) -> List[Arg]:
Expand All @@ -29,7 +29,7 @@ def get_args(func: Callable[..., Any]) -> List[Arg]:
return params


def get_models(args: List[Arg]) -> List[BaseModel]:
def get_models(args: List[Arg]) -> List[Type[BaseModel]]:
"""get list models from args"""
return [arg.model for arg in args]

Expand All @@ -40,7 +40,7 @@ def app(
usage: str | None = None,
description: str | None = None,
epilog: str | None = None,
parents: Sequence[ArgumentParser] = None,
parents: Sequence[ArgumentParser] | None = None,
formatter_class: Type[HelpFormatter] = HelpFormatter,
prefix_chars: str = "-",
fromfile_prefix_chars: str | None = None,
Expand Down Expand Up @@ -72,7 +72,7 @@ def app(

# Create app.
# Simple variant - expecting function with one argument.
def create_app(func: Callable[[Type[Any]], None]):
def create_app(func: Callable[..., Any]):
args = get_args(func)
app_cfg = args[0].model

Expand Down Expand Up @@ -102,7 +102,7 @@ def __init__(
usage: str | None = None,
description: str | None = None,
epilog: str | None = None,
parents: Sequence[ArgumentParser] = None,
parents: Sequence[ArgumentParser] | None = None,
formatter_class: Type[HelpFormatter] = HelpFormatter,
prefix_chars: str = "-",
fromfile_prefix_chars: str | None = None,
Expand Down Expand Up @@ -138,7 +138,9 @@ def main(self, func: Callable[[Type[Any]], None]):
self.commands["main"] = func
self.configs["main"] = get_args(func)

def command(self, func: Callable[[Type[Any]], None] = None, *, name: str = ""):
def command(
self, func: Callable[[Type[Any]], None] | str | None = None, *, name: str = ""
):
if func is None:
return partial(self.command, name=name)
if isinstance(func, str):
Expand Down Expand Up @@ -189,14 +191,14 @@ def __call__(self, args: Optional[Sequence[str]] = None) -> None:


def run(
func: Callable[[BaseModel], None],
*args: Callable[[BaseModel], None],
**kwargs: Callable[[BaseModel], None],
func: Callable[[Type[BaseModel]], None],
*args: Callable[[Type[BaseModel]], None],
**kwargs: Callable[[Type[BaseModel]], None] | ArgumentParserCfg,
) -> None:
"""Parse command line arguments and run function.
Pass ArgumentParser Cfg as `parser_cfg=parser_cfg`.
Pass command functions as arguments or `command=func`."""
parser_cfg = kwargs.pop("parser_cfg", None)
parser_cfg: ArgumentParserCfg | None = kwargs.pop("parser_cfg", None)
run_app = App(parser_cfg=parser_cfg)
run_app.main(func)
if args:
Expand Down
28 changes: 16 additions & 12 deletions src/argparse_pydantic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def argument_kwargs(
return {key: val for key, val in kwargs.items() if val is not None}


def get_field_type(field_info: FieldInfo) -> Type:
def get_field_type(field_info: FieldInfo) -> Type[Any] | None:
"""get field type, convert to base type."""
field_type = field_info.annotation
if is_union(field_type):
Expand All @@ -99,7 +99,7 @@ def parse_field_kwargs(json_schema_extra: dict[str, Any]) -> dict[str, Any]:


def add_field_arg(
parser: argparse.ArgumentParser,
parser: argparse.ArgumentParser | argparse._ArgumentGroup,
field_name: str,
field_info: FieldInfo,
undefined_positional: bool = True,
Expand Down Expand Up @@ -143,7 +143,8 @@ def add_field_arg(
default = ""
else:
default = f"default: {field_info.default}"
kwargs["help"] = kwargs.get("help", "") + f" [{field_type.__name__}] {default}"
field_type_name = field_type.__name__ if field_type is not None else "None"
kwargs["help"] = kwargs.get("help", "") + f" [{field_type_name}] {default}"

dest = kwargs.get("dest", None)
if dest and not check_dest_ok(dest, parser):
Expand All @@ -155,7 +156,9 @@ def add_field_arg(
parser.add_argument(*flags, **kwargs)


def check_dest_ok(dest: str, parser: argparse.ArgumentParser) -> bool:
def check_dest_ok(
dest: str, parser: argparse.ArgumentParser | argparse._ArgumentGroup
) -> bool:
"""check dest not exist"""
if dest in [
action.dest
Expand All @@ -166,7 +169,9 @@ def check_dest_ok(dest: str, parser: argparse.ArgumentParser) -> bool:
return True


def check_flags(flags: list[str], parser: argparse.ArgumentParser) -> list[str]:
def check_flags(
flags: list[str], parser: argparse.ArgumentParser | argparse._ArgumentGroup
) -> list[str]:
"""check and filter flags - return only valid flags"""
if flags:
dest_list = [
Expand Down Expand Up @@ -216,7 +221,7 @@ def validate_action(action: str, default: Optional[Type]) -> None:

def add_args_from_model(
parser: argparse.ArgumentParser,
model: BaseModel | list[BaseModel],
model: Type[BaseModel] | list[Type[BaseModel]],
undefined_positional: bool = True,
help_def_type: bool = False,
create_group: bool = False,
Expand All @@ -225,18 +230,17 @@ def add_args_from_model(
if not isinstance(model, list):
model = [model]
for item in model: # if same name check at add_field_arg
if create_group:
arg_group = parser.add_argument_group(item.__name__)
else:
arg_group = parser
arg_parser = (
parser.add_argument_group(item.__name__) if create_group else parser
)
for field_name, field_info in item.model_fields.items():
add_field_arg(
arg_group, field_name, field_info, undefined_positional, help_def_type
arg_parser, field_name, field_info, undefined_positional, help_def_type
)
return parser


def create_model_obj(model: BaseModel, args: argparse.Namespace) -> BaseModel:
def create_model_obj(model: Type[BaseModel], args: argparse.Namespace) -> BaseModel:
"""create model from parsed args"""
kwargs = {
key: val for key, val in args.__dict__.items() if key in model.model_fields
Expand Down