Skip to content

Commit

Permalink
fix: Fix detection of optional and default in Numpydoc-style parameters
Browse files Browse the repository at this point in the history
Issue #165: #165
  • Loading branch information
pawamoy committed Jun 19, 2023
1 parent 53827c8 commit 3509106
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 22 deletions.
56 changes: 34 additions & 22 deletions src/griffe/docstrings/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,26 +193,30 @@ def _read_block(docstring: Docstring, *, offset: int) -> tuple[str, int]:
_RE_OB: str = r"\{" # opening bracket
_RE_CB: str = r"\}" # closing bracket
_RE_NAME: str = r"\*{0,2}[_a-z][_a-z0-9]*"
_RE_TYPE: str = r"[_a-z0-9 \[\]|().,'\"-]*"
_RE_RETURNS: Pattern = re.compile(rf"^(?:(?P<name>{_RE_NAME})\s*:\s*)?(?P<type>{_RE_TYPE})", re.IGNORECASE)
_RE_TYPE: str = r".+"
_RE_RETURNS: Pattern = re.compile(
rf"""
(?:
(?P<nt_name>{_RE_NAME})\s*:\s*(?P<nt_type>{_RE_TYPE}) # name and type
| # or
(?P<name>{_RE_NAME})\s*:\s* # just name
| # or
(?P<type>{_RE_TYPE})\s* # just type
)
""",
re.IGNORECASE | re.VERBOSE,
)
_RE_YIELDS: Pattern = _RE_RETURNS
_RE_RECEIVES: Pattern = _RE_YIELDS
_RE_RECEIVES: Pattern = _RE_RETURNS
_RE_PARAMETER: Pattern = re.compile(
rf"""
(?P<names>{_RE_NAME}(?:,\s{_RE_NAME})*)
(?:
\s:\s
(?:
(?:{_RE_OB}(?P<choices>.+){_RE_CB})|
(?:
(?P<type>{_RE_TYPE})
(?:,\soptional)?
(?:
,\sdefault\s*[:=]\s*
(?P<default>.+)
)?
)
)
(?P<type>{_RE_TYPE})
)?
)?
""",
re.IGNORECASE | re.VERBOSE,
Expand All @@ -239,14 +243,19 @@ def _read_parameters(
continue

names = match.group("names").split(", ")
annotation = match.group("type")
annotation = annotation or None
annotation = match.group("type") or None
choices = match.group("choices")
default = None
if choices:
choices = choices.split(", ", 1)
default = choices[0]
else:
default = match.group("default")
elif annotation:
match = re.match(r"^(?P<annotation>.+),\s+default(?: |: |=)(?P<default>.+)$", annotation)
if match:
default = match.group("default")
annotation = match.group("annotation")
if annotation and annotation.endswith(", optional"):
annotation = annotation[:-10]
description = "\n".join(item[1:]).rstrip() if len(item) > 1 else ""

if annotation is None:
Expand Down Expand Up @@ -359,8 +368,9 @@ def _read_returns_section(
_warn(docstring, new_offset, f"Could not parse line '{item[0]}'")
continue

name, annotation = match.groups()
annotation = annotation or None
groups = match.groupdict()
name = groups["nt_name"] or groups["name"]
annotation = groups["nt_type"] or groups["type"]
text = dedent("\n".join(item[1:]))
if annotation is None:
# try to retrieve the annotation from the docstring parent
Expand Down Expand Up @@ -415,8 +425,9 @@ def _read_yields_section(
_warn(docstring, new_offset, f"Could not parse line '{item[0]}'")
continue

name, annotation = match.groups()
annotation = annotation or None
groups = match.groupdict()
name = groups["nt_name"] or groups["name"]
annotation = groups["nt_type"] or groups["type"]
text = dedent("\n".join(item[1:]))
if annotation is None:
# try to retrieve the annotation from the docstring parent
Expand Down Expand Up @@ -462,8 +473,9 @@ def _read_receives_section(
_warn(docstring, new_offset, f"Could not parse line '{item[0]}'")
continue

name, annotation = match.groups()
annotation = annotation or None
groups = match.groupdict()
name = groups["nt_name"] or groups["name"]
annotation = groups["nt_type"] or groups["type"]
text = dedent("\n".join(item[1:]))
if annotation is None:
# try to retrieve the annotation from the docstring parent
Expand Down
22 changes: 22 additions & 0 deletions tests/test_docstrings/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,28 @@ def test_class_uses_init_parameters(parse_numpy: ParserType) -> None:
assert argx.description == "X value."


def test_detect_optional_flag(parse_numpy: ParserType) -> None:
"""Detect the optional part of a parameter docstring.
Parameters:
parse_numpy: Fixture parser.
"""
docstring = """
Parameters
----------
a : str, optional
g, h : bytes, optional, default=b''
"""

sections, _ = parse_numpy(docstring)
assert len(sections) == 1
assert sections[0].value[0].annotation == "str"
assert sections[0].value[1].annotation == "bytes"
assert sections[0].value[1].default == "b''"
assert sections[0].value[2].annotation == "bytes"
assert sections[0].value[2].default == "b''"


# =============================================================================================
# Yields sections
@pytest.mark.parametrize(
Expand Down

0 comments on commit 3509106

Please sign in to comment.