Skip to content

Commit

Permalink
feat: Set labels on functions using decorators
Browse files Browse the repository at this point in the history
Issue #47: #47
  • Loading branch information
pawamoy committed Apr 15, 2022
1 parent c4a92b7 commit 1c1feb2
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 7 deletions.
62 changes: 56 additions & 6 deletions src/griffe/agents/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,22 @@
Parameters,
)
from griffe.docstrings.parsers import Parser
from griffe.exceptions import LastNodeError
from griffe.exceptions import LastNodeError, NameResolutionError
from griffe.expressions import Expression, Name

builtin_decorators = {
"property",
"staticmethod",
"classmethod",
}

stdlib_decorators = {
"abc.abstractmethod",
"functools.cache",
"functools.cached_property",
"functools.lru_cache",
}


def visit(
module_name: str,
Expand Down Expand Up @@ -210,8 +223,13 @@ def visit_classdef(self, node: ast.ClassDef) -> None:
if node.decorator_list:
lineno = node.decorator_list[0].lineno
for decorator_node in node.decorator_list:
decorators.append(Decorator(decorator_node.lineno, decorator_node.end_lineno)) # type: ignore[attr-defined]
self.visit(decorator_node)
decorators.append(
Decorator(
get_value(decorator_node),
lineno=decorator_node.lineno,
endlineno=decorator_node.end_lineno, # type: ignore[attr-defined]
)
)
else:
lineno = node.lineno

Expand All @@ -235,6 +253,32 @@ def visit_classdef(self, node: ast.ClassDef) -> None:
self.generic_visit(node)
self.current = self.current.parent # type: ignore[assignment]

def decorators_to_labels(self, decorators: list[Decorator]) -> set[str]: # noqa: WPS231
"""Build and return a set of labels based on decorators.
Parameters:
decorators: The decorators to check.
Returns:
A set of labels.
"""
labels = set()
for decorator in decorators:
decorator_value = decorator.value.split("(", 1)[0]
if decorator_value in builtin_decorators:
labels.add(decorator_value)
else:
names = decorator_value.split(".")
with suppress(NameResolutionError):
resolved_first = self.current.resolve(names[0])
resolved_name = ".".join([resolved_first, *names[1:]])
if resolved_name in stdlib_decorators:
if "abstract" in resolved_name:
labels.add("abstract")
elif "cache" in resolved_name:
labels.add("cached")
return labels

def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels: set | None = None): # noqa: WPS231
"""Handle a function definition node.
Expand All @@ -249,12 +293,18 @@ def handle_function(self, node: ast.AsyncFunctionDef | ast.FunctionDef, labels:
if node.decorator_list:
lineno = node.decorator_list[0].lineno
for decorator_node in node.decorator_list:
decorators.append(Decorator(decorator_node.lineno, decorator_node.end_lineno)) # type: ignore[attr-defined]
self.visit(decorator_node)
decorator_value = get_value(decorator_node)
decorators.append(
Decorator(
decorator_value,
lineno=decorator_node.lineno,
endlineno=decorator_node.end_lineno, # type: ignore[attr-defined]
)
)
else:
lineno = node.lineno

# TODO: handle member already exist, setter of property
labels |= self.decorators_to_labels(decorators)

# handle parameters
parameters = Parameters()
Expand Down
8 changes: 7 additions & 1 deletion src/griffe/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ class Decorator:
endlineno: The ending line number.
"""

def __init__(self, lineno: int | None, endlineno: int | None) -> None:
def __init__(self, value: str, *, lineno: int | None, endlineno: int | None) -> None:
"""Initialize the decorator.
Parameters:
value: The decorator code.
lineno: The starting line number.
endlineno: The ending line number.
"""
self.value: str = value
self.lineno: int | None = lineno
self.endlineno: int | None = endlineno

Expand All @@ -74,6 +76,7 @@ def as_dict(self, **kwargs: Any) -> dict[str, Any]:
A dictionary.
"""
return {
"value": self.value,
"lineno": self.lineno,
"endlineno": self.endlineno,
}
Expand Down Expand Up @@ -1167,6 +1170,9 @@ def __init__(
self.parameters: Parameters = parameters or Parameters()
self.returns: str | Name | Expression | None = returns
self.decorators: list[Decorator] = decorators or []
self.setter: Function | None = None
self.deleter: Function | None = None
self.overloads: list[Function] | None = None

@property
def annotation(self) -> str | Name | Expression | None:
Expand Down
35 changes: 35 additions & 0 deletions tests/test_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,38 @@ def test_not_defined_at_runtime():
assert "CONST_C" in package.members
assert "TYPE_B" not in package.members
assert "TYPE_C" not in package.members


@pytest.mark.parametrize(
("decorator", "label"),
[
("property", "property"),
("functools.cache", "cached"),
("functools.cached_property", "cached"),
("functools.lru_cache", "cached"),
("functools.lru_cache(maxsize=8)", "cached"),
("cache", "cached"),
("cached_property", "cached"),
("lru_cache", "cached"),
("lru_cache(maxsize=8)", "cached"),
],
)
def test_set_labels_using_decorators(decorator, label):
"""Assert decorators are used to set labels on objects.
Parameters:
decorator: A parametrized decorator.
label: The expected, parametrized label.
"""
code = f"""
import functools
from functools import cache, cached_property, lru_cache
class A:
@{decorator}
def f(self):
return 0
"""
with temporary_visited_module(code) as module:
assert label in module["A.f"].labels

0 comments on commit 1c1feb2

Please sign in to comment.