Skip to content

Commit

Permalink
Close #7341: py domain: type annotations are converted to cross refs
Browse files Browse the repository at this point in the history
  • Loading branch information
tk0miya committed Mar 22, 2020
1 parent dd85cb6 commit b0a6b3f
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Features added
* #6417: py domain: Allow to make a style for arguments of functions and methods
* #7238, #7239: py domain: Emit a warning on describing a python object if the
entry is already added as the same name
* #7341: py domain: type annotations in singature are converted to cross refs
* Support priority of event handlers. For more detail, see
:py:meth:`.Sphinx.connect()`
* #3077: Implement the scoping for :rst:dir:`productionlist` as indicated
Expand Down
59 changes: 57 additions & 2 deletions sphinx/domains/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sphinx.domains import Domain, ObjType, Index, IndexEntry
from sphinx.environment import BuildEnvironment
from sphinx.locale import _, __
from sphinx.pycode.ast import ast, parse as ast_parse
from sphinx.roles import XRefRole
from sphinx.util import logging
from sphinx.util.docfields import Field, GroupedField, TypedField
Expand Down Expand Up @@ -67,6 +68,58 @@
}


def _parse_annotation(annotation: str) -> List[Node]:
"""Parse type annotation."""
def make_xref(text: str) -> addnodes.pending_xref:
return pending_xref('', nodes.Text(text),
refdomain='py', reftype='class', reftarget=text)

def unparse(node: ast.AST) -> List[Node]:
if isinstance(node, ast.Attribute):
return [nodes.Text("%s.%s" % (unparse(node.value)[0], node.attr))]
elif isinstance(node, ast.Expr):
return unparse(node.value)
elif isinstance(node, ast.Index):
return unparse(node.value)
elif isinstance(node, ast.List):
result = [addnodes.desc_sig_punctuation('', '[')] # type: List[Node]
for elem in node.elts:
result.extend(unparse(elem))
result.append(addnodes.desc_sig_punctuation('', ', '))
result.pop()
result.append(addnodes.desc_sig_punctuation('', ']'))
return result
elif isinstance(node, ast.Module):
return sum((unparse(e) for e in node.body), [])
elif isinstance(node, ast.Name):
return [nodes.Text(node.id)]
elif isinstance(node, ast.Subscript):
result = unparse(node.value)
result.append(addnodes.desc_sig_punctuation('', '['))
result.extend(unparse(node.slice))
result.append(addnodes.desc_sig_punctuation('', ']'))
return result
elif isinstance(node, ast.Tuple):
result = []
for elem in node.elts:
result.extend(unparse(elem))
result.append(addnodes.desc_sig_punctuation('', ', '))
result.pop()
return result
else:
raise SyntaxError # unsupported syntax

try:
tree = ast_parse(annotation)
result = unparse(tree)
for i, node in enumerate(result):
if isinstance(node, nodes.Text):
result[i] = make_xref(str(node))
return result
except SyntaxError:
return [make_xref(annotation)]


def _parse_arglist(arglist: str) -> addnodes.desc_parameterlist:
"""Parse a list of arguments using AST parser"""
params = addnodes.desc_parameterlist(arglist)
Expand All @@ -93,9 +146,10 @@ def _parse_arglist(arglist: str) -> addnodes.desc_parameterlist:
node += addnodes.desc_sig_name('', param.name)

if param.annotation is not param.empty:
children = _parse_annotation(param.annotation)
node += addnodes.desc_sig_punctuation('', ':')
node += nodes.Text(' ')
node += addnodes.desc_sig_name('', param.annotation)
node += addnodes.desc_sig_name('', '', *children) # type: ignore
if param.default is not param.empty:
if param.annotation is not param.empty:
node += nodes.Text(' ')
Expand Down Expand Up @@ -354,7 +408,8 @@ def handle_signature(self, sig: str, signode: desc_signature) -> Tuple[str, str]
signode += addnodes.desc_parameterlist()

if retann:
signode += addnodes.desc_returns(retann, retann)
children = _parse_annotation(retann)
signode += addnodes.desc_returns(retann, '', *children)

