diff --git a/tests/integration_tests/function_wrapper/test_single_function.py b/tests/integration_tests/function_wrapper/test_single_function.py index 8eecf996..b81982e1 100644 --- a/tests/integration_tests/function_wrapper/test_single_function.py +++ b/tests/integration_tests/function_wrapper/test_single_function.py @@ -115,3 +115,13 @@ def example_func_with_outs_dict(cfg: NodeConfig): def test_example_func_with_outs_dict(proj_path): example_func_with_outs_dict(run=True) assert pathlib.Path("data.txt").read_text() == "Hello World" + + +@nodify() +def function_no_args(cfg: NodeConfig): + # not sure why this would ever be used, but it is possible + pass + + +def test_function_no_args(proj_path): + function_no_args(run=True) diff --git a/zntrack/core/functions/decorator.py b/zntrack/core/functions/decorator.py index 1f97f1c3..11fd078a 100644 --- a/zntrack/core/functions/decorator.py +++ b/zntrack/core/functions/decorator.py @@ -24,6 +24,11 @@ log = logging.getLogger(__name__) +UnionListOrStrAndPath = typing.Union[typing.List[StrOrPath], StrOrPath] +UnionDictListOfStrPath = typing.Union[ + typing.List[StrOrPath], typing.Dict[str, StrOrPath], StrOrPath +] + @dataclasses.dataclass class NodeConfig: @@ -36,15 +41,15 @@ class NodeConfig: """ params: typing.Union[dot4dict.dotdict, dict] = dataclasses.field(default_factory=dict) - outs: typing.Union[StrOrPath, typing.List[StrOrPath]] = None - outs_no_cache: typing.Union[StrOrPath, typing.List[StrOrPath]] = None - outs_persist: typing.Union[StrOrPath, typing.List[StrOrPath]] = None - outs_persist_no_cache: typing.Union[StrOrPath, typing.List[StrOrPath]] = None - metrics: typing.Union[StrOrPath, typing.List[StrOrPath]] = None - metrics_no_cache: typing.Union[StrOrPath, typing.List[StrOrPath]] = None - deps: typing.Union[StrOrPath, typing.List[StrOrPath]] = None - plots: typing.Union[StrOrPath, typing.List[StrOrPath]] = None - plots_no_cache: typing.Union[StrOrPath, typing.List[StrOrPath]] = None + outs: UnionDictListOfStrPath = None + outs_no_cache: UnionDictListOfStrPath = None + outs_persist: UnionDictListOfStrPath = None + outs_persist_no_cache: UnionDictListOfStrPath = None + metrics: UnionDictListOfStrPath = None + metrics_no_cache: UnionDictListOfStrPath = None + deps: UnionDictListOfStrPath = None + plots: UnionDictListOfStrPath = None + plots_no_cache: UnionDictListOfStrPath = None def __post_init__(self): for option_name in self.__dataclass_fields__: @@ -118,7 +123,6 @@ def write_dvc_command(self, node_name: str) -> list: AnyOrNodeConfig = typing.Union[typing.Any, NodeConfig] -UnionListOrStrAndPath = typing.Union[typing.List[StrOrPath], StrOrPath] def execute_function_call(func): @@ -180,15 +184,15 @@ def save_node_config_to_files(cfg: NodeConfig, node_name: str): def nodify( *, params: dict = None, - outs: UnionListOrStrAndPath = None, - outs_no_cache: UnionListOrStrAndPath = None, - outs_persist: UnionListOrStrAndPath = None, - outs_persist_no_cache: UnionListOrStrAndPath = None, - metrics: UnionListOrStrAndPath = None, - metrics_no_cache: UnionListOrStrAndPath = None, - deps: UnionListOrStrAndPath = None, - plots: UnionListOrStrAndPath = None, - plots_no_cache: UnionListOrStrAndPath = None, + outs: UnionDictListOfStrPath = None, + outs_no_cache: UnionDictListOfStrPath = None, + outs_persist: UnionDictListOfStrPath = None, + outs_persist_no_cache: UnionDictListOfStrPath = None, + metrics: UnionDictListOfStrPath = None, + metrics_no_cache: UnionDictListOfStrPath = None, + deps: UnionDictListOfStrPath = None, + plots: UnionDictListOfStrPath = None, + plots_no_cache: UnionDictListOfStrPath = None, ): """Main wrapper Function to convert a function into a DVC Stage