Skip to content

Commit

Permalink
exp init: only ask for that are not provided in an interactive mode (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry authored Oct 4, 2021
1 parent a098a50 commit 0eda4bd
Showing 1 changed file with 28 additions and 22 deletions.
50 changes: 28 additions & 22 deletions dvc/repo/experiments/init.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, Dict, Iterable

from funcy import compact, post_processing
from funcy import compact, lremove, post_processing
from rich.prompt import Prompt

from dvc.types import OptStr
Expand Down Expand Up @@ -78,7 +78,7 @@ def make_prompt(self, default):


@post_processing(dict)
def _prompts(keys: List[str], defaults: Dict[str, OptStr]):
def _prompts(keys: Iterable[str], defaults: Dict[str, OptStr]):
for key in keys:
if key == "cmd":
prompt_cls = RequiredPrompt
Expand All @@ -91,19 +91,30 @@ def _prompts(keys: List[str], defaults: Dict[str, OptStr]):


def init_interactive(
defaults: Dict[str, str], show_heading: bool = False, live: bool = False
defaults: Dict[str, str],
provided: Iterable[str],
show_heading: bool = False,
live: bool = False,
) -> Dict[str, str]:
primary = lremove(provided, ["cmd", "code", "data", "models", "params"])
secondary = lremove(provided, ["live"] if live else ["metrics", "plots"])

if not (primary or secondary):
return {}

message = (
"This command will guide you to set up your first stage in "
"[green]dvc.yaml[/green].\n"
)
if show_heading:
ui.error_write(message, styled=True)

return {
**_prompts(["cmd", "code", "data", "models", "params"], defaults),
**_prompts(["live"] if live else ["metrics", "plots"], defaults),
}
return compact(
{
**_prompts(primary, defaults),
**_prompts(secondary, defaults),
}
)


def _check_stage_exists(
Expand Down Expand Up @@ -139,10 +150,11 @@ def init(

with_live = type == "live"
if interactive:
context = init_interactive(
defaults = init_interactive(
defaults=defaults or {},
show_heading=not dvcfile.exists(),
live=with_live,
provided=overrides.keys(),
)
else:
if with_live:
Expand All @@ -153,16 +165,9 @@ def init(
defaults.pop("params")
else:
defaults.pop("live") # suppress live otherwise
context = {**defaults, **overrides}

context: Dict[str, str] = {**defaults, **overrides}
assert "cmd" in context
command = context["cmd"]
code = context.get("code")
data = context.get("data")
models = context.get("models")
metrics = context.get("metrics")
plots = context.get("plots")
live = context.get("live")

params_kv = []
if context.get("params"):
Expand All @@ -174,14 +179,15 @@ def init(
params_kv = [{path: list(LOADERS[ext](path))}]

checkpoint_out = bool(context.get("live"))
models = context.get("models")
return repo.stage.add(
name=name,
cmd=command,
deps=compact([code, data]),
cmd=context["cmd"],
deps=compact([context.get("code"), context.get("data")]),
params=params_kv,
metrics_no_cache=compact([metrics]),
plots_no_cache=compact([plots]),
live=live,
metrics_no_cache=compact([context.get("metrics")]),
plots_no_cache=compact([context.get("plots")]),
live=context.get("live"),
force=force,
**{"checkpoints" if checkpoint_out else "outs": compact([models])},
)

0 comments on commit 0eda4bd

Please sign in to comment.