Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BlackHoleStorage towards 100k+ evaluations #105

Merged
merged 14 commits into from
Apr 13, 2020
2 changes: 2 additions & 0 deletions .github/workflows/run-examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ jobs:
GO111MODULE: on
run: |
make build
./bin/cmaes
./bin/cmaes_blackhole
./bin/simple_tpe
./bin/concurrency
./bin/trialnotify
5 changes: 5 additions & 0 deletions _examples/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ DIR=$(cd $(dirname $0); pwd)
BIN_DIR=$(cd $(dirname $(dirname $0)); pwd)/bin

mkdir -p ${BIN_DIR}

set -ex

go build -o ${BIN_DIR}/cmaes ${DIR}/cmaes/main.go
go build -o ${BIN_DIR}/cmaes_blackhole ${DIR}/cmaes/blackhole/main.go
go build -o ${BIN_DIR}/concurrency ${DIR}/concurrency/main.go
go build -o ${BIN_DIR}/enqueue_trial ${DIR}/enqueue_trial/main.go
go build -o ${BIN_DIR}/trialnotify ${DIR}/trialnotify/main.go
Expand Down
59 changes: 59 additions & 0 deletions _examples/cmaes/blackhole/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package main

import (
"log"
"math"

"github.com/c-bata/goptuna"
"github.com/c-bata/goptuna/cmaes"
)

func objective(trial goptuna.Trial) (float64, error) {
x1, err := trial.SuggestFloat("x1", -10, 10)
if err != nil {
return -1, err
}
x2, err := trial.SuggestFloat("x2", -10, 10)
if err != nil {
return -1, err
}
return math.Pow(x1-2, 2) + math.Pow(x2+5, 2), nil
}

func main() {
relativeSampler := cmaes.NewSampler(
cmaes.SamplerOptionNStartupTrials(0))
study, err := goptuna.CreateStudy(
"goptuna-example",
goptuna.StudyOptionStorage(goptuna.NewBlackHoleStorage(20)),
goptuna.StudyOptionRelativeSampler(relativeSampler),
goptuna.StudyOptionDefineSearchSpace(map[string]interface{}{
"x1": goptuna.UniformDistribution{
High: 10,
Low: -10,
},
"x2": goptuna.UniformDistribution{
High: 10,
Low: -10,
},
}),
)
if err != nil {
log.Fatal("failed to create study:", err)
}

if err = study.Optimize(objective, 10000); err != nil {
log.Fatal("failed to optimize:", err)
}

v, err := study.GetBestValue()
if err != nil {
log.Fatal("failed to get best value:", err)
}
params, err := study.GetBestParams()
if err != nil {
log.Fatal("failed to get best params:", err)
}
log.Printf("Best evaluation=%f (x1=%f, x2=%f)",
v, params["x1"].(float64), params["x2"].(float64))
}
15 changes: 13 additions & 2 deletions cmaes/sampler.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (s *Sampler) SampleRelative(
sort.Strings(orderedKeys)

trials, err := study.GetTrials()
if err != nil {
if err != nil && err != goptuna.ErrTrialsPartiallyDeleted {
return nil, err
}
completed := make([]goptuna.FrozenTrial, 0, len(trials))
Expand All @@ -57,8 +57,19 @@ func (s *Sampler) SampleRelative(
}
}
if len(completed) < s.nStartUpTrials {
return nil, nil
// If catch ErrTrialsPartiallyDeleted, nStartUpTrials should be smaller than len(completed).
study.GetLogger().Error("Your BlackHoleStorage buffer is too small.",
fmt.Sprintf("nStartUpTrials:%d", s.nStartUpTrials))
return nil, err
}
if err == goptuna.ErrTrialsPartiallyDeleted && s.optimizer != nil &&
len(completed) < s.optimizer.PopulationSize() {
// If catch ErrTrialsPartiallyDeleted, population size should be smaller than len(completed).
study.GetLogger().Error("Your BlackHoleStorage buffer is too small.",
fmt.Sprintf("popsize:%d", s.optimizer.PopulationSize()))
return nil, err
}
err = nil

if s.optimizer == nil {
s.optimizer, err = s.initOptimizer(searchSpace, orderedKeys)
Expand Down
6 changes: 5 additions & 1 deletion sampler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@ func IntersectionSearchSpace(study *Study) (map[string]interface{}, error) {
var searchSpace map[string]interface{}

trials, err := study.GetTrials()
if err != nil {
if err == ErrTrialsPartiallyDeleted {
study.logger.Warn("Some trials are not used to calculate intersection of search spaces." +
" Please use `goptuna.StudyOptionDefineSearchSpace` option.")
err = nil
} else if err != nil {
return nil, err
}

Expand Down
Loading