diff --git a/src/griffe/agents/nodes.py b/src/griffe/agents/nodes.py index ae6f1a36..8c0f68b5 100644 --- a/src/griffe/agents/nodes.py +++ b/src/griffe/agents/nodes.py @@ -71,22 +71,14 @@ from ast import arguments as NodeArguments from ast import comprehension as NodeComprehension from ast import keyword as NodeKeyword -from contextlib import suppress +from contextlib import contextmanager, suppress from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterator, Sequence from griffe.exceptions import LastNodeError, RootNodeError from griffe.expressions import Expression, Name from griffe.logger import LogLevel, get_logger -if TYPE_CHECKING: - from pathlib import Path - - from griffe.collections import LinesCollection - - -logger = get_logger(__name__) - # TODO: remove condition once Python 3.7 support is dropped if sys.version_info >= (3, 8): from ast import NamedExpr as NodeNamedExpr @@ -106,9 +98,15 @@ from ast import Index as NodeIndex if TYPE_CHECKING: + from pathlib import Path + + from griffe.collections import LinesCollection from griffe.dataclasses import Class, Module +logger = get_logger(__name__) + + class ASTNode: """This class is dynamically added to the bases of each AST node class.""" @@ -523,62 +521,63 @@ def _join(sequence: Sequence, item: str) -> list: return new_sequence -def _parse__all__constant(node: NodeConstant, parent: Module) -> list[str]: # noqa: ARG001 - try: - return [node.value] - except AttributeError: - return [node.s] # TODO: remove once Python 3.7 is dropped - - -def _parse__all__name(node: NodeName, parent: Module) -> list[Name]: - return [Name(node.id, partial(parent.resolve, node.id))] - - -def _parse__all__starred(node: NodeStarred, parent: Module) -> list[str | Name]: - return _parse__all__(node.value, parent) - - -def _parse__all__sequence(node: NodeList | NodeSet | NodeTuple, parent: Module) -> list[str | Name]: - sequence = [] - for elt in node.elts: - sequence.extend(_parse__all__(elt, parent)) - return sequence +# =========================================================== +# __all__ assignments +class _AllExtractor: + def __init__(self, parent: Module) -> None: + self.parent = parent + self._node_map: dict[type, Callable[[Any], list[str | Name]]] = { + NodeConstant: self._extract_constant, # type: ignore[dict-item] + NodeName: self._extract_name, # type: ignore[dict-item] + NodeStarred: self._extract_starred, + NodeList: self._extract_sequence, + NodeSet: self._extract_sequence, + NodeTuple: self._extract_sequence, + NodeBinOp: self._extract_binop, + } + # TODO: remove once Python 3.7 support is dropped + if sys.version_info < (3, 8): + self._node_map[NodeNameConstant] = self._extract_nameconstant # type: ignore[assignment] + self._node_map[NodeStr] = self._extract_str # type: ignore[assignment] -def _parse__all__binop(node: NodeBinOp, parent: Module) -> list[str | Name]: - left = _parse__all__(node.left, parent) - right = _parse__all__(node.right, parent) - return left + right + def _extract_constant(self, node: NodeConstant) -> list[str]: + try: + return [node.value] + except AttributeError: + return [node.s] # TODO: remove once Python 3.7 is dropped + def _extract_name(self, node: NodeName) -> list[Name]: + return [Name(node.id, partial(self.parent.resolve, node.id))] -_node__all__map: dict[type, Callable[[Any, Module], list[str | Name]]] = { - NodeConstant: _parse__all__constant, # type: ignore[dict-item] - NodeName: _parse__all__name, # type: ignore[dict-item] - NodeStarred: _parse__all__starred, - NodeList: _parse__all__sequence, - NodeSet: _parse__all__sequence, - NodeTuple: _parse__all__sequence, - NodeBinOp: _parse__all__binop, -} + def _extract_starred(self, node: NodeStarred) -> list[str | Name]: + return self._extract(node.value) -# TODO: remove once Python 3.7 support is dropped -if sys.version_info < (3, 8): + def _extract_sequence(self, node: NodeList | NodeSet | NodeTuple) -> list[str | Name]: + sequence = [] + for elt in node.elts: + sequence.extend(self._extract(elt)) + return sequence - def _parse__all__nameconstant(node: NodeNameConstant, parent: Module) -> list[Name]: # noqa: ARG001 - return [node.value] + def _extract_binop(self, node: NodeBinOp) -> list[str | Name]: + left = self._extract(node.left) + right = self._extract(node.right) + return left + right - def _parse__all__str(node: NodeStr, parent: Module) -> list[str]: # noqa: ARG001 - return [node.s] + # TODO: remove once Python 3.7 support is dropped + if sys.version_info < (3, 8): - _node__all__map[NodeNameConstant] = _parse__all__nameconstant # type: ignore[assignment] - _node__all__map[NodeStr] = _parse__all__str # type: ignore[assignment] + def _extract_nameconstant(self, node: NodeNameConstant) -> list[Name]: + return [node.value] + def _extract_str(self, node: NodeStr) -> list[str]: + return [node.s] -def _parse__all__(node: AST, parent: Module) -> list[str | Name]: - return _node__all__map[type(node)](node, parent) + def _extract(self, node: AST) -> list[str | Name]: + return self._node_map[type(node)](node) -def parse__all__(node: NodeAssign | NodeAugAssign, parent: Module) -> list[str | Name]: +def get__all__(node: NodeAssign | NodeAugAssign, parent: Module) -> list[str | Name]: """Get the values declared in `__all__`. Parameters: @@ -588,218 +587,292 @@ def parse__all__(node: NodeAssign | NodeAugAssign, parent: Module) -> list[str | Returns: A set of names. """ - try: - return _parse__all__(node.value, parent) - except KeyError as error: - logger.debug(f"Cannot parse __all__ assignment: {get_value(node.value)} ({error})") + if node.value is None: return [] + extractor = _AllExtractor(parent) + return extractor._extract(node.value) -# ========================================================== -# annotations -def _get_attribute_annotation(node: NodeAttribute, parent: Module | Class) -> Expression: - left = _get_annotation(node.value, parent) - - def resolver() -> str: - return f"{left.full}.{node.attr}" # type: ignore[union-attr] - - right = Name(node.attr, resolver) - return Expression(left, ".", right) - - -def _get_binop_annotation(node: NodeBinOp, parent: Module | Class) -> Expression: - left = _get_annotation(node.left, parent) - right = _get_annotation(node.right, parent) - return Expression(left, _get_annotation(node.op, parent), right) - - -def _get_bitand_annotation(node: NodeBitAnd, parent: Module | Class) -> str: # noqa: ARG001 - return " & " - - -def _get_bitor_annotation(node: NodeBitOr, parent: Module | Class) -> str: # noqa: ARG001 - return " | " - - -def _get_call_annotation(node: NodeCall, parent: Module | Class) -> Expression: - posargs = Expression(*_join([_get_annotation(arg, parent) for arg in node.args], ", ")) - kwargs = Expression(*_join([_get_annotation(kwarg, parent) for kwarg in node.keywords], ", ")) - args: Expression | str - if posargs and kwargs: - args = Expression(posargs, ", ", kwargs) - elif posargs: - args = posargs - elif kwargs: - args = kwargs - else: - args = "" - return Expression(_get_annotation(node.func, parent), "(", args, ")") - - -def _get_constant_annotation(node: NodeConstant, parent: Module | Class) -> str | Name | Expression: - if isinstance(node.value, str): - # a string in an annotation is a stringified annotation: we parse it again - # literal strings must be wrapped in Literal[...] to be picked up as such - parsed = compile(node.value, mode="eval", filename="", flags=PyCF_ONLY_AST, optimize=1) - return _get_annotation(parsed.body, parent=parent) # type: ignore[attr-defined] - return _get_literal_annotation(node, parent) - - -def _get_literal_annotation(node: NodeConstant, parent: Module | Class) -> str: # noqa: ARG001 - return {type(...): lambda _: "..."}.get(type(node.value), repr)(node.value) +def safe_get__all__( + node: NodeAssign | NodeAugAssign, + parent: Module, + log_level: LogLevel = LogLevel.debug, # TODO: set to error when we handle more things +) -> list[str | Name]: + """Safely (no exception) extract values in `__all__`. + Parameters: + node: The `__all__` assignment node. + parent: The parent used to resolve the names. + log_level: Log level to use to log a message. -def _get_ellipsis_annotation(node: NodeEllipsis, parent: Module | Class) -> str: # noqa: ARG001 - return "..." - - -def _get_ifexp_annotation(node: NodeIfExp, parent: Module | Class) -> Expression: - return Expression( - _get_annotation(node.body, parent), - " if ", - _get_annotation(node.test, parent), - " else", - _get_annotation(node.orelse, parent), - ) - - -def _get_invert_annotation(node: NodeInvert, parent: Module | Class) -> str: # noqa: ARG001 - return "~" - - -def _get_keyword_annotation(node: NodeKeyword, parent: Module | Class) -> Expression: - return Expression(f"{node.arg}=", _get_annotation(node.value, parent)) - - -def _get_list_annotation(node: NodeList, parent: Module | Class) -> Expression: - return Expression("[", *_join([_get_annotation(el, parent) for el in node.elts], ", "), "]") - - -def _get_name_annotation(node: NodeName, parent: Module | Class) -> Name: - return Name(node.id, partial(parent.resolve, node.id)) - - -def _get_subscript_annotation(node: NodeSubscript, parent: Module | Class) -> Expression: - left = _get_annotation(node.value, parent) - if left.full in {"typing.Literal", "typing_extensions.Literal"}: # type: ignore[union-attr] - _node_annotation_map[NodeConstant] = _get_literal_annotation - subscript = _get_annotation(node.slice, parent) - _node_annotation_map[NodeConstant] = _get_constant_annotation - else: - subscript = _get_annotation(node.slice, parent) - return Expression(left, "[", subscript, "]") - + Returns: + A list of strings or resovable names. + """ + try: + return get__all__(node, parent) + except Exception as error: # noqa: BLE001 + message = f"Failed to extract `__all__` value: {get_value(node.value)}" + with suppress(Exception): + message += f" at {parent.relative_filepath}:{node.lineno}" # type: ignore[union-attr] + if isinstance(error, KeyError): + message += f": unsupported node {error}" + else: + message += f": {error}" + getattr(logger, log_level.value)(message) + return [] -def _get_tuple_annotation(node: NodeTuple, parent: Module | Class) -> Expression: - return Expression(*_join([_get_annotation(el, parent) for el in node.elts], ", ")) +# =========================================================== +# annotations, base classes, type-guarding conditions, values +class _ExpressionBuilder: + __slots__ = ("parent", "_node_map", "_literal_strings", "_parse_strings") + + def __init__(self, parent: Module | Class, *, parse_strings: bool | None = None) -> None: + self.parent = parent + self._node_map: dict[type, Callable[[Any], str | Name | Expression]] = { + NodeAttribute: self._build_attribute, + NodeBinOp: self._build_binop, + NodeBitAnd: self._build_bitand, + NodeBitOr: self._build_bitor, + NodeCall: self._build_call, + NodeConstant: self._build_constant, + NodeEllipsis: self._build_ellipsis, + NodeIfExp: self._build_ifexp, + NodeInvert: self._build_invert, + NodeKeyword: self._build_keyword, + NodeList: self._build_list, + NodeName: self._build_name, + NodeSubscript: self._build_subscript, + NodeTuple: self._build_tuple, + NodeUnaryOp: self._build_unaryop, + NodeUAdd: self._build_uadd, + NodeUSub: self._build_usub, + } + + self._literal_strings = False + if parse_strings is None: + try: + module = parent.module + except ValueError: + self._parse_strings = False + else: + self._parse_strings = not module.imports_future_annotations + else: + self._parse_strings = parse_strings + + # TODO: remove once Python 3.8 support is dropped + if sys.version_info < (3, 9): + self._node_map[NodeIndex] = self._build_index -def _get_unaryop_annotation(node: NodeUnaryOp, parent: Module | Class) -> Expression: - return Expression(_get_annotation(node.op, parent), _get_annotation(node.operand, parent)) + # TODO: remove once Python 3.7 support is dropped + if sys.version_info < (3, 8): + self._node_map[NodeBytes] = self._build_bytes + self._node_map[NodeNameConstant] = self._build_nameconstant + self._node_map[NodeNum] = self._build_num + self._node_map[NodeStr] = self._build_str + + @contextmanager + def literal_strings(self) -> Iterator[None]: + self._literal_strings = True + try: + yield + finally: + self._literal_strings = False + + def _build_attribute(self, node: NodeAttribute) -> Expression: + left = self._build(node.value) + + def resolver() -> str: + return f"{left.source}.{node.attr}" # type: ignore[union-attr] + + right = Name(node.attr, resolver) + return Expression(left, ".", right) + + def _build_binop(self, node: NodeBinOp) -> Expression: + left = self._build(node.left) + right = self._build(node.right) + return Expression(left, self._build(node.op), right) + + def _build_bitand(self, node: NodeBitAnd) -> str: # noqa: ARG002 + return " & " + + def _build_bitor(self, node: NodeBitOr) -> str: # noqa: ARG002 + return " | " + + def _build_call(self, node: NodeCall) -> Expression: + posargs = Expression(*_join([self._build(arg) for arg in node.args], ", ")) + kwargs = Expression(*_join([self._build(kwarg) for kwarg in node.keywords], ", ")) + args: Expression | str + if posargs and kwargs: + args = Expression(posargs, ", ", kwargs) + elif posargs: + args = posargs + elif kwargs: + args = kwargs + else: + args = "" + return Expression(self._build(node.func), "(", args, ")") + + def _build_constant(self, node: NodeConstant) -> str | Name | Expression: + if self._parse_strings and isinstance(node.value, str) and not self._literal_strings: + # a string in an annotation is a stringified annotation: we build it again + # literal strings must be wrapped in Literal[...] to be picked up as such + parsed = compile(node.value, mode="eval", filename="", flags=PyCF_ONLY_AST, optimize=1) + return self._build(parsed.body) # type: ignore[attr-defined] + return self._build_literal(node) + + def _build_literal(self, node: NodeConstant) -> str: + return {type(...): lambda _: "..."}.get(type(node.value), repr)(node.value) + + def _build_ellipsis(self, node: NodeEllipsis) -> str: # noqa: ARG002 + return "..." + + def _build_ifexp(self, node: NodeIfExp) -> Expression: + return Expression( + self._build(node.body), + " if ", + self._build(node.test), + " else", + self._build(node.orelse), + ) + def _build_invert(self, node: NodeInvert) -> str: # noqa: ARG002 + return "~" -def _get_uadd_annotation(node: NodeUAdd, parent: Module | Class) -> str: # noqa: ARG001 - return "+" + def _build_keyword(self, node: NodeKeyword) -> Expression: + return Expression(f"{node.arg}=", self._build(node.value)) + def _build_list(self, node: NodeList) -> Expression: + return Expression("[", *_join([self._build(el) for el in node.elts], ", "), "]") -def _get_usub_annotation(node: NodeUSub, parent: Module | Class) -> str: # noqa: ARG001 - return "-" + def _build_name(self, node: NodeName) -> Name: + return Name(node.id, partial(self.parent.resolve, node.id)) + def _build_subscript(self, node: NodeSubscript) -> Expression: + left = self._build(node.value) + if self._parse_strings and left.full in {"typing.Literal", "typing_extensions.Literal"}: # type: ignore[union-attr] + with self.literal_strings(): + subscript = self._build(node.slice) + else: + subscript = self._build(node.slice) + return Expression(left, "[", subscript, "]") -_node_annotation_map: dict[type, Callable[[Any, Module | Class], str | Name | Expression]] = { - NodeAttribute: _get_attribute_annotation, - NodeBinOp: _get_binop_annotation, - NodeBitAnd: _get_bitand_annotation, - NodeBitOr: _get_bitor_annotation, - NodeCall: _get_call_annotation, - NodeConstant: _get_constant_annotation, - NodeEllipsis: _get_ellipsis_annotation, - NodeIfExp: _get_ifexp_annotation, - NodeInvert: _get_invert_annotation, - NodeKeyword: _get_keyword_annotation, - NodeList: _get_list_annotation, - NodeName: _get_name_annotation, - NodeSubscript: _get_subscript_annotation, - NodeTuple: _get_tuple_annotation, - NodeUnaryOp: _get_unaryop_annotation, - NodeUAdd: _get_uadd_annotation, - NodeUSub: _get_usub_annotation, -} + def _build_tuple(self, node: NodeTuple) -> Expression: + return Expression(*_join([self._build(el) for el in node.elts], ", ")) -# TODO: remove once Python 3.8 support is dropped -if sys.version_info < (3, 9): + def _build_unaryop(self, node: NodeUnaryOp) -> Expression: + return Expression(self._build(node.op), self._build(node.operand)) - def _get_index_annotation(node: NodeIndex, parent: Module | Class) -> str | Name | Expression: - return _get_annotation(node.value, parent) + def _build_uadd(self, node: NodeUAdd) -> str: # noqa: ARG002 + return "+" - _node_annotation_map[NodeIndex] = _get_index_annotation + def _build_usub(self, node: NodeUSub) -> str: # noqa: ARG002 + return "-" -# TODO: remove once Python 3.7 support is dropped -if sys.version_info < (3, 8): + # TODO: remove once Python 3.8 support is dropped + if sys.version_info < (3, 9): - def _get_bytes_annotation(node: NodeBytes, parent: Module | Class) -> str: # noqa: ARG001 - return repr(node.s) + def _build_index(self, node: NodeIndex) -> str | Name | Expression: + return self._build(node.value) - def _get_nameconstant_annotation(node: NodeNameConstant, parent: Module | Class) -> str: # noqa: ARG001 - return repr(node.value) + # TODO: remove once Python 3.7 support is dropped + if sys.version_info < (3, 8): - def _get_num_annotation(node: NodeNum, parent: Module | Class) -> str: # noqa: ARG001 - return repr(node.n) + def _build_bytes(self, node: NodeBytes) -> str: + return repr(node.s) - def _get_str_annotation(node: NodeStr, parent: Module | Class) -> str | Name: - node.value = node.s # type: ignore[attr-defined] # fake node as constant - return _node_annotation_map[NodeConstant](node, parent) # type: ignore[return-value] + def _build_nameconstant(self, node: NodeNameConstant) -> str: + return repr(node.value) - _node_annotation_map[NodeBytes] = _get_bytes_annotation - _node_annotation_map[NodeNameConstant] = _get_nameconstant_annotation - _node_annotation_map[NodeNum] = _get_num_annotation - _node_annotation_map[NodeStr] = _get_str_annotation + def _build_num(self, node: NodeNum) -> str: + return repr(node.n) + def _build_str(self, node: NodeStr) -> str | Name: + node.value = node.s # type: ignore[attr-defined] # fake node as constant + return self._node_map[NodeConstant](node) # type: ignore[return-value] -def _get_annotation(node: AST, parent: Module | Class) -> str | Name | Expression: - return _node_annotation_map[type(node)](node, parent) + def _build(self, node: AST) -> str | Name | Expression: + return self._node_map[type(node)](node) -def get_annotation(node: AST | None, parent: Module | Class) -> str | Name | Expression | None: - """Extract a resolvable annotation. +def get_expression( + node: AST | None, + parent: Module | Class, + *, + parse_strings: bool | None = None, +) -> str | Name | Expression | None: + """Build an expression from an AST. Parameters: node: The annotation node. parent: The parent used to resolve the name. + parse_strings: Whether to try and parse strings as type annotations. Returns: A string or resovable name or expression. """ if node is None: return None - return _get_annotation(node, parent) + builder = _ExpressionBuilder(parent, parse_strings=parse_strings) + return builder._build(node) -def safe_get_annotation( +def safe_get_expression( node: AST | None, parent: Module | Class, - log_level: LogLevel = LogLevel.error, + *, + parse_strings: bool | None = None, + log_level: LogLevel | None = LogLevel.error, + msg_format: str = "{path}:{lineno}: Failed to get expression from {node_class}: {error}", ) -> str | Name | Expression | None: - """Safely (no exception) extract a resolvable annotation. + """Safely (no exception) build a resolvable annotation. Parameters: node: The annotation node. parent: The parent used to resolve the name. - log_level: Log level to use to log a message. + parse_strings: Whether to try and parse strings as type annotations. + log_level: Log level to use to log a message. None to disable logging. + msg_format: A format string for the log message. Available placeholders: + path, lineno, node, error. Returns: A string or resovable name or expression. """ try: - return get_annotation(node, parent) + return get_expression(node, parent, parse_strings=parse_strings) except Exception as error: # noqa: BLE001 - message = f"Failed to parse annotation from '{node.__class__.__name__}' node" - with suppress(Exception): - message += f" at {parent.relative_filepath}:{node.lineno}" # type: ignore[union-attr] - if not isinstance(error, KeyError): - message += f": {error}" + if log_level is None: + return None + node_class = node.__class__.__name__ + try: + path: Path | str = parent.relative_filepath + except ValueError: + path = "" + lineno = node.lineno # type: ignore[union-attr] + message = msg_format.format(path=path, lineno=lineno, node_class=node_class, error=error) getattr(logger, log_level.value)(message) - return None + return None + + +_msg_format = "{path}:{lineno}: Failed to get %s expression from {node_class}: {error}" +get_annotation = partial(get_expression, parse_strings=None) +safe_get_annotation = partial( + safe_get_expression, + parse_strings=None, + msg_format=_msg_format % "annotation", +) +get_base_class = partial(get_expression, parse_strings=False) +safe_get_base_class = partial( + safe_get_expression, + parse_strings=False, + msg_format=_msg_format % "base class", +) +get_condition = partial(get_expression, parse_strings=False) +safe_get_condition = partial( + safe_get_expression, + parse_strings=False, + msg_format=_msg_format % "condition", +) # ========================================================== @@ -841,419 +914,369 @@ def get_docstring( # ========================================================== # values -def _get_add_value(node: NodeAdd) -> str: # noqa: ARG001 - return "+" - - -def _get_and_value(node: NodeAnd) -> str: # noqa: ARG001 - return " and " - - -def _get_arguments_value(node: NodeArguments) -> str: - return ", ".join(arg.arg for arg in node.args) - - -def _get_attribute_value(node: NodeAttribute) -> str: - return f"{_get_value(node.value)}.{node.attr}" - - -def _get_binop_value(node: NodeBinOp) -> str: - return f"{_get_value(node.left)} {_get_value(node.op)} {_get_value(node.right)}" - - -def _get_bitor_value(node: NodeBitOr) -> str: # noqa: ARG001 - return "|" - - -def _get_bitand_value(node: NodeBitAnd) -> str: # noqa: ARG001 - return "&" - - -def _get_bitxor_value(node: NodeBitXor) -> str: # noqa: ARG001 - return "^" - - -def _get_boolop_value(node: NodeBoolOp) -> str: - return _get_value(node.op).join(_get_value(value) for value in node.values) - - -def _get_call_value(node: NodeCall) -> str: - posargs = ", ".join(_get_value(arg) for arg in node.args) - kwargs = ", ".join(_get_value(kwarg) for kwarg in node.keywords) - if posargs and kwargs: - args = f"{posargs}, {kwargs}" - elif posargs: - args = posargs - elif kwargs: - args = kwargs - else: - args = "" - return f"{_get_value(node.func)}({args})" - - -def _get_compare_value(node: NodeCompare) -> str: - left = _get_value(node.left) - ops = [_get_value(op) for op in node.ops] - comparators = [_get_value(comparator) for comparator in node.comparators] - return f"{left} " + " ".join(f"{op} {comp}" for op, comp in zip(ops, comparators)) - - -def _get_comprehension_value(node: NodeComprehension) -> str: - target = _get_value(node.target) - iterable = _get_value(node.iter) - conditions = [_get_value(condition) for condition in node.ifs] - value = f"for {target} in {iterable}" - if conditions: - value = f"{value} if " + " if ".join(conditions) - if node.is_async: - return f"async {value}" - return value - - -def _get_constant_value(node: NodeConstant) -> str: - return repr(node.value) - - -def _get_constant_value_no_string_repr(node: NodeConstant) -> str: - if isinstance(node.value, str): - return node.value - return repr(node.value) - - -def _get_dict_value(node: NodeDict) -> str: - pairs = zip(node.keys, node.values) - gen = (f"{'None' if key is None else _get_value(key)}: {_get_value(value)}" for key, value in pairs) - return "{" + ", ".join(gen) + "}" - - -def _get_dictcomp_value(node: NodeDictComp) -> str: - key = _get_value(node.key) - value = _get_value(node.value) - generators = [_get_value(gen) for gen in node.generators] - return f"{{{key}: {value} " + " ".join(generators) + "}" - - -def _get_div_value(node: NodeDiv) -> str: # noqa: ARG001 - return "/" - - -def _get_ellipsis_value(node: NodeEllipsis) -> str: # noqa: ARG001 - return "..." - - -def _get_eq_value(node: NodeEq) -> str: # noqa: ARG001 - return "==" - - -def _get_floordiv_value(node: NodeFloorDiv) -> str: # noqa: ARG001 - return "//" - - -def _get_formatted_value(node: NodeFormattedValue) -> str: - return f"{{{_get_value(node.value)}}}" - - -def _get_generatorexp_value(node: NodeGeneratorExp) -> str: - element = _get_value(node.elt) - generators = [_get_value(gen) for gen in node.generators] - return f"{element} " + " ".join(generators) - +class _ValueExtractor: + __slots__ = ("_node_map",) + + def __init__(self) -> None: + self._node_map: dict[type, Callable[[Any], str]] = { + NodeAdd: self._extract_add, + NodeAnd: self._extract_and, + NodeArguments: self._extract_arguments, + NodeAttribute: self._extract_attribute, + NodeBinOp: self._extract_binop, + NodeBitAnd: self._extract_bitand, + NodeBitOr: self._extract_bitor, + NodeBitXor: self._extract_bitxor, + NodeBoolOp: self._extract_boolop, + NodeCall: self._extract_call, + NodeCompare: self._extract_compare, + NodeComprehension: self._extract_comprehension, + NodeConstant: self._extract_constant, + NodeDictComp: self._extract_dictcomp, + NodeDict: self._extract_dict, + NodeDiv: self._extract_div, + NodeEllipsis: self._extract_ellipsis, + NodeEq: self._extract_eq, + NodeFloorDiv: self._extract_floordiv, + NodeFormattedValue: self._extract_formatted, + NodeGeneratorExp: self._extract_generatorexp, + NodeGtE: self._extract_gte, + NodeGt: self._extract_gt, + NodeIfExp: self._extract_ifexp, + NodeIn: self._extract_in, + NodeInvert: self._extract_invert, + NodeIs: self._extract_is, + NodeIsNot: self._extract_isnot, + NodeJoinedStr: self._extract_joinedstr, + NodeKeyword: self._extract_keyword, + NodeLambda: self._extract_lambda, + NodeListComp: self._extract_listcomp, + NodeList: self._extract_list, + NodeLShift: self._extract_lshift, + NodeLtE: self._extract_lte, + NodeLt: self._extract_lt, + NodeMatMult: self._extract_matmult, + NodeMod: self._extract_mod, + NodeMult: self._extract_mult, + NodeName: self._extract_name, + NodeNotEq: self._extract_noteq, + NodeNot: self._extract_not, + NodeNotIn: self._extract_notin, + NodeOr: self._extract_or, + NodePow: self._extract_pow, + NodeRShift: self._extract_rshift, + NodeSetComp: self._extract_setcomp, + NodeSet: self._extract_set, + NodeSlice: self._extract_slice, + NodeStarred: self._extract_starred, + NodeSub: self._extract_sub, + NodeSubscript: self._extract_subscript, + NodeTuple: self._extract_tuple, + NodeUAdd: self._extract_uadd, + NodeUnaryOp: self._extract_unaryop, + NodeUSub: self._extract_usub, + NodeYield: self._extract_yield, + } + + # TODO: remove condition once Python 3.7 support is dropped + if sys.version_info >= (3, 8): + self._node_map[NodeNamedExpr] = self._extract_named_expr + + # TODO: remove once Python 3.8 support is dropped + if sys.version_info < (3, 9): + self._node_map[NodeExtSlice] = self._extract_extslice + self._node_map[NodeIndex] = self._extract_index -def _get_gte_value(node: NodeNotEq) -> str: # noqa: ARG001 - return ">=" - - -def _get_gt_value(node: NodeNotEq) -> str: # noqa: ARG001 - return ">" - - -def _get_ifexp_value(node: NodeIfExp) -> str: - return f"{_get_value(node.body)} if {_get_value(node.test)} else {_get_value(node.orelse)}" - - -def _get_invert_value(node: NodeInvert) -> str: # noqa: ARG001 - return "~" - - -def _get_in_value(node: NodeIn) -> str: # noqa: ARG001 - return "in" - - -def _get_is_value(node: NodeIs) -> str: # noqa: ARG001 - return "is" - - -def _get_isnot_value(node: NodeIsNot) -> str: # noqa: ARG001 - return "is not" - - -def _get_joinedstr_value(node: NodeJoinedStr) -> str: - _node_value_map[NodeConstant] = _get_constant_value_no_string_repr - try: - return "f" + repr("".join(_get_value(value) for value in node.values)) - finally: - _node_value_map[NodeConstant] = _get_constant_value - - -def _get_keyword_value(node: NodeKeyword) -> str: - return f"{node.arg}={_get_value(node.value)}" - - -def _get_lambda_value(node: NodeLambda) -> str: - return f"lambda {_get_value(node.args)}: {_get_value(node.body)}" - - -def _get_list_value(node: NodeList) -> str: - return "[" + ", ".join(_get_value(el) for el in node.elts) + "]" - - -def _get_listcomp_value(node: NodeListComp) -> str: - element = _get_value(node.elt) - generators = [_get_value(gen) for gen in node.generators] - return f"[{element} " + " ".join(generators) + "]" - - -def _get_lshift_value(node: NodeLShift) -> str: # noqa: ARG001 - return "<<" - - -def _get_lte_value(node: NodeNotEq) -> str: # noqa: ARG001 - return "<=" - - -def _get_lt_value(node: NodeNotEq) -> str: # noqa: ARG001 - return "<" - - -def _get_matmult_value(node: NodeMatMult) -> str: # noqa: ARG001 - return "@" + # TODO: remove once Python 3.7 support is dropped + if sys.version_info < (3, 8): + self._node_map[NodeBytes] = self._extract_bytes + self._node_map[NodeNameConstant] = self._extract_nameconstant + self._node_map[NodeNum] = self._extract_num + self._node_map[NodeStr] = self._extract_str + + def _extract_add(self, node: NodeAdd) -> str: # noqa: ARG002 + return "+" + + def _extract_and(self, node: NodeAnd) -> str: # noqa: ARG002 + return " and " + + def _extract_arguments(self, node: NodeArguments) -> str: + return ", ".join(arg.arg for arg in node.args) + + def _extract_attribute(self, node: NodeAttribute) -> str: + return f"{self._extract(node.value)}.{node.attr}" + + def _extract_binop(self, node: NodeBinOp) -> str: + return f"{self._extract(node.left)} {self._extract(node.op)} {self._extract(node.right)}" + + def _extract_bitor(self, node: NodeBitOr) -> str: # noqa: ARG002 + return "|" + + def _extract_bitand(self, node: NodeBitAnd) -> str: # noqa: ARG002 + return "&" + + def _extract_bitxor(self, node: NodeBitXor) -> str: # noqa: ARG002 + return "^" + + def _extract_boolop(self, node: NodeBoolOp) -> str: + return self._extract(node.op).join(self._extract(value) for value in node.values) + + def _extract_call(self, node: NodeCall) -> str: + posargs = ", ".join(self._extract(arg) for arg in node.args) + kwargs = ", ".join(self._extract(kwarg) for kwarg in node.keywords) + if posargs and kwargs: + args = f"{posargs}, {kwargs}" + elif posargs: + args = posargs + elif kwargs: + args = kwargs + else: + args = "" + return f"{self._extract(node.func)}({args})" + + def _extract_compare(self, node: NodeCompare) -> str: + left = self._extract(node.left) + ops = [self._extract(op) for op in node.ops] + comparators = [self._extract(comparator) for comparator in node.comparators] + return f"{left} " + " ".join(f"{op} {comp}" for op, comp in zip(ops, comparators)) + + def _extract_comprehension(self, node: NodeComprehension) -> str: + target = self._extract(node.target) + iterable = self._extract(node.iter) + conditions = [self._extract(condition) for condition in node.ifs] + value = f"for {target} in {iterable}" + if conditions: + value = f"{value} if " + " if ".join(conditions) + if node.is_async: + return f"async {value}" + return value + + def _extract_constant(self, node: NodeConstant) -> str: + return repr(node.value) + def _extract_constant_no_string_repr(self, node: NodeConstant) -> str: + if isinstance(node.value, str): + return node.value + return repr(node.value) -def _get_mod_value(node: NodeMod) -> str: # noqa: ARG001 - return "%" + def _extract_dict(self, node: NodeDict) -> str: + pairs = zip(node.keys, node.values) + gen = (f"{'None' if key is None else self._extract(key)}: {self._extract(value)}" for key, value in pairs) + return "{" + ", ".join(gen) + "}" + def _extract_dictcomp(self, node: NodeDictComp) -> str: + key = self._extract(node.key) + value = self._extract(node.value) + generators = [self._extract(gen) for gen in node.generators] + return f"{{{key}: {value} " + " ".join(generators) + "}" -def _get_mult_value(node: NodeMult) -> str: # noqa: ARG001 - return "*" + def _extract_div(self, node: NodeDiv) -> str: # noqa: ARG002 + return "/" + def _extract_ellipsis(self, node: NodeEllipsis) -> str: # noqa: ARG002 + return "..." -def _get_name_value(node: NodeName) -> str: - return node.id + def _extract_eq(self, node: NodeEq) -> str: # noqa: ARG002 + return "==" + def _extract_floordiv(self, node: NodeFloorDiv) -> str: # noqa: ARG002 + return "//" -def _get_not_value(node: NodeNot) -> str: # noqa: ARG001 - return "not " + def _extract_formatted(self, node: NodeFormattedValue) -> str: + return f"{{{self._extract(node.value)}}}" + def _extract_generatorexp(self, node: NodeGeneratorExp) -> str: + element = self._extract(node.elt) + generators = [self._extract(gen) for gen in node.generators] + return f"{element} " + " ".join(generators) -def _get_noteq_value(node: NodeNotEq) -> str: # noqa: ARG001 - return "!=" + def _extract_gte(self, node: NodeNotEq) -> str: # noqa: ARG002 + return ">=" + def _extract_gt(self, node: NodeNotEq) -> str: # noqa: ARG002 + return ">" -def _get_notin_value(node: NodeNotIn) -> str: # noqa: ARG001 - return "not in" + def _extract_ifexp(self, node: NodeIfExp) -> str: + return f"{self._extract(node.body)} if {self._extract(node.test)} else {self._extract(node.orelse)}" + def _extract_invert(self, node: NodeInvert) -> str: # noqa: ARG002 + return "~" -def _get_or_value(node: NodeOr) -> str: # noqa: ARG001 - return " or " + def _extract_in(self, node: NodeIn) -> str: # noqa: ARG002 + return "in" + def _extract_is(self, node: NodeIs) -> str: # noqa: ARG002 + return "is" -def _get_pow_value(node: NodePow) -> str: # noqa: ARG001 - return "**" + def _extract_isnot(self, node: NodeIsNot) -> str: # noqa: ARG002 + return "is not" + def _extract_joinedstr(self, node: NodeJoinedStr) -> str: + self._node_map[NodeConstant] = self._extract_constant_no_string_repr + try: + return "f" + repr("".join(self._extract(value) for value in node.values)) + finally: + self._node_map[NodeConstant] = self._extract_constant -def _get_rshift_value(node: NodeRShift) -> str: # noqa: ARG001 - return ">>" + def _extract_keyword(self, node: NodeKeyword) -> str: + return f"{node.arg}={self._extract(node.value)}" + def _extract_lambda(self, node: NodeLambda) -> str: + return f"lambda {self._extract(node.args)}: {self._extract(node.body)}" -def _get_set_value(node: NodeSet) -> str: - return "{" + ", ".join(_get_value(el) for el in node.elts) + "}" + def _extract_list(self, node: NodeList) -> str: + return "[" + ", ".join(self._extract(el) for el in node.elts) + "]" + def _extract_listcomp(self, node: NodeListComp) -> str: + element = self._extract(node.elt) + generators = [self._extract(gen) for gen in node.generators] + return f"[{element} " + " ".join(generators) + "]" -def _get_setcomp_value(node: NodeSetComp) -> str: - element = _get_value(node.elt) - generators = [_get_value(gen) for gen in node.generators] - return f"{{{element} " + " ".join(generators) + "}" + def _extract_lshift(self, node: NodeLShift) -> str: # noqa: ARG002 + return "<<" + def _extract_lte(self, node: NodeNotEq) -> str: # noqa: ARG002 + return "<=" -def _get_slice_value(node: NodeSlice) -> str: - value = f"{_get_value(node.lower) if node.lower else ''}:{_get_value(node.upper) if node.upper else ''}" - if node.step: - return f"{value}:{_get_value(node.step)}" - return value + def _extract_lt(self, node: NodeNotEq) -> str: # noqa: ARG002 + return "<" + def _extract_matmult(self, node: NodeMatMult) -> str: # noqa: ARG002 + return "@" -def _get_starred_value(node: NodeStarred) -> str: - return _get_value(node.value) + def _extract_mod(self, node: NodeMod) -> str: # noqa: ARG002 + return "%" + def _extract_mult(self, node: NodeMult) -> str: # noqa: ARG002 + return "*" -def _get_sub_value(node: NodeSub) -> str: # noqa: ARG001 - return "-" + def _extract_name(self, node: NodeName) -> str: + return node.id + def _extract_not(self, node: NodeNot) -> str: # noqa: ARG002 + return "not " -def _get_subscript_value(node: NodeSubscript) -> str: - subscript = _get_value(node.slice) - if isinstance(subscript, str) and subscript.startswith("(") and subscript.endswith(")"): - subscript = subscript[1:-1] - return f"{_get_value(node.value)}[{subscript}]" + def _extract_noteq(self, node: NodeNotEq) -> str: # noqa: ARG002 + return "!=" + def _extract_notin(self, node: NodeNotIn) -> str: # noqa: ARG002 + return "not in" -def _get_tuple_value(node: NodeTuple) -> str: - return "(" + ", ".join(_get_value(el) for el in node.elts) + ")" + def _extract_or(self, node: NodeOr) -> str: # noqa: ARG002 + return " or " + def _extract_pow(self, node: NodePow) -> str: # noqa: ARG002 + return "**" -def _get_uadd_value(node: NodeUAdd) -> str: # noqa: ARG001 - return "+" + def _extract_rshift(self, node: NodeRShift) -> str: # noqa: ARG002 + return ">>" + def _extract_set(self, node: NodeSet) -> str: + return "{" + ", ".join(self._extract(el) for el in node.elts) + "}" -def _get_unaryop_value(node: NodeUnaryOp) -> str: - return f"{_get_value(node.op)}{_get_value(node.operand)}" + def _extract_setcomp(self, node: NodeSetComp) -> str: + element = self._extract(node.elt) + generators = [self._extract(gen) for gen in node.generators] + return f"{{{element} " + " ".join(generators) + "}" + def _extract_slice(self, node: NodeSlice) -> str: + value = f"{self._extract(node.lower) if node.lower else ''}:{self._extract(node.upper) if node.upper else ''}" + if node.step: + return f"{value}:{self._extract(node.step)}" + return value -def _get_usub_value(node: NodeUSub) -> str: # noqa: ARG001 - return "-" + def _extract_starred(self, node: NodeStarred) -> str: + return self._extract(node.value) + def _extract_sub(self, node: NodeSub) -> str: # noqa: ARG002 + return "-" -def _get_yield_value(node: NodeYield) -> str: - if node.value is None: - return repr(None) - return _get_value(node.value) - - -_node_value_map: dict[type, Callable[[Any], str]] = { - # type(None): lambda _: repr(None), - NodeAdd: _get_add_value, - NodeAnd: _get_and_value, - NodeArguments: _get_arguments_value, - NodeAttribute: _get_attribute_value, - NodeBinOp: _get_binop_value, - NodeBitAnd: _get_bitand_value, - NodeBitOr: _get_bitor_value, - NodeBitXor: _get_bitxor_value, - NodeBoolOp: _get_boolop_value, - NodeCall: _get_call_value, - NodeCompare: _get_compare_value, - NodeComprehension: _get_comprehension_value, - NodeConstant: _get_constant_value, - NodeDictComp: _get_dictcomp_value, - NodeDict: _get_dict_value, - NodeDiv: _get_div_value, - NodeEllipsis: _get_ellipsis_value, - NodeEq: _get_eq_value, - NodeFloorDiv: _get_floordiv_value, - NodeFormattedValue: _get_formatted_value, - NodeGeneratorExp: _get_generatorexp_value, - NodeGtE: _get_gte_value, - NodeGt: _get_gt_value, - NodeIfExp: _get_ifexp_value, - NodeIn: _get_in_value, - NodeInvert: _get_invert_value, - NodeIs: _get_is_value, - NodeIsNot: _get_isnot_value, - NodeJoinedStr: _get_joinedstr_value, - NodeKeyword: _get_keyword_value, - NodeLambda: _get_lambda_value, - NodeListComp: _get_listcomp_value, - NodeList: _get_list_value, - NodeLShift: _get_lshift_value, - NodeLtE: _get_lte_value, - NodeLt: _get_lt_value, - NodeMatMult: _get_matmult_value, - NodeMod: _get_mod_value, - NodeMult: _get_mult_value, - NodeName: _get_name_value, - NodeNotEq: _get_noteq_value, - NodeNot: _get_not_value, - NodeNotIn: _get_notin_value, - NodeOr: _get_or_value, - NodePow: _get_pow_value, - NodeRShift: _get_rshift_value, - NodeSetComp: _get_setcomp_value, - NodeSet: _get_set_value, - NodeSlice: _get_slice_value, - NodeStarred: _get_starred_value, - NodeSub: _get_sub_value, - NodeSubscript: _get_subscript_value, - NodeTuple: _get_tuple_value, - NodeUAdd: _get_uadd_value, - NodeUnaryOp: _get_unaryop_value, - NodeUSub: _get_usub_value, - NodeYield: _get_yield_value, -} + def _extract_subscript(self, node: NodeSubscript) -> str: + subscript = self._extract(node.slice) + if isinstance(subscript, str) and subscript.startswith("(") and subscript.endswith(")"): + subscript = subscript[1:-1] + return f"{self._extract(node.value)}[{subscript}]" -# TODO: remove condition once Python 3.7 support is dropped -if sys.version_info >= (3, 8): + def _extract_tuple(self, node: NodeTuple) -> str: + return "(" + ", ".join(self._extract(el) for el in node.elts) + ")" - def _get_named_expr_value(node: NodeNamedExpr) -> str: - return f"({_get_value(node.target)} := {_get_value(node.value)})" + def _extract_uadd(self, node: NodeUAdd) -> str: # noqa: ARG002 + return "+" - _node_value_map[NodeNamedExpr] = _get_named_expr_value + def _extract_unaryop(self, node: NodeUnaryOp) -> str: + return f"{self._extract(node.op)}{self._extract(node.operand)}" -# TODO: remove once Python 3.8 support is dropped -if sys.version_info < (3, 9): + def _extract_usub(self, node: NodeUSub) -> str: # noqa: ARG002 + return "-" - def _get_extslice_value(node: NodeExtSlice) -> str: - return ",".join(_get_value(dim) for dim in node.dims) + def _extract_yield(self, node: NodeYield) -> str: + if node.value is None: + return repr(None) + return self._extract(node.value) - def _get_index_value(node: NodeIndex) -> str: - return _get_value(node.value) + # TODO: remove condition once Python 3.7 support is dropped + if sys.version_info >= (3, 8): - _node_value_map[NodeExtSlice] = _get_extslice_value - _node_value_map[NodeIndex] = _get_index_value + def _extract_named_expr(self, node: NodeNamedExpr) -> str: + return f"({self._extract(node.target)} := {self._extract(node.value)})" + # TODO: remove once Python 3.8 support is dropped + if sys.version_info < (3, 9): -# TODO: remove once Python 3.7 support is dropped -if sys.version_info < (3, 8): + def _extract_extslice(self, node: NodeExtSlice) -> str: + return ",".join(self._extract(dim) for dim in node.dims) - def _get_bytes_value(node: NodeBytes) -> str: - return repr(node.s) + def _extract_index(self, node: NodeIndex) -> str: + return self._extract(node.value) - def _get_nameconstant_value(node: NodeNameConstant) -> str: - return repr(node.value) + # TODO: remove once Python 3.7 support is dropped + if sys.version_info < (3, 8): - def _get_num_value(node: NodeNum) -> str: - return repr(node.n) + def _extract_bytes(self, node: NodeBytes) -> str: + return repr(node.s) - def _get_str_value(node: NodeStr) -> str: - return repr(node.s) + def _extract_nameconstant(self, node: NodeNameConstant) -> str: + return repr(node.value) - _node_value_map[NodeBytes] = _get_bytes_value - _node_value_map[NodeNameConstant] = _get_nameconstant_value - _node_value_map[NodeNum] = _get_num_value - _node_value_map[NodeStr] = _get_str_value + def _extract_num(self, node: NodeNum) -> str: + return repr(node.n) + def _extract_str(self, node: NodeStr) -> str: + return repr(node.s) -def _get_value(node: AST) -> str: - return _node_value_map[type(node)](node) + def _extract(self, node: AST) -> str: + return self._node_map[type(node)](node) def get_value(node: AST | None) -> str | None: - """Unparse a node to its string representation. + """Get the string representation of a node. Parameters: - node: The node to unparse. + node: The node to represent. Returns: - The unparsed code of the node. + The representing code for the node. """ if node is None: return None - return _node_value_map[type(node)](node) + extractor = _ValueExtractor() + return extractor._extract(node) def safe_get_value(node: AST | None, filepath: str | Path | None = None) -> str | None: - """Safely (no exception) unparse a node to its string representation. + """Safely (no exception) get the string representation of a node. Parameters: - node: The node to unparse. + node: The node to represent. filepath: An optional filepath from where the node comes. Returns: - The unparsed code of the node. + The representing code for the node. """ try: return get_value(node) except Exception as error: - message = f"Failed to unparse node {node}" + message = f"Failed to represent node {node}" if filepath: message += f" at {filepath}:{node.lineno}" # type: ignore[union-attr] message += f": {error}" @@ -1345,8 +1368,9 @@ def get_parameter_default(node: AST | None, filepath: Path, lines_collection: Li """ if node is None: return None - with suppress(KeyError): - return _get_value(node) + default = safe_get_value(node) + if default is not None: + return default if node.lineno == node.end_lineno: # type: ignore[attr-defined] return lines_collection[filepath][node.lineno - 1][node.col_offset : node.end_col_offset] # type: ignore[attr-defined] # TODO: handle multiple line defaults diff --git a/src/griffe/agents/visitor.py b/src/griffe/agents/visitor.py index 5a13cb3d..b429679d 100644 --- a/src/griffe/agents/visitor.py +++ b/src/griffe/agents/visitor.py @@ -17,14 +17,15 @@ from griffe.agents.base import BaseVisitor from griffe.agents.nodes import ( ASTNode, - get_annotation, get_docstring, get_instance_names, get_names, get_parameter_default, - parse__all__, relative_to_absolute, + safe_get__all__, safe_get_annotation, + safe_get_base_class, + safe_get_condition, safe_get_value, ) from griffe.collections import LinesCollection, ModulesCollection @@ -246,7 +247,7 @@ def visit_classdef(self, node: ast.ClassDef) -> None: bases = [] if node.bases: for base in node.bases: - bases.append(safe_get_annotation(base, parent=self.current)) + bases.append(safe_get_base_class(base, parent=self.current)) class_ = Class( name=node.name, @@ -627,7 +628,7 @@ def handle_attribute( if name == "__all__": with suppress(AttributeError): - parent.exports = parse__all__(node, self.current) # type: ignore[assignment,arg-type] + parent.exports = safe_get__all__(node, self.current) # type: ignore[assignment,arg-type] def visit_assign(self, node: ast.Assign) -> None: """Visit an assignment node. @@ -659,7 +660,7 @@ def visit_augassign(self, node: ast.AugAssign) -> None: ) if all_augment: # we assume exports is not None at this point - self.current.exports.extend(parse__all__(node, self.current)) # type: ignore[arg-type,union-attr] + self.current.exports.extend(safe_get__all__(node, self.current)) # type: ignore[arg-type,union-attr] def visit_if(self, node: ast.If) -> None: """Visit an "if" node. @@ -668,10 +669,9 @@ def visit_if(self, node: ast.If) -> None: node: The node to visit. """ if isinstance(node.parent, (ast.Module, ast.ClassDef)): # type: ignore[attr-defined] - with suppress(KeyError): # unhandled AST nodes - condition = get_annotation(node.test, parent=self.current) - if str(condition) in {"typing.TYPE_CHECKING", "TYPE_CHECKING"}: - self.type_guarded = True + condition = safe_get_condition(node.test, parent=self.current, log_level=None) + if str(condition) in {"typing.TYPE_CHECKING", "TYPE_CHECKING"}: + self.type_guarded = True self.generic_visit(node) self.type_guarded = False diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 4a6e5b85..c83dfb3b 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -2,6 +2,7 @@ from __future__ import annotations +import logging import sys from ast import PyCF_ONLY_AST @@ -249,3 +250,20 @@ def kwargs(self) -> 'dict[str, Any] | None': assert init_args_annotation.is_tuple kwargs_return_annotation = module["ArgsKwargs.kwargs"].annotation assert isinstance(kwargs_return_annotation, Expression) + + +def test_parsing_dynamic_base_classes(caplog: pytest.LogCaptureFixture) -> None: + """Assert parsing dynamic base classes does not trigger errors. + + Parameters: + caplog: Pytest fixture to capture logs. + """ + with caplog.at_level(logging.ERROR), temporary_visited_module( + """ + from collections import namedtuple + class Thing(namedtuple('Thing', 'attr1 attr2')): + ... + """, + ): + pass + assert not caplog.records diff --git a/tests/test_visitor.py b/tests/test_visitor.py index 1a7c3efe..95cd5a2f 100644 --- a/tests/test_visitor.py +++ b/tests/test_visitor.py @@ -339,3 +339,16 @@ def __init__(self) -> None: assert module["C.b"].annotation.full == "bytes" assert module["C.b"].labels == {"instance-attribute"} + + +def test_visiting_if_statement_in_class_for_type_guards() -> None: + """Don't fail on various if statements when checking for type-guards.""" + with temporary_visited_module( + """ + class A: + if something("string1 string2"): + class B: + pass + """, + ) as module: + assert module["A.B"].runtime