Skip to content

Commit

Permalink
fmt / cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
azhou-determined committed Oct 7, 2024
1 parent 163d1d4 commit f987818
Show file tree
Hide file tree
Showing 13 changed files with 39 additions and 26 deletions.
13 changes: 9 additions & 4 deletions harness/determined/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,17 @@
version,
workspace,
)
from determined.common import api, util, yaml
from determined.common import api, util
from determined.common.api import bindings, certs


def _render_search_summary(resp: bindings.v1PreviewHPSearchResponse) -> str:
output = [
termcolor.colored("Using search configuration:", "green"),
]

# For mypy
assert resp.summary and resp.summary.config and resp.summary.runs
config_str = render.format_object_as_yaml(resp.summary.config)
output.append(config_str)
headers = ["Runs", "Training Time"]
Expand All @@ -68,18 +71,20 @@ def _render_search_summary(resp: bindings.v1PreviewHPSearchResponse) -> str:
output.append(tabulate.tabulate(run_summaries, headers, tablefmt="presto"))
return "\n".join(output)


def preview_search(args: argparse.Namespace) -> None:
sess = cli.setup_session(args)
experiment_config = util.safe_load_yaml_with_exceptions(args.config_file)
args.config_file.close()

if "searcher" not in experiment_config:
raise errors.CliError(f"Missing 'searcher' config section in experiment config.")
raise errors.CliError("Missing 'searcher' config section in experiment config.")

resp = bindings.post_PreviewHPSearch(
session=sess, body=bindings.v1PreviewHPSearchRequest(
session=sess,
body=bindings.v1PreviewHPSearchRequest(
config=experiment_config,
)
),
)
print(_render_search_summary(resp=resp))

Expand Down
9 changes: 4 additions & 5 deletions harness/tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
MINIMAL_CONFIG = '{"description": "test"}'
MASTER_HOST = "http://localhost:8080"


def test_parse_config() -> None:
assert ntsc.parse_config(None, [], [], []) == {}

Expand Down Expand Up @@ -582,7 +583,7 @@ def test_preview_search(tmp_path: pathlib.Path) -> None:
"name": "random",
"metric": "loss",
"max_trials": 10,
}
},
}
conf_path = tmp_path / "config.yaml"
with conf_path.open("w") as tmp_file:
Expand All @@ -593,7 +594,7 @@ def test_preview_search(tmp_path: pathlib.Path) -> None:
config=searcher_config,
runs={
"1": bindings.v1SearchUnit(undefined=True),
}
},
)
)
with util.standard_cli_rsps() as rsps:
Expand All @@ -609,6 +610,4 @@ def test_preview_search(tmp_path: pathlib.Path) -> None:
],
json=mock_resp.to_json(),
)
cli.main(
["preview-search", str(conf_path)]
)
cli.main(["preview-search", str(conf_path)])
2 changes: 1 addition & 1 deletion master/internal/experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ type (
SearcherState json.RawMessage `json:"searcher_state"`
RunSearcherState map[int32]experiment.RunSearcherState `json:"run_searcher_state"`
}

