From bf29d7a731f606b42bcf87227e027048edf619be Mon Sep 17 00:00:00 2001 From: Aliaksandr Ivanou Date: Fri, 15 Oct 2021 15:05:53 -0700 Subject: [PATCH] Make docstring optional (#259) Summary: Pull Request resolved: https://github.com/pytorch/torchx/pull/259 * Refactor docstring functions: combines two functions that retrieve docstring into one * Make docstring optional * Remove docstring validator Git issue: https://github.com/pytorch/torchx/issues/253 Reviewed By: kiukchung Differential Revision: D31671125 fbshipit-source-id: 938faa5c0979d2b5e8c159384ebecc0274e95fff --- docs/source/component_best_practices.rst | 8 ++ torchx/specs/api.py | 30 ++-- torchx/specs/file_linter.py | 137 ++++++++---------- torchx/specs/finder.py | 11 +- torchx/specs/test/api_test.py | 50 ++++++- torchx/specs/test/file_linter_test.py | 175 ++++++++--------------- torchx/specs/test/finder_test.py | 23 ++- 7 files changed, 210 insertions(+), 224 deletions(-) diff --git a/docs/source/component_best_practices.rst b/docs/source/component_best_practices.rst index 99272120f..b6239895e 100644 --- a/docs/source/component_best_practices.rst +++ b/docs/source/component_best_practices.rst @@ -74,6 +74,14 @@ others to understand how to use it. return AppDef(roles=[Role(..., num_replicas=num_replicas)]) +Documentation +^^^^^^^^^^^^^^^^^^^^^ + +The documentation is optional, but it is the best practice to keep component functions documented, +especially if you want to share your components. See :ref:Component Authoring +for more details. + + Named Resources ----------------- diff --git a/torchx/specs/api.py b/torchx/specs/api.py index d068ae010..bc9138007 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -19,7 +19,6 @@ Generic, Iterator, List, - Mapping, Optional, Tuple, Type, @@ -28,8 +27,7 @@ ) import yaml -from pyre_extensions import none_throws -from torchx.specs.file_linter import parse_fn_docstring +from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter from torchx.util.types import decode_from_string, decode_optional, is_bool, is_primitive @@ -748,22 +746,21 @@ def get_argparse_param_type(parameter: inspect.Parameter) -> Callable[[str], obj return str -def _create_args_parser( - fn_name: str, - parameters: Mapping[str, inspect.Parameter], - function_desc: str, - args_desc: Dict[str, str], -) -> argparse.ArgumentParser: +def _create_args_parser(app_fn: Callable[..., AppDef]) -> argparse.ArgumentParser: + parameters = inspect.signature(app_fn).parameters + function_desc, args_desc = get_fn_docstring(app_fn) script_parser = argparse.ArgumentParser( - prog=f"torchx run ...torchx_params... {fn_name} ", - description=f"App spec: {function_desc}", + prog=f"torchx run <> {app_fn.__name__} ", + description=f"AppDef for {function_desc}", + formatter_class=TorchXArgumentHelpFormatter, ) remainder_arg = [] for param_name, parameter in parameters.items(): + param_desc = args_desc[parameter.name] args: Dict[str, Any] = { - "help": args_desc[param_name], + "help": param_desc, "type": get_argparse_param_type(parameter), } if parameter.default != inspect.Parameter.empty: @@ -788,13 +785,7 @@ def _create_args_parser( def _get_function_args( app_fn: Callable[..., AppDef], app_args: List[str] ) -> Tuple[List[object], List[str], Dict[str, object]]: - docstring = none_throws(inspect.getdoc(app_fn)) - function_desc, args_desc = parse_fn_docstring(docstring) - - parameters = inspect.signature(app_fn).parameters - script_parser = _create_args_parser( - app_fn.__name__, parameters, function_desc, args_desc - ) + script_parser = _create_args_parser(app_fn) parsed_args = script_parser.parse_args(app_args) @@ -802,6 +793,7 @@ def _get_function_args( var_arg = [] kwargs = {} + parameters = inspect.signature(app_fn).parameters for param_name, parameter in parameters.items(): arg_value = getattr(parsed_args, param_name) parameter_type = parameter.annotation diff --git a/torchx/specs/file_linter.py b/torchx/specs/file_linter.py index 385c71987..a626c02b8 100644 --- a/torchx/specs/file_linter.py +++ b/torchx/specs/file_linter.py @@ -6,9 +6,11 @@ # LICENSE file in the root directory of this source tree. import abc +import argparse import ast +import inspect from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, cast +from typing import Dict, List, Optional, Tuple, cast, Callable from docstring_parser import parse from pyre_extensions import none_throws @@ -18,53 +20,66 @@ # pyre-ignore-all-errors[16] -def get_arg_names(app_specs_func_def: ast.FunctionDef) -> List[str]: - arg_names = [] - fn_args = app_specs_func_def.args - for arg_def in fn_args.args: - arg_names.append(arg_def.arg) - if fn_args.vararg: - arg_names.append(fn_args.vararg.arg) - for arg in fn_args.kwonlyargs: - arg_names.append(arg.arg) - return arg_names +def _get_default_arguments_descriptions(fn: Callable[..., object]) -> Dict[str, str]: + parameters = inspect.signature(fn).parameters + args_decs = {} + for parameter_name in parameters.keys(): + # The None or Empty string values getting ignored during help command by argparse + args_decs[parameter_name] = " " + return args_decs -def parse_fn_docstring(func_description: str) -> Tuple[str, Dict[str, str]]: +class TorchXArgumentHelpFormatter(argparse.HelpFormatter): + """Help message formatter which adds default values and required to argument help. + + If the argument is required, the class appends `(required)` at the end of the help message. + If the argument has default value, the class appends `(default: $DEFAULT)` at the end. + The formatter is designed to be used only for the torchx components functions. + These functions do not have both required and default arguments. """ - Given a docstring in a google-style format, returns the function description and - description of all arguments. - See: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html + + def _get_help_string(self, action: argparse.Action) -> str: + help = action.help or "" + # Only `--help` will have be SUPPRESS, so we ignore it + if action.default is argparse.SUPPRESS: + return help + if action.required: + help += " (required)" + else: + help += f" (default: {action.default})" + return help + + +def get_fn_docstring(fn: Callable[..., object]) -> Tuple[str, Dict[str, str]]: """ - args_description = {} + Parses the function and arguments description from the provided function. Docstring should be in + `google-style format `_ + + If function has no docstring, the function description will be the name of the function, TIP + on how to improve the help message and arguments descriptions will be names of the arguments. + + The arguments that are not present in the docstring will contain default/required information + + Args: + fn: Function with or without docstring + + Returns: + function description, arguments description where key is the name of the argument and value + if the description + """ + default_fn_desc = f"""{fn.__name__} TIP: improve this help string by adding a docstring +to your component (see: https://pytorch.org/torchx/latest/component_best_practices.html)""" + args_description = _get_default_arguments_descriptions(fn) + func_description = inspect.getdoc(fn) + if not func_description: + return default_fn_desc, args_description docstring = parse(func_description) for param in docstring.params: args_description[param.arg_name] = param.description - short_func_description = docstring.short_description - return (short_func_description or "", args_description) - - -def _get_fn_docstring( - source: str, function_name: str -) -> Optional[Tuple[str, Dict[str, str]]]: - module = ast.parse(source) - for expr in module.body: - if type(expr) == ast.FunctionDef: - func_def = cast(ast.FunctionDef, expr) - if func_def.name == function_name: - docstring = ast.get_docstring(func_def) - if not docstring: - return None - return parse_fn_docstring(docstring) - return None - - -def get_short_fn_description(path: str, function_name: str) -> Optional[str]: - source = read_conf_file(path) - docstring = _get_fn_docstring(source, function_name) - if not docstring: - return None - return docstring[0] + short_func_description = docstring.short_description or default_fn_desc + if docstring.long_description: + short_func_description += " ..." + return (short_func_description or default_fn_desc, args_description) @dataclass @@ -91,38 +106,6 @@ def _gen_linter_message(self, description: str, lineno: int) -> LinterMessage: ) -class TorchxDocstringValidator(TorchxFunctionValidator): - def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]: - """ - Validates the docstring of the `get_app_spec` function. Criteria: - * There mast be google-style docstring - * If there are more than zero arguments, there mast be a `Args:` section defined - with all arguments included. - """ - docsting = ast.get_docstring(app_specs_func_def) - lineno = app_specs_func_def.lineno - if not docsting: - desc = ( - f"`{app_specs_func_def.name}` is missing a Google Style docstring, please add one. " - "For more information on the docstring format see: " - "https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html" - ) - return [self._gen_linter_message(desc, lineno)] - - arg_names = get_arg_names(app_specs_func_def) - _, docstring_arg_defs = parse_fn_docstring(docsting) - missing_args = [ - arg_name for arg_name in arg_names if arg_name not in docstring_arg_defs - ] - if len(missing_args) > 0: - desc = ( - f"`{app_specs_func_def.name}` not all function arguments are present" - f" in the docstring. Missing args: {missing_args}" - ) - return [self._gen_linter_message(desc, lineno)] - return [] - - class TorchxFunctionArgsValidator(TorchxFunctionValidator): def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]: linter_errors = [] @@ -149,7 +132,6 @@ def _validate_arg_def( ) ] if isinstance(arg_def.annotation, ast.Name): - # TODO(aivanou): add support for primitive type check return [] complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation)) if complex_type_def.value.id == "Optional": @@ -239,12 +221,6 @@ class TorchFunctionVisitor(ast.NodeVisitor): Visitor that finds the component_function and runs registered validators on it. Current registered validators: - * TorchxDocstringValidator - validates the docstring of the function. - Criteria: - * There format should be google-python - * If there are more than zero arguments defined, there - should be obligatory `Args:` section that describes each argument on a new line. - * TorchxFunctionArgsValidator - validates arguments of the function. Criteria: * Each argument should be annotated with the type @@ -260,7 +236,6 @@ class TorchFunctionVisitor(ast.NodeVisitor): def __init__(self, component_function_name: str) -> None: self.validators = [ - TorchxDocstringValidator(), TorchxFunctionArgsValidator(), TorchxReturnValidator(), ] diff --git a/torchx/specs/finder.py b/torchx/specs/finder.py index 3dc2a74bd..18c7e2f74 100644 --- a/torchx/specs/finder.py +++ b/torchx/specs/finder.py @@ -17,7 +17,7 @@ from pyre_extensions import none_throws from torchx.specs import AppDef -from torchx.specs.file_linter import get_short_fn_description, validate +from torchx.specs.file_linter import get_fn_docstring, validate from torchx.util import entrypoints from torchx.util.io import read_conf_file @@ -40,14 +40,15 @@ class _Component: Args: name: The name of the component, which usually MODULE_PATH.FN_NAME description: The description of the component, taken from the desrciption - of the function that creates component + of the function that creates component. In case of no docstring, description + will be the same as name fn_name: Function name that creates component fn: Function that creates component validation_errors: Validation errors """ name: str - description: Optional[str] + description: str fn_name: str fn: Callable[..., AppDef] validation_errors: List[str] @@ -150,7 +151,7 @@ def _get_components_from_module( module_path = os.path.abspath(module.__file__) for function_name, function in functions: linter_errors = validate(module_path, function_name) - component_desc = get_short_fn_description(module_path, function_name) + component_desc, _ = get_fn_docstring(function) component_def = _Component( name=self._get_component_name( base_module, module.__name__, function_name @@ -197,7 +198,6 @@ def find(self) -> List[_Component]: validation_errors = self._get_validation_errors( self._filepath, self._function_name ) - fn_desc = get_short_fn_description(self._filepath, self._function_name) file_source = read_conf_file(self._filepath) namespace = globals() @@ -207,6 +207,7 @@ def find(self) -> List[_Component]: f"Function {self._function_name} does not exist in file {self._filepath}" ) app_fn = namespace[self._function_name] + fn_desc, _ = get_fn_docstring(app_fn) return [ _Component( name=f"{self._filepath}:{self._function_name}", diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index 232d7be31..d21c8cc52 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -5,13 +5,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import argparse import sys import unittest from dataclasses import asdict -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, Any from unittest.mock import MagicMock, patch import torchx.specs.named_resources_aws as named_resources_aws +from pyre_extensions import none_throws from torchx.specs import named_resources from torchx.specs.api import ( _TERMINAL_STATES, @@ -33,6 +35,7 @@ make_app_handle, parse_app_handle, runopts, + _create_args_parser, ) @@ -463,11 +466,6 @@ def _test_complex_fn( app_name: AppDef name containers: List of containers roles_scripts: Dict role_name -> role_script - num_cpus: List of cpus per role - num_gpus: Dict role_name -> gpus used for role - nnodes: Num replicas per role - first_arg: First argument to the user script - roles_args: Roles args """ num_roles = len(roles_scripts) if not num_cpus: @@ -710,3 +708,43 @@ def test_varargs_only_arg_first(self) -> None: _TEST_VAR_ARGS_FIRST, (("fooval", "--foo", "barval", "arg1", "arg2"), "asdf"), ) + + # pyre-ignore[3] + def _get_argument_help( + self, parser: argparse.ArgumentParser, name: str + ) -> Optional[Tuple[str, Any]]: + actions = parser._actions + for action in actions: + if action.dest == name: + return action.help or "", action.default + return None + + def test_argparster_complex_fn_partial(self) -> None: + parser = _create_args_parser(_test_complex_fn) + self.assertTupleEqual( + ("AppDef name", None), + none_throws(self._get_argument_help(parser, "app_name")), + ) + self.assertTupleEqual( + ("List of containers", None), + none_throws(self._get_argument_help(parser, "containers")), + ) + self.assertTupleEqual( + ("Dict role_name -> role_script", None), + none_throws(self._get_argument_help(parser, "roles_scripts")), + ) + self.assertTupleEqual( + (" ", None), none_throws(self._get_argument_help(parser, "num_cpus")) + ) + self.assertTupleEqual( + (" ", None), none_throws(self._get_argument_help(parser, "num_gpus")) + ) + self.assertTupleEqual( + (" ", 4), none_throws(self._get_argument_help(parser, "nnodes")) + ) + self.assertTupleEqual( + (" ", None), none_throws(self._get_argument_help(parser, "first_arg")) + ) + self.assertTupleEqual( + (" ", None), none_throws(self._get_argument_help(parser, "roles_args")) + ) diff --git a/torchx/specs/test/file_linter_test.py b/torchx/specs/test/file_linter_test.py index 08aa23450..ca650f4c1 100644 --- a/torchx/specs/test/file_linter_test.py +++ b/torchx/specs/test/file_linter_test.py @@ -4,18 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import ast +import argparse import os import unittest -from typing import Dict, List, Optional, cast +from typing import Dict, List, Optional from unittest.mock import patch -from pyre_extensions import none_throws from torchx.specs.file_linter import ( - get_short_fn_description, - _get_fn_docstring, - parse_fn_docstring, + get_fn_docstring, validate, + TorchXArgumentHelpFormatter, ) @@ -41,35 +39,24 @@ def _test_fn_return_int() -> int: return 0 -def _test_docstring_empty(arg: str) -> "AppDef": - """ """ - pass +def _test_docstring(arg0: str, arg1: int, arg2: Dict[int, str]) -> "AppDef": + """Short Test description + Long funct description -def _test_docstring_func_desc() -> "AppDef": - """ - Function description + Args: + arg0: arg0 desc + arg1: arg1 desc """ pass -def _test_docstring_no_args(arg: str) -> "AppDef": - """ - Test description - """ +def _test_docstring_short() -> "AppDef": + """Short Test description""" pass -def _test_docstring_correct(arg0: str, arg1: int, arg2: Dict[int, str]) -> "AppDef": - """Short Test description - - Long funct description - - Args: - arg0: arg0 desc - arg1: arg1 desc - arg2: arg2 desc - """ +def _test_without_docstring(arg0: str) -> "AppDef": pass @@ -129,10 +116,6 @@ def setUp(self) -> None: source = fp.read() self._file_content = source - def test_validate_docstring_func_desc(self) -> None: - linter_errors = validate(self._path, "_test_docstring_func_desc") - self.assertEqual(0, len(linter_errors)) - def test_syntax_error(self) -> None: content = "!!foo====bar" with patch("torchx.specs.file_linter.read_conf_file") as read_conf_file_mock: @@ -146,10 +129,9 @@ def test_validate_varargs_kwargs_fn(self) -> None: self._path, "_test_invalid_fn_with_varags_and_kwargs", ) - self.assertEqual(2, len(linter_errors)) - self.assertTrue("Missing args: ['id']" in linter_errors[0].description) + self.assertEqual(1, len(linter_errors)) self.assertTrue( - "Arg args missing type annotation", linter_errors[1].description + "Arg args missing type annotation", linter_errors[0].description ) def test_validate_no_return(self) -> None: @@ -172,45 +154,6 @@ def test_validate_incorrect_return(self) -> None: def test_validate_empty_fn(self) -> None: linter_errors = validate(self._path, "_test_empty_fn") - self.assertEqual(1, len(linter_errors)) - linter_error = linter_errors[0] - self.assertEqual("TorchxFunctionValidator", linter_error.name) - - expected_desc = ( - "`_test_empty_fn` is missing a Google Style docstring, please add one. " - "For more information on the docstring format see: " - "https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html" - ) - self.assertEquals(expected_desc, linter_error.description) - # TODO(aivanou): change this test to validate fn from another file to avoid changing lineno - # on each file change - self.assertEqual(24, linter_error.line) - - def test_validate_docstring_empty(self) -> None: - linter_errors = validate(self._path, "_test_docstring_empty") - self.assertEqual(1, len(linter_errors)) - linter_error = linter_errors[0] - self.assertEqual("TorchxFunctionValidator", linter_error.name) - expected_desc = ( - "`_test_docstring_empty` is missing a Google Style docstring, please add one. " - "For more information on the docstring format see: " - "https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html" - ) - self.assertEquals(expected_desc, linter_error.description) - - def test_validate_docstring_no_args(self) -> None: - linter_errors = validate(self._path, "_test_docstring_no_args") - self.assertEqual(1, len(linter_errors)) - linter_error = linter_errors[0] - self.assertEqual("TorchxFunctionValidator", linter_error.name) - expected_desc = ( - "`_test_docstring_no_args` not all function arguments" - " are present in the docstring. Missing args: ['arg']" - ) - self.assertEqual(expected_desc, linter_error.description) - - def test_validate_docstring_correct(self) -> None: - linter_errors = validate(self._path, "_test_docstring_correct") self.assertEqual(0, len(linter_errors)) def test_validate_args_no_type_defs(self) -> None: @@ -229,66 +172,74 @@ def test_validate_args_no_type_defs_complex(self) -> None: self._path, "_test_args_dict_list_complex_types", ) - self.assertEqual(6, len(linter_errors)) - expected_desc = ( - "`_test_args_dict_list_complex_types` not all function arguments" - " are present in the docstring. Missing args: ['arg4']" - ) + self.assertEqual(5, len(linter_errors)) self.assertEqual( - expected_desc, - linter_errors[0].description, - ) - self.assertEqual( - "Arg arg0 missing type annotation", linter_errors[1].description + "Arg arg0 missing type annotation", linter_errors[0].description ) self.assertEqual( - "Arg arg1 missing type annotation", linter_errors[2].description + "Arg arg1 missing type annotation", linter_errors[1].description ) self.assertEqual( - "Dict can only have primitive types", linter_errors[3].description + "Dict can only have primitive types", linter_errors[2].description ) self.assertEqual( - "List can only have primitive types", linter_errors[4].description + "List can only have primitive types", linter_errors[3].description ) self.assertEqual( "`_test_args_dict_list_complex_types` allows only Dict, List as complex types.Argument `arg4` has: Optional", - linter_errors[5].description, + linter_errors[4].description, ) - def _get_function_def(self, function_name: str) -> ast.FunctionDef: - module: ast.Module = ast.parse(self._file_content) - for expr in module.body: - if type(expr) == ast.FunctionDef: - func_def = cast(ast.FunctionDef, expr) - if func_def.name == function_name: - return func_def - raise RuntimeError(f"No function found: {function_name}") - - def test_validate_docstring_full(self) -> None: - func_def = self._get_function_def("_test_docstring_correct") - docstring = none_throws(ast.get_docstring(func_def)) - - func_desc, param_desc = parse_fn_docstring(docstring) - self.assertEqual("Short Test description", func_desc) + def test_validate_docstring(self) -> None: + func_desc, param_desc = get_fn_docstring(_test_docstring) + self.assertEqual("Short Test description ...", func_desc) self.assertEqual("arg0 desc", param_desc["arg0"]) self.assertEqual("arg1 desc", param_desc["arg1"]) - self.assertEqual("arg2 desc", param_desc["arg2"]) + self.assertEqual(" ", param_desc["arg2"]) - def test_get_fn_docstring(self) -> None: - function_desc, _ = none_throws( - _get_fn_docstring(self._file_content, "_test_args_dict_list_complex_types") - ) - self.assertEqual("Test description", function_desc) + def test_validate_docstring_short(self) -> None: + func_desc, param_desc = get_fn_docstring(_test_docstring_short) + self.assertEqual("Short Test description", func_desc) + + def test_validate_docstring_no_docs(self) -> None: + func_desc, param_desc = get_fn_docstring(_test_without_docstring) + expected_fn_desc = """_test_without_docstring TIP: improve this help string by adding a docstring +to your component (see: https://pytorch.org/torchx/latest/component_best_practices.html)""" + self.assertEqual(expected_fn_desc, func_desc) + self.assertEqual(" ", param_desc["arg0"]) - def test_unknown_function(self) -> None: + def test_validate_unknown_function(self) -> None: linter_errors = validate(self._path, "unknown_function") self.assertEqual(1, len(linter_errors)) self.assertEqual( "Function unknown_function not found", linter_errors[0].description ) - def test_get_short_fn_description(self) -> None: - fn_short_desc = none_throws( - get_short_fn_description(self._path, "_test_args_dict_list_complex_types") + def test_formatter(self) -> None: + parser = argparse.ArgumentParser( + prog="test prog", + description="test desc", + ) + parser.add_argument( + "--foo", + type=int, + required=True, + help="foo", + ) + parser.add_argument( + "--bar", + type=int, + help="bar", + default=1, + ) + formatter = TorchXArgumentHelpFormatter(prog="test") + self.assertEqual( + "show this help message and exit", + formatter._get_help_string(parser._actions[0]), + ) + self.assertEqual( + "foo (required)", formatter._get_help_string(parser._actions[1]) + ) + self.assertEqual( + "bar (default: 1)", formatter._get_help_string(parser._actions[2]) ) - self.assertEqual("Test description", fn_short_desc) diff --git a/torchx/specs/test/finder_test.py b/torchx/specs/test/finder_test.py index ea5664afd..7c9c7f519 100644 --- a/torchx/specs/test/finder_test.py +++ b/torchx/specs/test/finder_test.py @@ -40,6 +40,12 @@ def _test_component(name: str, role_name: str = "worker") -> AppDef: ) +def _test_component_without_docstring(name: str, role_name: str = "worker") -> AppDef: + return AppDef( + name, roles=[Role(name=role_name, image="test_image", entrypoint="main.py")] + ) + + # pyre-ignore[2] def invalid_component(name, role_name: str = "worker") -> AppDef: return AppDef( @@ -87,7 +93,7 @@ def test_get_invalid_component(self) -> None: entrypoints_mock.load_group.return_value = test_torchx_group components = _load_components() foobar_component = components["foobar.finder_test.invalid_component"] - self.assertEqual(2, len(foobar_component.validation_errors)) + self.assertEqual(1, len(foobar_component.validation_errors)) def test_get_entrypoints_components(self) -> None: test_torchx_group = {"foobar": sys.modules[__name__]} @@ -151,6 +157,21 @@ def test_find_components(self) -> None: self.assertEqual("_test_component", component.fn_name) self.assertListEqual([], component.validation_errors) + def test_find_components_without_docstring(self) -> None: + components = CustomComponentsFinder( + current_file_path(), "_test_component_without_docstring" + ).find() + self.assertEqual(1, len(components)) + component = components[0] + self.assertEqual( + f"{current_file_path()}:_test_component_without_docstring", component.name + ) + exprected_desc = """_test_component_without_docstring TIP: improve this help string by adding a docstring +to your component (see: https://pytorch.org/torchx/latest/component_best_practices.html)""" + self.assertEqual(exprected_desc, component.description) + self.assertEqual("_test_component_without_docstring", component.fn_name) + self.assertListEqual([], component.validation_errors) + def test_get_component(self) -> None: component = get_component(f"{current_file_path()}:_test_component") self.assertEqual(f"{current_file_path()}:_test_component", component.name)