Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconstruct defaulted arg values from docstring #147

Merged
merged 9 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ pybind11-stubgen [-h]
[--ignore-invalid-identifiers REGEX]
[--ignore-unresolved-names REGEX]
[--ignore-all-errors]
[--enum-class-locations [REGEX:LOC ...]]
[--numpy-array-wrap-with-annotated|
--numpy-array-remove-parameters]
[--print-invalid-expressions-as-is]
[--print-safe-value-reprs REGEX]
[--exit-code]
[--stub-extension EXT]
MODULE_NAME
Expand Down
34 changes: 34 additions & 0 deletions pybind11_stubgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@
FixTypingExtTypeNames,
FixTypingTypeNames,
FixValueReprRandomAddress,
OverridePrintSafeValues,
RemoveSelfAnnotation,
ReplaceReadWritePropertyWithField,
RewritePybind11EnumValueRepr,
)
from pybind11_stubgen.parser.mixins.parse import (
BaseParser,
Expand All @@ -62,6 +64,12 @@ def regex(pattern_str: str) -> re.Pattern:
except re.error as e:
raise ValueError(f"Invalid REGEX pattern: {e}")

def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]:
pattern_str, path = regex_path.rsplit(":", maxsplit=1)
if any(not part.isidentifier() for part in path.split(".")):
raise ValueError(f"Invalid PATH: {path}")
return regex(pattern_str), path

parser = ArgumentParser(
prog="pybind11-stubgen", description="Generates stubs for specified modules"
)
Expand Down Expand Up @@ -109,6 +117,18 @@ def regex(pattern_str: str) -> re.Pattern:
help="Ignore all errors during module parsing",
)

parser.add_argument(
"--enum-class-locations",
dest="enum_class_locations",
metavar="REGEX:LOC",
default=[],
nargs="*",
type=regex_colon_path,
help="Locations of enum classes in "
"<enum-class-name-regex>:<path-to-class> format. "
"Example: `MyEnum:foo.bar.Baz`",
)

numpy_array_fix = parser.add_mutually_exclusive_group()
numpy_array_fix.add_argument(
"--numpy-array-wrap-with-annotated",
Expand All @@ -133,6 +153,14 @@ def regex(pattern_str: str) -> re.Pattern:
help="Suppress invalid expression replacement with '...'",
)

parser.add_argument(
"--print-safe-value-reprs",
metavar="REGEX",
default=None,
type=regex,
help="Override the print-safe check for values matching REGEX",
)

