-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
163 lines (149 loc) · 5.79 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
package main
import (
crand "crypto/rand"
"flag"
"fmt"
"math/big"
"math/rand"
"sync"
"github.com/franciscobonand/symb-regr-gp/datasets"
"github.com/franciscobonand/symb-regr-gp/operator"
pop "github.com/franciscobonand/symb-regr-gp/population"
"github.com/franciscobonand/symb-regr-gp/stats"
)
var (
popSize, tournamentSize, threads, generations, nElitism int
file, sel, statsfile string
crossProb, mutProb float64
seed int64
)
func main() {
// ./symb-regr-gp -popsize 20 -selector tour -toursize 2 -gens 20 -threads 1 -file "abcd.csv" -cxprob 0.9 -mutprob 0.05 -elitism 0 -seed 4132 -getstats
initializeFlags()
if !allPositiveInts(popSize, threads, generations) {
panic("Invalid value for popsize, gens or threads, must be a positive integer")
}
if nElitism < 0 {
panic("Elitism size must be at least 0")
}
if sel == "tour" && tournamentSize < 2 {
panic("Tournament size must be at least 2")
}
if crossProb < 0.0 || mutProb < 0.0 || crossProb > 1.0 || mutProb > 1.0 {
panic("Genetic operators probability must be between 0.0 and 1.0")
}
// fmt.Println(popSize, tournamentSize, threads, file, crossProb, mutProb, seed)
ds, err := dataset.Read(file)
if err != nil {
panic(err.Error())
}
var runqnt int64 = 1
var run int64
getstats := statsfile != ""
if getstats {
runqnt = 30
}
rundata := [][]float64{}
for run = 0; run < runqnt; run++ {
setSeed(seed + run)
opset := operator.CreateOpSet(ds.Variables...)
gen := pop.NewRampedGenerator(opset, 1, 6)
rmse := pop.RMSE{ DS: ds }
// Create initial population
p := pop.CreatePopulation(popSize, gen)
// Define selection method and genetic operators
var selector pop.Selector
if sel == "rol" {
selector = pop.RouletteSelector(nElitism, rmse)
} else if sel == "tour" {
selector = pop.TournamentSelector(nElitism, tournamentSize, threads, rmse)
} else if sel == "lex" {
selector = pop.LexicaseSelector(nElitism, threads, rmse, ds.Copy())
} else {
selector = pop.RandomSelector(nElitism)
}
mut := pop.MutationOp(gen, rmse)
cross := pop.CrossoverOp(rmse)
// Run Fitness for initial population
p, e := p.Evaluate(rmse, threads)
var betterCxChild, worseCxChild float64
var wg sync.WaitGroup
if !getstats {
wg.Add(generations + 1)
fmt.Println("gen,evals,repeated,bestfit,worstfit,meanfit,maxsize,minsize,meansize,betterCxChild,worseCxChild")
go stats.PrintRunStats(&wg, 0, float64(e), betterCxChild, worseCxChild, p, rmse)
}
rundata = append(rundata, stats.GetRunStats(0.0, float64(e), betterCxChild, worseCxChild, p, rmse))
fgen := float64(generations)
for i := 0.0; i < fgen; i++ {
// Selects new population
children := selector.Select(p, len(p))
// appleis genetic operators
p, betterCxChild, worseCxChild = pop.ApplyGeneticOps(children, cross, mut, crossProb, mutProb)
p, e = p.Evaluate(rmse, threads)
// print new population stats
if getstats {
rundata = append(rundata, stats.GetRunStats(i+1.0, float64(e), betterCxChild, worseCxChild, p, rmse))
} else {
go stats.PrintRunStats(&wg, i+1.0, float64(e), betterCxChild, worseCxChild, p, rmse)
}
}
if !getstats {
wg.Wait()
}
best := p.Best(rmse)
fmt.Println(best)
}
if getstats {
output := [][]float64{}
fmt.Println("Writing stats to file...")
for k := 0; k <= generations; k++ {
currgen := []float64{}
for col := 1; col < len(rundata[0]); col++ {
acc := 0.0
for lin := k; lin < len(rundata); lin += (generations + 1) {
acc += rundata[lin][col]
}
currgen = append(currgen, acc/30.0)
}
output = append(output, currgen)
}
if err := dataset.Write(statsfile, output); err != nil {
fmt.Println("(ERROR) failed to write stats file:", err.Error())
} else {
fmt.Printf("Stats file available in 'analysis/%s'\n", statsfile)
}
}
}
func initializeFlags() {
flag.IntVar(&popSize, "popsize", 20, "population size")
flag.IntVar(&nElitism, "elitism", 0, "number of best members of elitism")
flag.IntVar(&tournamentSize, "toursize", 2, "tournament size")
flag.StringVar(&sel, "selector", "tour", "defines the selection method ('rol', 'tour', 'lex', or 'rand')")
flag.StringVar(&statsfile, "statsfile", "", "generate stats and saves into given file")
flag.IntVar(&generations, "gens", 10, "number of generations to run")
flag.IntVar(&threads, "threads", 1, "quantity of threads to be used when evaluating")
flag.StringVar(&file, "file", "datasets/synth1/synth1-train.csv", "csv file containing data to be processed")
flag.Float64Var(&crossProb, "cxprob", 0.9, "crossover probability")
flag.Float64Var(&mutProb, "mutprob", 0.05, "mutation probability")
flag.Int64Var(&seed, "seed", 1, "seed for generating the initial population")
flag.Parse()
}
// setSeed sets the given number as seed, or a random value if seed is <= 0
func setSeed(seed int64) int64 {
if seed <= 0 {
max := big.NewInt(2<<31 - 1)
rseed, _ := crand.Int(crand.Reader, max)
seed = rseed.Int64()
}
rand.Seed(seed)
return seed
}
func allPositiveInts(nums... int) bool {
for _, n := range nums {
if n <= 0 {
return false
}
}
return true
}