internalExperiment struct {
mu sync.Mutex

Expand Down
2 changes: 1 addition & 1 deletion master/internal/experiment/experiment_iface.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type (
type Experiment interface {
RunReportProgress(runID int32, msg RunReportProgress) error
RunReportValidation(runID int32, metrics map[string]interface{}) error
//TrialGetSearcherState(runID int32) (RunSearcherState, error)
// TrialGetSearcherState(runID int32) (RunSearcherState, error)
UserInitiatedEarlyRunExit(msg UserInitiatedEarlyRunExit) error
PatchRunState(msg PatchRunState) error
SetGroupMaxSlots(msg sproto.SetGroupMaxSlots)
Expand Down
7 changes: 4 additions & 3 deletions master/internal/restore.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
"database/sql"
"encoding/json"
"fmt"

"github.com/pkg/errors"
log "github.com/sirupsen/logrus"

"github.com/determined-ai/determined/master/internal/db"
"github.com/determined-ai/determined/master/internal/experiment"
"github.com/determined-ai/determined/master/internal/rm"
Expand All @@ -16,8 +20,6 @@ import (
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/master/pkg/schemas"
"github.com/determined-ai/determined/master/pkg/searcher"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)

// The current experiment snapshot version. Once this is incremented, older versions should be
Expand Down Expand Up @@ -448,7 +450,6 @@ func shimExperimentSnapshotV4(snapshot []byte) ([]byte, error) {
// - `trial_searcher_state.Stop` -> dropped
// - `trial_searcher_state.Closed` -> `run_searcher_state.Closed`
func shimExperimentSnapshotV5(snapshot []byte) ([]byte, error) {

type v4SearcherState struct {
TrialsRequested int `json:"trials_requested"`
TrialsCreated map[model.RequestID]bool `json:"trials_created"`
Expand Down
1 change: 1 addition & 0 deletions master/pkg/searcher/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package searcher

import (
"fmt"

"github.com/determined-ai/determined/master/pkg/nprand"
)

Expand Down
3 changes: 2 additions & 1 deletion master/pkg/searcher/adaptive_asha_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
package searcher

import (
"github.com/stretchr/testify/require"
"testing"

"github.com/stretchr/testify/require"

"gotest.tools/assert"

"github.com/determined-ai/determined/master/pkg/ptrs"
Expand Down
6 changes: 2 additions & 4 deletions master/pkg/searcher/asha_stopping.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package searcher
import (
"encoding/json"
"fmt"
"github.com/determined-ai/determined/master/pkg/ptrs"
"math"
"sort"

"github.com/determined-ai/determined/master/pkg/ptrs"

"github.com/determined-ai/determined/master/pkg/mathx"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"
Expand All @@ -18,8 +19,6 @@ import (
// to stop or continue training the trial based on the ranking of the validation metric
// compared to other trials in a particular rung. Once a trial has been stopped, it will not
// be resumed later; this is why the algorithm does not require fault tolerance.
// The searcher state and config match that of AsyncHalvingSearch but we will only run
// the stopping based version if StopOnce is true.
type asyncHalvingStoppingSearch struct {
expconf.AsyncHalvingConfig
SmallerIsBetter bool
Expand Down Expand Up @@ -219,7 +218,6 @@ func (s *asyncHalvingStoppingSearch) stopRun(
// If this is the top rung, close the run and exit.
if r == s.NumRungs()-1 {
actions = append(actions, NewStop(runID))
//s.ClosedTrials[requestID] = true
return actions
}

Expand Down
6 changes: 4 additions & 2 deletions master/pkg/searcher/asha_stopping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
package searcher

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/master/pkg/schemas"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"
"github.com/stretchr/testify/require"
"testing"
)

func TestMakeRungs(t *testing.T) {
Expand Down
3 changes: 2 additions & 1 deletion master/pkg/searcher/random_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
package searcher

import (
"github.com/stretchr/testify/require"
"testing"

"github.com/stretchr/testify/require"

"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/master/pkg/schemas"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"
Expand Down
6 changes: 4 additions & 2 deletions master/pkg/searcher/simulate.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package searcher

import (
"fmt"
"math/rand"

"github.com/pkg/errors"

"github.com/determined-ai/determined/master/pkg/mathx"
"github.com/determined-ai/determined/master/pkg/protoutils"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"
"github.com/determined-ai/determined/proto/pkg/experimentv1"
"github.com/pkg/errors"
"math/rand"
)

// ValidationFunction calculates the validation metric for the validation step.
Expand Down
6 changes: 4 additions & 2 deletions master/pkg/searcher/simulate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package searcher

import (
"fmt"
"testing"

"github.com/stretchr/testify/require"

"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/master/pkg/schemas"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"
"github.com/stretchr/testify/require"
"testing"
)

func TestSimulate(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions master/pkg/searcher/tournament.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package searcher

import (
"encoding/json"

"github.com/pkg/errors"

"github.com/determined-ai/determined/master/pkg/model"
Expand Down

0 comments on commit f987818

Please sign in to comment.