anno = self.options.get('annotation')
if anno:
Expand Down
2 changes: 1 addition & 1 deletion sphinx/testing/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def assert_node(node: Node, cls: Any = None, xpath: str = "", **kwargs: Any) ->
'The node%s has %d child nodes, not one' % (xpath, len(node))
assert_node(node[0], cls[1:], xpath=xpath + "[0]", **kwargs)
elif isinstance(cls, tuple):
assert isinstance(node, nodes.Element), \
assert isinstance(node, (list, nodes.Element)), \
'The node%s does not have any items' % xpath
assert len(node) == len(cls), \
'The node%s has %d child nodes, not %r' % (xpath, len(node), len(cls))
Expand Down
57 changes: 44 additions & 13 deletions tests/test_domain_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from sphinx.addnodes import (
desc, desc_addname, desc_annotation, desc_content, desc_name, desc_optional,
desc_parameter, desc_parameterlist, desc_returns, desc_signature,
desc_sig_name, desc_sig_operator, desc_sig_punctuation,
desc_sig_name, desc_sig_operator, desc_sig_punctuation, pending_xref,
)
from sphinx.domains import IndexEntry
from sphinx.domains.python import (
py_sig_re, _pseudo_parse_arglist, PythonDomain, PythonModuleIndex
py_sig_re, _parse_annotation, _pseudo_parse_arglist, PythonDomain, PythonModuleIndex
)
from sphinx.testing import restructuredtext
from sphinx.testing.util import assert_node
Expand Down Expand Up @@ -78,7 +78,7 @@ def assert_refnode(node, module_name, class_name, target, reftype=None,
assert_node(node, **attributes)

doctree = app.env.get_doctree('roles')
refnodes = list(doctree.traverse(addnodes.pending_xref))
refnodes = list(doctree.traverse(pending_xref))
assert_refnode(refnodes[0], None, None, 'TopLevel', 'class')
assert_refnode(refnodes[1], None, None, 'top_level', 'meth')
assert_refnode(refnodes[2], None, 'NestedParentA', 'child_1', 'meth')
Expand All @@ -96,7 +96,7 @@ def assert_refnode(node, module_name, class_name, target, reftype=None,
assert len(refnodes) == 13

doctree = app.env.get_doctree('module')
refnodes = list(doctree.traverse(addnodes.pending_xref))
refnodes = list(doctree.traverse(pending_xref))
assert_refnode(refnodes[0], 'module_a.submodule', None,
'ModTopLevel', 'class')
assert_refnode(refnodes[1], 'module_a.submodule', 'ModTopLevel',
Expand Down Expand Up @@ -125,7 +125,7 @@ def assert_refnode(node, module_name, class_name, target, reftype=None,
assert len(refnodes) == 16

doctree = app.env.get_doctree('module_option')
refnodes = list(doctree.traverse(addnodes.pending_xref))
refnodes = list(doctree.traverse(pending_xref))
print(refnodes)
print(refnodes[0])
print(refnodes[1])
Expand Down Expand Up @@ -236,21 +236,52 @@ def test_get_full_qualified_name():
assert domain.get_full_qualified_name(node) == 'module1.Class.func'


def test_parse_annotation():
doctree = _parse_annotation("int")
assert_node(doctree, ([pending_xref, "int"],))

doctree = _parse_annotation("List[int]")
assert_node(doctree, ([pending_xref, "List"],
[desc_sig_punctuation, "["],
[pending_xref, "int"],
[desc_sig_punctuation, "]"]))

doctree = _parse_annotation("Tuple[int, int]")
assert_node(doctree, ([pending_xref, "Tuple"],
[desc_sig_punctuation, "["],
[pending_xref, "int"],
[desc_sig_punctuation, ", "],
[pending_xref, "int"],
[desc_sig_punctuation, "]"]))

doctree = _parse_annotation("Callable[[int, int], int]")
assert_node(doctree, ([pending_xref, "Callable"],
[desc_sig_punctuation, "["],
[desc_sig_punctuation, "["],
[pending_xref, "int"],
[desc_sig_punctuation, ", "],
[pending_xref, "int"],
[desc_sig_punctuation, "]"],
[desc_sig_punctuation, ", "],
[pending_xref, "int"],
[desc_sig_punctuation, "]"]))


def test_pyfunction_signature(app):
text = ".. py:function:: hello(name: str) -> str"
doctree = restructuredtext.parse(app, text)
assert_node(doctree, (addnodes.index,
[desc, ([desc_signature, ([desc_name, "hello"],
desc_parameterlist,
[desc_returns, "str"])],
[desc_returns, pending_xref, "str"])],
desc_content)]))
assert_node(doctree[1], addnodes.desc, desctype="function",
domain="py", objtype="function", noindex=False)
assert_node(doctree[1][0][1],
[desc_parameterlist, desc_parameter, ([desc_sig_name, "name"],
[desc_sig_punctuation, ":"],
" ",
[nodes.inline, "str"])])
[nodes.inline, pending_xref, "str"])])


def test_pyfunction_signature_full(app):
Expand All @@ -260,27 +291,27 @@ def test_pyfunction_signature_full(app):
assert_node(doctree, (addnodes.index,
[desc, ([desc_signature, ([desc_name, "hello"],
desc_parameterlist,
[desc_returns, "str"])],
[desc_returns, pending_xref, "str"])],
desc_content)]))
assert_node(doctree[1], addnodes.desc, desctype="function",
domain="py", objtype="function", noindex=False)
assert_node(doctree[1][0][1],
[desc_parameterlist, ([desc_parameter, ([desc_sig_name, "a"],
[desc_sig_punctuation, ":"],
" ",
[desc_sig_name, "str"])],
[desc_sig_name, pending_xref, "str"])],
[desc_parameter, ([desc_sig_name, "b"],
[desc_sig_operator, "="],
[nodes.inline, "1"])],
[desc_parameter, ([desc_sig_operator, "*"],
[desc_sig_name, "args"],
[desc_sig_punctuation, ":"],
" ",
[desc_sig_name, "str"])],
[desc_sig_name, pending_xref, "str"])],
[desc_parameter, ([desc_sig_name, "c"],
[desc_sig_punctuation, ":"],
" ",
[desc_sig_name, "bool"],
[desc_sig_name, pending_xref, "bool"],
" ",
[desc_sig_operator, "="],
" ",
Expand All @@ -289,7 +320,7 @@ def test_pyfunction_signature_full(app):
[desc_sig_name, "kwargs"],
[desc_sig_punctuation, ":"],
" ",
[desc_sig_name, "str"])])])
[desc_sig_name, pending_xref, "str"])])])


@pytest.mark.skipif(sys.version_info < (3, 8), reason='python 3.8+ is required.')
Expand Down Expand Up @@ -340,7 +371,7 @@ def test_optional_pyfunction_signature(app):
assert_node(doctree, (addnodes.index,
[desc, ([desc_signature, ([desc_name, "compile"],
desc_parameterlist,
[desc_returns, "ast object"])],
[desc_returns, pending_xref, "ast object"])],
desc_content)]))
assert_node(doctree[1], addnodes.desc, desctype="function",
domain="py", objtype="function", noindex=False)
Expand Down

0 comments on commit b0a6b3f

Please sign in to comment.