diff --git a/harness/determined/cli/cli.py b/harness/determined/cli/cli.py index 25f21628d9a..bab6d13ef25 100644 --- a/harness/determined/cli/cli.py +++ b/harness/determined/cli/cli.py @@ -46,7 +46,7 @@ version, workspace, ) -from determined.common import api, util, yaml +from determined.common import api, util from determined.common.api import bindings, certs @@ -54,6 +54,9 @@ def _render_search_summary(resp: bindings.v1PreviewHPSearchResponse) -> str: output = [ termcolor.colored("Using search configuration:", "green"), ] + + # For mypy + assert resp.summary and resp.summary.config and resp.summary.runs config_str = render.format_object_as_yaml(resp.summary.config) output.append(config_str) headers = ["Runs", "Training Time"] @@ -68,18 +71,20 @@ def _render_search_summary(resp: bindings.v1PreviewHPSearchResponse) -> str: output.append(tabulate.tabulate(run_summaries, headers, tablefmt="presto")) return "\n".join(output) + def preview_search(args: argparse.Namespace) -> None: sess = cli.setup_session(args) experiment_config = util.safe_load_yaml_with_exceptions(args.config_file) args.config_file.close() if "searcher" not in experiment_config: - raise errors.CliError(f"Missing 'searcher' config section in experiment config.") + raise errors.CliError("Missing 'searcher' config section in experiment config.") resp = bindings.post_PreviewHPSearch( - session=sess, body=bindings.v1PreviewHPSearchRequest( + session=sess, + body=bindings.v1PreviewHPSearchRequest( config=experiment_config, - ) + ), ) print(_render_search_summary(resp=resp)) diff --git a/harness/tests/cli/test_cli.py b/harness/tests/cli/test_cli.py index 465cdf0e8b2..774b509364e 100644 --- a/harness/tests/cli/test_cli.py +++ b/harness/tests/cli/test_cli.py @@ -24,6 +24,7 @@ MINIMAL_CONFIG = '{"description": "test"}' MASTER_HOST = "http://localhost:8080" + def test_parse_config() -> None: assert ntsc.parse_config(None, [], [], []) == {} @@ -582,7 +583,7 @@ def test_preview_search(tmp_path: pathlib.Path) -> None: "name": "random", "metric": "loss", "max_trials": 10, - } + }, } conf_path = tmp_path / "config.yaml" with conf_path.open("w") as tmp_file: @@ -593,7 +594,7 @@ def test_preview_search(tmp_path: pathlib.Path) -> None: config=searcher_config, runs={ "1": bindings.v1SearchUnit(undefined=True), - } + }, ) ) with util.standard_cli_rsps() as rsps: @@ -609,6 +610,4 @@ def test_preview_search(tmp_path: pathlib.Path) -> None: ], json=mock_resp.to_json(), ) - cli.main( - ["preview-search", str(conf_path)] - ) \ No newline at end of file + cli.main(["preview-search", str(conf_path)]) diff --git a/master/internal/experiment.go b/master/internal/experiment.go index 8eef016741f..5269b0677b2 100644 --- a/master/internal/experiment.go +++ b/master/internal/experiment.go @@ -50,7 +50,7 @@ type ( SearcherState json.RawMessage `json:"searcher_state"` RunSearcherState map[int32]experiment.RunSearcherState `json:"run_searcher_state"` } - + internalExperiment struct { mu sync.Mutex diff --git a/master/internal/experiment/experiment_iface.go b/master/internal/experiment/experiment_iface.go index d7ad27de8f6..b1cb45b8265 100644 --- a/master/internal/experiment/experiment_iface.go +++ b/master/internal/experiment/experiment_iface.go @@ -49,7 +49,7 @@ type ( type Experiment interface { RunReportProgress(runID int32, msg RunReportProgress) error RunReportValidation(runID int32, metrics map[string]interface{}) error - //TrialGetSearcherState(runID int32) (RunSearcherState, error) + // TrialGetSearcherState(runID int32) (RunSearcherState, error) UserInitiatedEarlyRunExit(msg UserInitiatedEarlyRunExit) error PatchRunState(msg PatchRunState) error SetGroupMaxSlots(msg sproto.SetGroupMaxSlots) diff --git a/master/internal/restore.go b/master/internal/restore.go index 73dd1de86c6..3db78628f8d 100644 --- a/master/internal/restore.go +++ b/master/internal/restore.go @@ -5,6 +5,10 @@ import ( "database/sql" "encoding/json" "fmt" + + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "github.com/determined-ai/determined/master/internal/db" "github.com/determined-ai/determined/master/internal/experiment" "github.com/determined-ai/determined/master/internal/rm" @@ -16,8 +20,6 @@ import ( "github.com/determined-ai/determined/master/pkg/ptrs" "github.com/determined-ai/determined/master/pkg/schemas" "github.com/determined-ai/determined/master/pkg/searcher" - "github.com/pkg/errors" - log "github.com/sirupsen/logrus" ) // The current experiment snapshot version. Once this is incremented, older versions should be @@ -448,7 +450,6 @@ func shimExperimentSnapshotV4(snapshot []byte) ([]byte, error) { // - `trial_searcher_state.Stop` -> dropped // - `trial_searcher_state.Closed` -> `run_searcher_state.Closed` func shimExperimentSnapshotV5(snapshot []byte) ([]byte, error) { - type v4SearcherState struct { TrialsRequested int `json:"trials_requested"` TrialsCreated map[model.RequestID]bool `json:"trials_created"` diff --git a/master/pkg/searcher/actions.go b/master/pkg/searcher/actions.go index f94bd979ba1..2e409d386d3 100644 --- a/master/pkg/searcher/actions.go +++ b/master/pkg/searcher/actions.go @@ -2,6 +2,7 @@ package searcher import ( "fmt" + "github.com/determined-ai/determined/master/pkg/nprand" ) diff --git a/master/pkg/searcher/adaptive_asha_test.go b/master/pkg/searcher/adaptive_asha_test.go index fcc5f5e4abc..f8ff8a85f01 100644 --- a/master/pkg/searcher/adaptive_asha_test.go +++ b/master/pkg/searcher/adaptive_asha_test.go @@ -2,9 +2,10 @@ package searcher import ( - "github.com/stretchr/testify/require" "testing" + "github.com/stretchr/testify/require" + "gotest.tools/assert" "github.com/determined-ai/determined/master/pkg/ptrs" diff --git a/master/pkg/searcher/asha_stopping.go b/master/pkg/searcher/asha_stopping.go index 26f28a7ffa5..cf46ac83bd4 100644 --- a/master/pkg/searcher/asha_stopping.go +++ b/master/pkg/searcher/asha_stopping.go @@ -3,10 +3,11 @@ package searcher import ( "encoding/json" "fmt" - "github.com/determined-ai/determined/master/pkg/ptrs" "math" "sort" + "github.com/determined-ai/determined/master/pkg/ptrs" + "github.com/determined-ai/determined/master/pkg/mathx" "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/master/pkg/schemas/expconf" @@ -18,8 +19,6 @@ import ( // to stop or continue training the trial based on the ranking of the validation metric // compared to other trials in a particular rung. Once a trial has been stopped, it will not // be resumed later; this is why the algorithm does not require fault tolerance. -// The searcher state and config match that of AsyncHalvingSearch but we will only run -// the stopping based version if StopOnce is true. type asyncHalvingStoppingSearch struct { expconf.AsyncHalvingConfig SmallerIsBetter bool @@ -219,7 +218,6 @@ func (s *asyncHalvingStoppingSearch) stopRun( // If this is the top rung, close the run and exit. if r == s.NumRungs()-1 { actions = append(actions, NewStop(runID)) - //s.ClosedTrials[requestID] = true return actions } diff --git a/master/pkg/searcher/asha_stopping_test.go b/master/pkg/searcher/asha_stopping_test.go index dfabf688a45..a0441ba6ea5 100644 --- a/master/pkg/searcher/asha_stopping_test.go +++ b/master/pkg/searcher/asha_stopping_test.go @@ -2,12 +2,14 @@ package searcher import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/determined-ai/determined/master/pkg/model" "github.com/determined-ai/determined/master/pkg/ptrs" "github.com/determined-ai/determined/master/pkg/schemas" "github.com/determined-ai/determined/master/pkg/schemas/expconf" - "github.com/stretchr/testify/require" - "testing" ) func TestMakeRungs(t *testing.T) { diff --git a/master/pkg/searcher/random_test.go b/master/pkg/searcher/random_test.go index e0cb63f7dec..dfe95511bcd 100644 --- a/master/pkg/searcher/random_test.go +++ b/master/pkg/searcher/random_test.go @@ -2,9 +2,10 @@ package searcher import ( - "github.com/stretchr/testify/require" "testing" + "github.com/stretchr/testify/require" + "github.com/determined-ai/determined/master/pkg/ptrs" "github.com/determined-ai/determined/master/pkg/schemas" "github.com/determined-ai/determined/master/pkg/schemas/expconf" diff --git a/master/pkg/searcher/simulate.go b/master/pkg/searcher/simulate.go index 624481fdb17..4c9ea899d47 100644 --- a/master/pkg/searcher/simulate.go +++ b/master/pkg/searcher/simulate.go @@ -2,12 +2,14 @@ package searcher import ( "fmt" + "math/rand" + + "github.com/pkg/errors" + "github.com/determined-ai/determined/master/pkg/mathx" "github.com/determined-ai/determined/master/pkg/protoutils" "github.com/determined-ai/determined/master/pkg/schemas/expconf" "github.com/determined-ai/determined/proto/pkg/experimentv1" - "github.com/pkg/errors" - "math/rand" ) // ValidationFunction calculates the validation metric for the validation step. diff --git a/master/pkg/searcher/simulate_test.go b/master/pkg/searcher/simulate_test.go index 7c0fcf2c0e8..97086fd7a45 100644 --- a/master/pkg/searcher/simulate_test.go +++ b/master/pkg/searcher/simulate_test.go @@ -2,11 +2,13 @@ package searcher import ( "fmt" + "testing" + + "github.com/stretchr/testify/require" + "github.com/determined-ai/determined/master/pkg/ptrs" "github.com/determined-ai/determined/master/pkg/schemas" "github.com/determined-ai/determined/master/pkg/schemas/expconf" - "github.com/stretchr/testify/require" - "testing" ) func TestSimulate(t *testing.T) { diff --git a/master/pkg/searcher/tournament.go b/master/pkg/searcher/tournament.go index 6149ed04328..658f66fcbf8 100644 --- a/master/pkg/searcher/tournament.go +++ b/master/pkg/searcher/tournament.go @@ -2,6 +2,7 @@ package searcher import ( "encoding/json" + "github.com/pkg/errors" "github.com/determined-ai/determined/master/pkg/model"