diff --git a/dvc/repo/experiments/init.py b/dvc/repo/experiments/init.py index 2fce0f79c4..d060737263 100644 --- a/dvc/repo/experiments/init.py +++ b/dvc/repo/experiments/init.py @@ -1,7 +1,7 @@ import os -from typing import TYPE_CHECKING, Dict, List +from typing import TYPE_CHECKING, Dict, Iterable -from funcy import compact, post_processing +from funcy import compact, lremove, post_processing from rich.prompt import Prompt from dvc.types import OptStr @@ -78,7 +78,7 @@ def make_prompt(self, default): @post_processing(dict) -def _prompts(keys: List[str], defaults: Dict[str, OptStr]): +def _prompts(keys: Iterable[str], defaults: Dict[str, OptStr]): for key in keys: if key == "cmd": prompt_cls = RequiredPrompt @@ -91,8 +91,17 @@ def _prompts(keys: List[str], defaults: Dict[str, OptStr]): def init_interactive( - defaults: Dict[str, str], show_heading: bool = False, live: bool = False + defaults: Dict[str, str], + provided: Iterable[str], + show_heading: bool = False, + live: bool = False, ) -> Dict[str, str]: + primary = lremove(provided, ["cmd", "code", "data", "models", "params"]) + secondary = lremove(provided, ["live"] if live else ["metrics", "plots"]) + + if not (primary or secondary): + return {} + message = ( "This command will guide you to set up your first stage in " "[green]dvc.yaml[/green].\n" @@ -100,10 +109,12 @@ def init_interactive( if show_heading: ui.error_write(message, styled=True) - return { - **_prompts(["cmd", "code", "data", "models", "params"], defaults), - **_prompts(["live"] if live else ["metrics", "plots"], defaults), - } + return compact( + { + **_prompts(primary, defaults), + **_prompts(secondary, defaults), + } + ) def _check_stage_exists( @@ -139,10 +150,11 @@ def init( with_live = type == "live" if interactive: - context = init_interactive( + defaults = init_interactive( defaults=defaults or {}, show_heading=not dvcfile.exists(), live=with_live, + provided=overrides.keys(), ) else: if with_live: @@ -153,16 +165,9 @@ def init( defaults.pop("params") else: defaults.pop("live") # suppress live otherwise - context = {**defaults, **overrides} + context: Dict[str, str] = {**defaults, **overrides} assert "cmd" in context - command = context["cmd"] - code = context.get("code") - data = context.get("data") - models = context.get("models") - metrics = context.get("metrics") - plots = context.get("plots") - live = context.get("live") params_kv = [] if context.get("params"): @@ -174,14 +179,15 @@ def init( params_kv = [{path: list(LOADERS[ext](path))}] checkpoint_out = bool(context.get("live")) + models = context.get("models") return repo.stage.add( name=name, - cmd=command, - deps=compact([code, data]), + cmd=context["cmd"], + deps=compact([context.get("code"), context.get("data")]), params=params_kv, - metrics_no_cache=compact([metrics]), - plots_no_cache=compact([plots]), - live=live, + metrics_no_cache=compact([context.get("metrics")]), + plots_no_cache=compact([context.get("plots")]), + live=context.get("live"), force=force, **{"checkpoints" if checkpoint_out else "outs": compact([models])}, )