parser.add_argument(
"--exit-code",
action="store_true",
Expand Down Expand Up @@ -202,10 +230,12 @@ class Parser(
FixTypingExtTypeNames,
FixMissingFixedSizeImport,
FixMissingEnumMembersAnnotation,
OverridePrintSafeValues,
*numpy_fixes, # type: ignore[misc]
FixNumpyArrayFlags,
FixCurrentModulePrefixInTypeNames,
FixBuiltinTypes,
RewritePybind11EnumValueRepr,
FilterClassMembers,
ReplaceReadWritePropertyWithField,
FilterInvalidIdentifiers,
Expand All @@ -224,12 +254,16 @@ class Parser(

parser = Parser()

if args.enum_class_locations:
parser.set_pybind11_enum_locations(dict(args.enum_class_locations))
if args.ignore_invalid_identifiers is not None:
parser.set_ignored_invalid_identifiers(args.ignore_invalid_identifiers)
if args.ignore_invalid_expressions is not None:
parser.set_ignored_invalid_expressions(args.ignore_invalid_expressions)
if args.ignore_unresolved_names is not None:
parser.set_ignored_unresolved_names(args.ignore_unresolved_names)
if args.print_safe_value_reprs is not None:
parser.set_print_safe_value_pattern(args.print_safe_value_reprs)
return parser


Expand Down
2 changes: 1 addition & 1 deletion pybind11_stubgen/parser/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def parse_annotation_str(
...

@abc.abstractmethod
def parse_value_str(self, value: str) -> Value:
def parse_value_str(self, value: str) -> Value | InvalidExpression:
...

@abc.abstractmethod
Expand Down
65 changes: 64 additions & 1 deletion pybind11_stubgen/parser/mixins/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import inspect
import re
import types
from logging import getLogger
from typing import Any

from pybind11_stubgen.parser.errors import NameResolutionError, ParserError
Expand All @@ -29,6 +30,8 @@
)
from pybind11_stubgen.typing_ext import DynamicSize, FixedSize

logger = getLogger("pybind11_stubgen")


class RemoveSelfAnnotation(IParser):
def handle_method(self, path: QualifiedName, method: Any) -> list[Method]:
Expand Down Expand Up @@ -363,7 +366,7 @@ class FixTypingExtTypeNames(IParser):
__typing_names: set[Identifier] = set(
map(
Identifier,
["buffer"],
["buffer", "Buffer"],
)
)

Expand Down Expand Up @@ -751,3 +754,63 @@ def handle_class_member(
method.modifier = None
method.function.doc = None
return result


class OverridePrintSafeValues(IParser):
_print_safe_values: re.Pattern | None

def __init__(self):
super().__init__()
self._print_safe_values = None

def set_print_safe_value_pattern(self, pattern: re.Pattern):
self._print_safe_values = pattern

def parse_value_str(self, value: str) -> Value | InvalidExpression:
result = super().parse_value_str(value)
if (
self._print_safe_values is not None
and isinstance(result, Value)
and not result.is_print_safe
and self._print_safe_values.match(result.repr) is not None
):
result.is_print_safe = True
return result


class RewritePybind11EnumValueRepr(IParser):
_pybind11_enum_pattern = re.compile(r"<(?P<enum>\w+(\.\w+)+): (?P<value>\d+)>")
_unknown_enum_classes: set[str] = set()

def __init__(self):
super().__init__()
self._pybind11_enum_locations: dict[re.Pattern, str] = {}

def set_pybind11_enum_locations(self, locations: dict[re.Pattern, str]):
self._pybind11_enum_locations = locations

def parse_value_str(self, value: str) -> Value | InvalidExpression:
value = value.strip()
match = self._pybind11_enum_pattern.match(value)
if match is not None:
enum_qual_name = match.group("enum")
enum_class_str, entry = enum_qual_name.rsplit(".", maxsplit=1)
for pattern, prefix in self._pybind11_enum_locations.items():
if pattern.match(enum_class_str) is None:
continue
enum_class = self.parse_annotation_str(f"{prefix}.{enum_class_str}")
if isinstance(enum_class, ResolvedType):
return Value(repr=f"{enum_class.name}.{entry}", is_print_safe=True)
self._unknown_enum_classes.add(enum_class_str)
return super().parse_value_str(value)

def finalize(self):
if self._unknown_enum_classes:
logger.warning(
"Enum-like str representations were found with no "
"matching mapping to the enum class location.\n"
"Use `--enum-class-locations` to specify "
"full path to the following enum(s):\n"
+ "\n".join(f" - {c}" for c in self._unknown_enum_classes)
)
super().finalize()
33 changes: 20 additions & 13 deletions pybind11_stubgen/parser/mixins/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,8 @@ def handle_type(self, type_: type) -> QualifiedName:
)
)

def parse_value_str(self, value: str) -> Value:
return Value(value)
def parse_value_str(self, value: str) -> Value | InvalidExpression:
return self._parse_expression_str(value)

def report_error(self, error: ParserError):
if isinstance(error, NameResolutionError):
Expand Down Expand Up @@ -428,6 +428,21 @@ def _get_full_name(self, path: QualifiedName, origin: Any) -> QualifiedName | No
return None
return origin_name

def _parse_expression_str(self, expr_str: str) -> Value | InvalidExpression:
strip_expr = expr_str.strip()
try:
ast.parse(strip_expr)
print_safe = False
try:
ast.literal_eval(strip_expr)
print_safe = True
except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
pass
return Value(strip_expr, is_print_safe=print_safe)
except SyntaxError:
self.report_error(InvalidExpressionError(strip_expr))
return InvalidExpression(strip_expr)


class ExtractSignaturesFromPybind11Docstrings(IParser):
_arg_star_name_regex = re.compile(
Expand Down Expand Up @@ -571,7 +586,7 @@ def parse_type_str(
annotation_str = annotation_str.strip()
match = qname_regex.match(annotation_str)
if match is None:
return self._parse_expression_str(annotation_str)
return self.parse_value_str(annotation_str)
qual_name = QualifiedName(
Identifier(part)
for part in match.group("qual_name").replace(" ", "").split(".")
Expand All @@ -582,25 +597,17 @@ def parse_type_str(
parameters = None
else:
if parameters_str[0] != "[" or parameters_str[-1] != "]":
return self._parse_expression_str(annotation_str)
return self.parse_value_str(annotation_str)

split_parameters = self._split_parameters_str(parameters_str[1:-1])
if split_parameters is None:
return self._parse_expression_str(annotation_str)
return self.parse_value_str(annotation_str)

parameters = [
self.parse_annotation_str(param_str) for param_str in split_parameters
]
return ResolvedType(name=qual_name, parameters=parameters)

def _parse_expression_str(self, expr_str: str) -> Value | InvalidExpression:
try:
ast.parse(expr_str)
return self.parse_value_str(expr_str)
except SyntaxError:
self.report_error(InvalidExpressionError(expr_str))
return InvalidExpression(expr_str)

def parse_function_docstring(
self, func_name: Identifier, doc_lines: list[str]
) -> list[Function]:
Expand Down
4 changes: 3 additions & 1 deletion pybind11_stubgen/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ def print_argument(self, arg: Argument) -> str:
parts.append(f"{arg.name}")
if arg.annotation is not None:
parts.append(f": {self.print_annotation(arg.annotation)}")
if arg.default is not None:
if isinstance(arg.default, Value):
if arg.default.is_print_safe:
parts.append(f" = {self.print_value(arg.default)}")
else:
parts.append(" = ...")
elif isinstance(arg.default, InvalidExpression):
parts.append(f" = {self.print_invalid_exp(arg.default)}")

return "".join(parts)

Expand Down
2 changes: 1 addition & 1 deletion pybind11_stubgen/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class Argument:
kw_only: bool = field_(default=False)
variadic: bool = field_(default=False) # *args
kw_variadic: bool = field_(default=False) # **kwargs
default: Value | None = field_(default=None)
default: Value | InvalidExpression | None = field_(default=None)
annotation: Annotation | None = field_(default=None)

def __str__(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/check-demo-stubs-generation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ run_stubgen() {
demo \
--output-dir=${STUBS_DIR} \
--numpy-array-wrap-with-annotated \
--ignore-invalid-expressions="\(anonymous namespace\)::(Enum|Unbound)" \
--ignore-invalid-expressions="\(anonymous namespace\)::(Enum|Unbound)|<demo\._bindings\.flawed_bindings\..*" \
--ignore-unresolved-names="typing\.Annotated" \
--enum-class-locations="ConsoleForegroundColor:demo._bindings.enum" \
--print-safe-value-reprs="Foo\(\d+\)" \
--exit-code
}

Expand Down
2 changes: 1 addition & 1 deletion tests/py-demo/bindings/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

PYBIND11_MODULE(_bindings, m) {
bind_classes_module(m.def_submodule("classes"));
bind_aliases_module(m.def_submodule("aliases"));
bind_eigen_module(m.def_submodule("eigen"));
bind_enum_module(m.def_submodule("enum"));
bind_aliases_module(m.def_submodule("aliases"));
bind_flawed_bindings_module(m.def_submodule("flawed_bindings"));
bind_functions_module(m.def_submodule("functions"));
bind_issues_module(m.def_submodule("issues"));
Expand Down
24 changes: 18 additions & 6 deletions tests/py-demo/bindings/src/modules/aliases.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
#include "modules.h"

#include <demo/Foo.h>
#include <demo/sublibA/ConsoleColors.h>

namespace {
class Dummy {};
class Dummy {
};

struct Color {};
struct Color {
};

struct Bar1 {};
struct Bar2 {};
struct Bar3 {};
struct Bar1 {
};
struct Bar2 {
};
struct Bar3 {
};
} // namespace

void bind_aliases_module(py::module_ &&m) {
Expand All @@ -18,7 +24,7 @@ void bind_aliases_module(py::module_ &&m) {
auto &&pyDummy = py::class_<Dummy>(m, "Dummy");

pyDummy.def_property_readonly_static(
"linalg", [](py::object &) { return py::module::import("numpy.linalg"); });
"linalg", [](py::object &) { return py::module::import("numpy.linalg"); });

m.add_object("random", py::module::import("numpy.random"));
}
Expand Down Expand Up @@ -63,4 +69,10 @@ void bind_aliases_module(py::module_ &&m) {
m.attr("foreign_type_alias") = m.attr("foreign_method_arg").attr("Bar2");
m.attr("foreign_class_alias") = m.attr("foreign_return").attr("get_foo");
}

m.def(
"foreign_enum_default",
[](const py::object & /* color */) {},
py::arg("color") = demo::sublibA::ConsoleForegroundColor::Blue
);
}
Loading