From d32b12a72e894a95c93de7a71f278d2b4913619b Mon Sep 17 00:00:00 2001 From: Jeremy Maitin-Shepard Date: Thu, 31 Oct 2024 21:29:19 -0700 Subject: [PATCH] Fix C++ apigen page naming for variable template specializations (#396) Previously, the page names for variable template specializations incorrectly included the template arguments within angle brackets, unlike class template specializations where such arguments were stripped. With this commit, template arguments are stripped and the entities are distinguished based on the user-specified `id`. --- sphinx_immaterial/apidoc/cpp/api_parser.py | 6 ++- tests/cpp_api_parser_test.py | 43 +++++++++++++++++++++- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/sphinx_immaterial/apidoc/cpp/api_parser.py b/sphinx_immaterial/apidoc/cpp/api_parser.py index 44aadd1d..3d7152f2 100644 --- a/sphinx_immaterial/apidoc/cpp/api_parser.py +++ b/sphinx_immaterial/apidoc/cpp/api_parser.py @@ -112,7 +112,7 @@ def _make_replacement_pattern( ) SPECIAL_GROUP_COMMAND_PATTERN = re.compile( - r"^(?:\\|@)(ingroup|relates|membergroup|id)\s+(.+[^\s])\s*$", re.MULTILINE + r"^(?:\\|@)(ingroup|relates|membergroup|id)\s+(.*[^\s])\s*$", re.MULTILINE ) @@ -2206,7 +2206,9 @@ def _format_template_arguments(entity: CppApiEntity) -> str: def _get_entity_base_page_name_component(entity: CppApiEntity) -> str: base_name = entity["name"] - if entity["kind"] == "class" and entity.get("specializes"): + if (entity["kind"] == "class" or entity["kind"] == "var") and entity.get( + "specializes" + ): # Strip any template arguments base_name = re.sub("([^<]*).*", r"\1", base_name) elif entity["kind"] == "conversion_function": diff --git a/tests/cpp_api_parser_test.py b/tests/cpp_api_parser_test.py index 5b6fdb1f..acbbcd51 100644 --- a/tests/cpp_api_parser_test.py +++ b/tests/cpp_api_parser_test.py @@ -8,7 +8,7 @@ def test_basic(): config = api_parser.Config( input_path="a.cpp", - input_content=b""" + input_content=rb""" /// This is the doc. int foo(bool x, int y); @@ -17,6 +17,7 @@ def test_basic(): ) output = api_parser.generate_output(config) + assert not output.get("errors") entities = list(output["entities"].values()) assert len(entities) == 1 @@ -45,6 +46,7 @@ class DimensionIdentifier { ) output = api_parser.generate_output(config) + assert not output.get("errors") print(output) entities = list(output["entities"].values()) assert len(entities) == 2 @@ -62,6 +64,8 @@ def test_enable_if_transform(): } /// This is the doc. +/// +/// \ingroup X template >> int foo(T x); @@ -76,7 +80,8 @@ def test_enable_if_transform(): ) output = api_parser.generate_output(config) - + assert not output.get("errors") + assert not output.get("warnings") entities = list(output["entities"].values()) assert len(entities) == 1 requires = entities[0].get("requires") @@ -170,6 +175,7 @@ def test_comment_styles(doc_str: bytes, expected: str): input_content=doc_str, ) output = api_parser.generate_output(config) + assert not output.get("errors") doc_strings = [ cast(api_parser.JsonDocComment, v["doc"])["text"] for v in output.get("entities", {}).values() @@ -198,6 +204,7 @@ def test_function_fields(): ) output = api_parser.generate_output(config) + assert not output.get("errors") entities = output.get("entities", {}) doc_str = "" for entity in entities.values(): @@ -238,9 +245,41 @@ def test_unnamed_template_parameter(): ) output = api_parser.generate_output(config) + assert not output.get("errors") + assert not output.get("warnings") entities = output.get("entities", {}) assert len(entities) == 1 entity = list(entities.values())[0] tparams = entity["template_parameters"] assert tparams is not None assert tparams[0]["name"] == "" + + +def test_variable_template_specialization(): + config = api_parser.Config( + input_path="a.cpp", + compiler_flags=["-std=c++17", "-x", "c++"], + input_content=rb""" +/// Check if it has A. +/// +/// \ingroup Array +template +constexpr inline bool HasA = false; + +/// Specializes HasA for int. +/// \ingroup Array +/// \id int +template <> +constexpr inline bool HasA = true; +""", + ) + + output = api_parser.generate_output(config) + assert not output.get("errors") + assert not output.get("warnings") + entities = output.get("entities", {}) + assert len(entities) == 2 + assert sorted([entity["page_name"] for entity in entities.values()]) == [ + "HasA", + "HasA-int", + ]