Skip to content

Commit

Permalink
Use new suggest API
Browse files Browse the repository at this point in the history
  • Loading branch information
c-bata committed Mar 31, 2020
1 parent 07b30fa commit 84e1c1d
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 58 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ import (
// Define an objective function we want to minimize.
func objective(trial goptuna.Trial) (float64, error) {
// Define a search space of the input values.
x1, _ := trial.SuggestUniform("x1", -10, 10)
x2, _ := trial.SuggestUniform("x2", -10, 10)
x1, _ := trial.SuggestFloat("x1", -10, 10)
x2, _ := trial.SuggestFloat("x2", -10, 10)

// Here is a two-dimensional quadratic function.
// F(x1, x2) = (x1 - 2)^2 + (x2 + 5)^2
Expand Down
4 changes: 2 additions & 2 deletions _examples/cmaes/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
)

func objective(trial goptuna.Trial) (float64, error) {
x1, _ := trial.SuggestUniform("x1", -10, 10)
x2, _ := trial.SuggestUniform("x2", -10, 10)
x1, _ := trial.SuggestFloat("x1", -10, 10)
x2, _ := trial.SuggestFloat("x2", -10, 10)
return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

Expand Down
4 changes: 2 additions & 2 deletions _examples/concurrency/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
)

func objective(trial goptuna.Trial) (float64, error) {
x1, _ := trial.SuggestUniform("x1", -10, 10)
x2, _ := trial.SuggestUniform("x2", -10, 10)
x1, _ := trial.SuggestFloat("x1", -10, 10)
x2, _ := trial.SuggestFloat("x2", -10, 10)
return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

Expand Down
4 changes: 2 additions & 2 deletions _examples/enqueue_trial/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
)

func objective(trial goptuna.Trial) (float64, error) {
x1, _ := trial.SuggestUniform("x1", -10, 10)
x2, _ := trial.SuggestUniform("x2", -10, 10)
x1, _ := trial.SuggestFloat("x1", -10, 10)
x2, _ := trial.SuggestFloat("x2", -10, 10)
return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

Expand Down
4 changes: 2 additions & 2 deletions _examples/signalhandling/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import (
func objective(trial goptuna.Trial) (float64, error) {
ctx := trial.GetContext()

x1, _ := trial.SuggestUniform("x1", -10, 10)
x2, _ := trial.SuggestUniform("x2", -10, 10)
x1, _ := trial.SuggestFloat("x1", -10, 10)
x2, _ := trial.SuggestFloat("x2", -10, 10)

cmd := exec.CommandContext(ctx, "sleep", "1")
err := cmd.Run()
Expand Down
4 changes: 2 additions & 2 deletions _examples/simple_rdb/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ import (
)

func objective(trial goptuna.Trial) (float64, error) {
x1, _ := trial.SuggestUniform("x1", -10, 10)
x2, _ := trial.SuggestUniform("x2", -10, 10)
x1, _ := trial.SuggestFloat("x1", -10, 10)
x2, _ := trial.SuggestFloat("x2", -10, 10)
return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

Expand Down
4 changes: 2 additions & 2 deletions _examples/simple_tpe/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
)

func objective(trial goptuna.Trial) (float64, error) {
x1, _ := trial.SuggestUniform("x1", -10, 10)
x2, _ := trial.SuggestUniform("x2", -10, 10)
x1, _ := trial.SuggestFloat("x1", -10, 10)
x2, _ := trial.SuggestFloat("x2", -10, 10)
return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

Expand Down
4 changes: 2 additions & 2 deletions _examples/trialnotify/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
)

func objective(trial goptuna.Trial) (float64, error) {
x1, _ := trial.SuggestUniform("x1", -10, 10)
x2, _ := trial.SuggestUniform("x2", -10, 10)
x1, _ := trial.SuggestFloat("x1", -10, 10)
x2, _ := trial.SuggestFloat("x2", -10, 10)
return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

Expand Down
34 changes: 17 additions & 17 deletions sampler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,22 +208,22 @@ func TestRelativeSampler(t *testing.T) {

// First trial cannot trigger relative sampler.
err = study.Optimize(func(trial goptuna.Trial) (f float64, e error) {
_, _ = trial.SuggestUniform("uniform", -10, 10)
_, _ = trial.SuggestLogUniform("log_uniform", 1e-10, 1e10)
_, _ = trial.SuggestFloat("uniform", -10, 10)
_, _ = trial.SuggestLogFloat("log_uniform", 1e-10, 1e10)
_, _ = trial.SuggestInt("int", -10, 10)
_, _ = trial.SuggestDiscreteUniform("discrete", -10, 10, 0.5)
_, _ = trial.SuggestDiscreteFloat("discrete", -10, 10, 0.5)
_, _ = trial.SuggestCategorical("categorical", []string{"choice1", "choice2", "choice3"})
return 0.0, nil
}, 1)

// Second trial call relative sampler.
err = study.Optimize(func(trial goptuna.Trial) (f float64, e error) {
uniformParam, _ := trial.SuggestUniform("uniform", -10, 10)
uniformParam, _ := trial.SuggestFloat("uniform", -10, 10)
if uniformParam != 3 {
t.Errorf("should be 3, but got %f", uniformParam)
}

logUniformParam, _ := trial.SuggestLogUniform("log_uniform", 1e-10, 1e10)
logUniformParam, _ := trial.SuggestLogFloat("log_uniform", 1e-10, 1e10)
if logUniformParam != 100 {
t.Errorf("should be 100, but got %f", logUniformParam)
}
Expand All @@ -233,7 +233,7 @@ func TestRelativeSampler(t *testing.T) {
t.Errorf("should be 7, but got %d", intParam)
}

discreteParam, _ := trial.SuggestDiscreteUniform("discrete", -10, 10, 0.5)
discreteParam, _ := trial.SuggestDiscreteFloat("discrete", -10, 10, 0.5)
if discreteParam != 5.5 {
t.Errorf("should be 5.5, but got %f", discreteParam)
}
Expand Down Expand Up @@ -274,8 +274,8 @@ func TestRelativeSampler_UnsupportedSearchSpace(t *testing.T) {

// First trial cannot trigger relative sampler.
err = study.Optimize(func(trial goptuna.Trial) (f float64, e error) {
_, _ = trial.SuggestUniform("x1", -10, 10)
_, _ = trial.SuggestLogUniform("x2", 1e-10, 1e10)
_, _ = trial.SuggestFloat("x1", -10, 10)
_, _ = trial.SuggestLogFloat("x2", 1e-10, 1e10)
return 0.0, nil
}, 1)
if err != nil {
Expand All @@ -285,11 +285,11 @@ func TestRelativeSampler_UnsupportedSearchSpace(t *testing.T) {

// Second trial. RelativeSampler return ErrUnsupportedSearchSpace.
err = study.Optimize(func(trial goptuna.Trial) (f float64, e error) {
_, e = trial.SuggestUniform("x1", -10, 10)
_, e = trial.SuggestFloat("x1", -10, 10)
if e != nil {
t.Errorf("err should be nil, but got %s", e)
}
_, e = trial.SuggestLogUniform("x2", 1e-10, 1e10)
_, e = trial.SuggestLogFloat("x2", 1e-10, 1e10)
if e != nil {
t.Errorf("err should be nil, but got %s", e)
}
Expand Down Expand Up @@ -331,7 +331,7 @@ func TestIntersectionSearchSpace(t *testing.T) {

if err = study.Optimize(func(trial goptuna.Trial) (v float64, e error) {
x, _ := trial.SuggestInt("x", 0, 10)
y, _ := trial.SuggestUniform("y", -3, 3)
y, _ := trial.SuggestFloat("y", -3, 3)
return float64(x) + y, nil
}, 1); err != nil {
panic(err)
Expand Down Expand Up @@ -361,15 +361,15 @@ func TestIntersectionSearchSpace(t *testing.T) {
// First Trial
if err = study.Optimize(func(trial goptuna.Trial) (v float64, e error) {
x, _ := trial.SuggestInt("x", 0, 10)
y, _ := trial.SuggestUniform("y", -3, 3)
y, _ := trial.SuggestFloat("y", -3, 3)
return float64(x) + y, nil
}, 1); err != nil {
panic(err)
}

// Second Trial
if err = study.Optimize(func(trial goptuna.Trial) (v float64, e error) {
y, _ := trial.SuggestUniform("y", -3, 3)
y, _ := trial.SuggestFloat("y", -3, 3)
return y, nil
}, 1); err != nil {
panic(err)
Expand All @@ -395,28 +395,28 @@ func TestIntersectionSearchSpace(t *testing.T) {
// First Trial
if err = study.Optimize(func(trial goptuna.Trial) (v float64, e error) {
x, _ := trial.SuggestInt("x", 0, 10)
y, _ := trial.SuggestUniform("y", -3, 3)
y, _ := trial.SuggestFloat("y", -3, 3)
return float64(x) + y, nil
}, 1); err != nil {
panic(err)
}

// Second Trial
if err = study.Optimize(func(trial goptuna.Trial) (v float64, e error) {
y, _ := trial.SuggestUniform("y", -3, 3)
y, _ := trial.SuggestFloat("y", -3, 3)
return y, nil
}, 1); err != nil {
panic(err)
}

// Failed trial (ignore error)
_ = study.Optimize(func(trial goptuna.Trial) (v float64, e error) {
_, _ = trial.SuggestUniform("y", -3, 3)
_, _ = trial.SuggestFloat("y", -3, 3)
return 0.0, errors.New("something error")
}, 1)
// Pruned trial
if err = study.Optimize(func(trial goptuna.Trial) (v float64, e error) {
_, _ = trial.SuggestUniform("y", -3, 3)
_, _ = trial.SuggestFloat("y", -3, 3)
return 0.0, goptuna.ErrTrialPruned
}, 1); err != nil {
panic(err)
Expand Down
12 changes: 6 additions & 6 deletions study_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ func ExampleStudy_Optimize() {
)

objective := func(trial goptuna.Trial) (float64, error) {
x1, _ := trial.SuggestUniform("x1", -10, 10)
x2, _ := trial.SuggestUniform("x2", -10, 10)
x1, _ := trial.SuggestFloat("x1", -10, 10)
x2, _ := trial.SuggestFloat("x2", -10, 10)
return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

Expand Down Expand Up @@ -76,8 +76,8 @@ func ExampleStudy_EnqueueTrial() {
)

objective := func(trial goptuna.Trial) (float64, error) {
x1, _ := trial.SuggestUniform("x1", -10, 10)
x2, _ := trial.SuggestUniform("x2", -10, 10)
x1, _ := trial.SuggestFloat("x1", -10, 10)
x2, _ := trial.SuggestFloat("x2", -10, 10)
return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

Expand Down Expand Up @@ -108,8 +108,8 @@ func TestStudy_EnqueueTrial_WithUnfixedParameter(t *testing.T) {
)

objective := func(trial goptuna.Trial) (float64, error) {
x1, _ := trial.SuggestUniform("x1", -10, 10)
x2, _ := trial.SuggestUniform("x2", -10, 10)
x1, _ := trial.SuggestFloat("x1", -10, 10)
x2, _ := trial.SuggestFloat("x2", -10, 10)
return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

Expand Down
32 changes: 19 additions & 13 deletions trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ func (t *Trial) SuggestLogUniform(name string, low, high float64) (float64, erro
return t.SuggestLogFloat(name, low, high)
}

// SuggestDiscreteUniform suggests a value from a discrete uniform distribution.
// Deprecated: Please use SuggestDiscreteFloat method.
func (t *Trial) SuggestDiscreteUniform(name string, low, high, q float64) (float64, error) {
return t.SuggestDiscreteFloat(name, low, high, q)
}

// SuggestFloat suggests a value for the floating point parameter.
func (t *Trial) SuggestFloat(name string, low, high float64) (float64, error) {
if low > high {
Expand All @@ -194,19 +200,8 @@ func (t *Trial) SuggestLogFloat(name string, low, high float64) (float64, error)
})
}

// SuggestInt suggests an integer parameter.
func (t *Trial) SuggestInt(name string, low, high int) (int, error) {
if low > high {
return 0, errors.New("'low' must be smaller than or equal to the 'high'")
}
v, err := t.suggest(name, IntUniformDistribution{
High: high, Low: low,
})
return int(v), err
}

// SuggestDiscreteUniform suggests a value from a discrete uniform distribution.
func (t *Trial) SuggestDiscreteUniform(name string, low, high, q float64) (float64, error) {
// SuggestDiscreteFloat suggests a value for the discrete floating point parameter.
func (t *Trial) SuggestDiscreteFloat(name string, low, high, q float64) (float64, error) {
if low > high {
return 0, errors.New("'low' must be smaller than or equal to the 'high'")
}
Expand All @@ -220,6 +215,17 @@ func (t *Trial) SuggestDiscreteUniform(name string, low, high, q float64) (float
return d.ToExternalRepr(ir).(float64), err
}

// SuggestInt suggests an integer parameter.
func (t *Trial) SuggestInt(name string, low, high int) (int, error) {
if low > high {
return 0, errors.New("'low' must be smaller than or equal to the 'high'")
}
v, err := t.suggest(name, IntUniformDistribution{
High: high, Low: low,
})
return int(v), err
}

// SuggestCategorical suggests an categorical parameter.
func (t *Trial) SuggestCategorical(name string, choices []string) (string, error) {
if len(choices) == 0 {
Expand Down
12 changes: 6 additions & 6 deletions trial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestTrial_Suggest(t *testing.T) {
name: "SuggestUniform",
objective: func(trial goptuna.Trial) (float64, error) {
// low is larger than high
x1, err := trial.SuggestUniform("x", -10, 10)
x1, err := trial.SuggestFloat("x", -10, 10)
if err != nil {
return -1, err
}
Expand All @@ -29,7 +29,7 @@ func TestTrial_Suggest(t *testing.T) {
name: "SuggestUniform: low is larger than high",
objective: func(trial goptuna.Trial) (float64, error) {
// low is larger than high
x1, err := trial.SuggestUniform("x", 10, -10)
x1, err := trial.SuggestFloat("x", 10, -10)
if err != nil {
return -1, err
}
Expand All @@ -40,7 +40,7 @@ func TestTrial_Suggest(t *testing.T) {
{
name: "SuggestLogUniform",
objective: func(trial goptuna.Trial) (float64, error) {
x1, err := trial.SuggestLogUniform("x", 1e5, 1e10)
x1, err := trial.SuggestLogFloat("x", 1e5, 1e10)
if err != nil {
return -1, err
}
Expand All @@ -51,7 +51,7 @@ func TestTrial_Suggest(t *testing.T) {
{
name: "SuggestLogUniform: low is larger than high",
objective: func(trial goptuna.Trial) (float64, error) {
x1, err := trial.SuggestLogUniform("x", 1e10, 1e5)
x1, err := trial.SuggestLogFloat("x", 1e10, 1e5)
if err != nil {
return -1, err
}
Expand All @@ -62,7 +62,7 @@ func TestTrial_Suggest(t *testing.T) {
{
name: "SuggestDiscreteUniform",
objective: func(trial goptuna.Trial) (float64, error) {
x1, err := trial.SuggestDiscreteUniform("x", -10, 10, 0.5)
x1, err := trial.SuggestDiscreteFloat("x", -10, 10, 0.5)
if err != nil {
return -1, err
}
Expand All @@ -73,7 +73,7 @@ func TestTrial_Suggest(t *testing.T) {
{
name: "SuggestDiscreteUniform: low is larger than high",
objective: func(trial goptuna.Trial) (float64, error) {
x1, err := trial.SuggestDiscreteUniform("x", 10, -10, 0.5)
x1, err := trial.SuggestDiscreteFloat("x", 10, -10, 0.5)
if err != nil {
return -1, err
}
Expand Down

0 comments on commit 84e1c1d

Please sign in to comment.