-
Notifications
You must be signed in to change notification settings - Fork 0
/
scalatree.scala
299 lines (253 loc) · 8.5 KB
/
scalatree.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
package com.scalatrees
package object stree {
/**
* This is a wrapper class for the tree-like source structure
* that keeps track of the parameters used in creating
* the tree.
*/
class Stree(
val numParam: Int,
val funcList: List[Tfunc],
val maxDepth: Int = 5,
val prFunc: Float = 0.6f,
val prParam: Float = 0.5f,
val constFunc: ()=>Any=()=>util.Random.nextInt(100),
var root: Node = null ) {
if (root==null) {
root = random_tree()
}
/**
* Generates a random tree using the given parameters.
*/
def random_tree(depth: Int=0,atroot: Boolean=true): Node = {
val roll = util.Random.nextFloat()
if (atroot || ((roll < prFunc) && (depth < maxDepth))) {
// make a function node here, and recurse.
val newfunc = choice(funcList)
// Recusively create children subtrees.
var children = for (i <- 1 to newfunc.numParam)
yield random_tree(depth+1,false)
// Wrap it up in an fnode and return.
new Fnode(newfunc, children.toList.asInstanceOf[List[Node]])
} else if (roll < prParam) {
// Make a parameter node.
new Pnode(util.Random.nextInt(numParam))
} else {
// Make a constant node.
new Cnode(constFunc())
}
}
/**
* Recurses through the tree and randomly mutates
* its subtrees.
*/
def mutate(probMut: Float=0.15f){
root = _mutate(root, probMut)
}
def _mutate(subtree: Node, probMut: Float=0.15f, depth: Int=0): Node = {
if (util.Random.nextFloat() < probMut) {
// Return a brand new subtree.
random_tree(depth)
} else {
// If this is a function node:
if (subtree.isInstanceOf[Fnode]) {
// Mutate its children:
subtree.asInstanceOf[Fnode].children = for (child <- subtree.asInstanceOf[Fnode].children)
yield (_mutate(child, probMut, depth+1))
}
// Return the current subtree, mutated or not.
subtree
}
}
/**
* Recurses through this and an 'other' tree, randomly replaces
* subtrees on this tree with subtrees from the other.
*/
def crossbreed(otherroot: Node, probCross: Float=0.15f): Stree = {
var newroot = _crossbreed(root, otherroot, probCross)
new Stree(numParam, funcList, maxDepth, prFunc, prParam, constFunc, newroot)
}
def _crossbreed(thisroot: Node, otherroot: Node, probCross: Float=0.15f, atroot: Boolean=true): Node = {
if ((!atroot)&& (util.Random.nextFloat() < probCross)) {
// Cross these trees
otherroot
} else {
// See about crossing the childrens, if any:
if (thisroot.isInstanceOf[Fnode] && otherroot.isInstanceOf[Fnode]) {
// Randomly replace this node's children with the other node's children.
thisroot.asInstanceOf[Fnode].children = for (child <- thisroot.asInstanceOf[Fnode].children)
yield (_crossbreed(child, choice(otherroot.asInstanceOf[Fnode].children),probCross,false))
}
// Return the current root, whether crossed or not.
thisroot
}
}
/**
* Evaluates this source tree against a list of list containing parameters and their
* expected output. Returns a score based on the tree's performance.
*
* data should be of the form:
* ((x11,x12,x13...y1),
* (x21,x22... y2),
* ...
* (xn1, xn2... yn))
*/
def scoreAgainstData(data: List[List[Any]]):Int = {
val scores = for (v <- data) yield score(v)
(scores.sum / data.length).toInt
}
/**
* Returns absolute differenc between the tree's evaluation
* of some parameters and the expected result.
*/
def score(v: List[Any]):Int= {
val s = evaluate(v.dropRight(1)).asInstanceOf[Int] - v.last.asInstanceOf[Int]
if (s>0) s else -s
}
def printToString(paramlist: List[Any] = List()) {
root.printToString(paramlist)
}
def evaluate(paramlist: List[Any]): Any = {
root.evaluate(paramlist)
}
def test(paramlist: List[Any]) {
printToString(paramlist)
}
}
/**
* Represents a particular object in the source tree: could be
* a function, a parameter, or a constant.
*/
abstract class Node() {
val spacer = " "
val noder = "\\"
val stemmer = " |"
def printToString(paramlist: List[Any], indent: String=" ") {
print(indent)
}
def evaluate(paramlist: List[Any]): Any
}
/**
* Wrapper for a function that holds number of parameters
* and name of the function as well as the function object.
*/
class Tfunc(val name: String, val numParam:Int, val function: List[Any]=>Any){
}
/**
* A tree node that contains a function.
*/
class Fnode(val func: Tfunc, var children: List[Node]) extends Node() {
val name = func.name
val function = func.function
/**
* Recursively evaluate the children of this function,
* and evaluate this function to get the result.
*/
def evaluate(paramlist: List[Any]): Any =
function(
for (child <- children)
yield child.asInstanceOf[Node].evaluate(paramlist)
)
/**
* Prints the name of this function and recusively
* prints its children.
*/
override def printToString(paramlist: List[Any], indent: String=" ") {
super.printToString(paramlist, indent)
println(noder + name + "=" + evaluate(paramlist))
for (child <- children.dropRight(1)) child.printToString(paramlist, indent + spacer * 2 + stemmer)
children.last.printToString(paramlist, indent + spacer * 4)
}
}
/**
* A tree node that holds a parameter: ie, the
* index of the parameter, not its literal value.
*/
class Pnode(paramid: Int) extends Node() {
/**
* Return the value of this parameter.
*/
def evaluate(paramlist: List[Any]): Any =
paramlist(paramid).asInstanceOf[Any]
/**
* Prints the parameter index and its value.
*/
override def printToString(paramlist: List[Any], indent: String=" ") {
super.printToString(paramlist, indent)
println(noder + paramToString(paramid) + "=" + evaluate(paramlist))
}
def paramToString(id: Int): String = {
"p[" + id + "]"
}
}
/**
*A tree node that holds a constant value.
*/
class Cnode(value: Any) extends Node() {
/**
* Return the value of this constant.
*/
def evaluate(paramlist: List[Any]): Any =
value
/**
* Prints the constant.
*/
override def printToString(paramlist: List[Any], indent: String=" ") {
super.printToString(paramlist, indent)
println(noder + value)
}
}
/**
* Useful Utility Functions
*/
/**
* Return a randomly chosen element
* from the list of stuff.
*/
def choice[A](stuff: List[A]): A = {
stuff(util.Random.nextInt(stuff.length))
}
/**
* Return a list of trees randomly generated from the given
* parameters.
*/
def makeForest(popsize: Int,
numParam: Int,
funcList: List[Tfunc],
maxDepth: Int = 5,
prFunc: Float = 0.6f,
prParam: Float = 0.5f,
constFunc: ()=>Any=()=>util.Random.nextInt(100)
): List[Stree] = {
for (i <- (0 to popsize-1).toList) yield new Stree(numParam, funcList, maxDepth, prFunc, prParam, constFunc)
}
/**
* Return a list of tuples containing a tree from the forest and its score against some
* given data.
*/
def scoreForest(forest: List[Stree], data: List[List[Int]]): List[(Stree,Int)] = {
forest.map((tree)=> (tree, tree.scoreAgainstData(data)))
}
/**
* Given a population of trees and some data, make a new generation of this population by:
* 1 - scoring each tree against this data
* 2 - removing a proportion, p, of the population
* 3 - crossbreeding the remaining trees randomly to create a new population
*/
def generateGeneration(forest: List[Stree], data: List[List[Int]], propToPrune: Float=0.5f, probCross: Float=0.5f, probMutate: Float=0.25f): List[Stree] = {
// Make a list of trees sorted by increasing score.
val sortedTreeScores = scoreForest(forest, data).sortBy(_._2)
// Take the best propToPrune number of trees.
val topTrees = sortedTreeScores.dropRight((propToPrune * sortedTreeScores.length).toInt).unzip._1
// Randomly cross trees here.
// For each tree in the top trees, randomly pick a second
// tree to cross with.
val pairs = topTrees map(tree => (tree,choice(topTrees)))
// Cross the pairs of trees to produce kid trees. Then mutate
// the kids 'cause, you know, it's awesome.
val kids = pairs map(pair => pair._1.crossbreed(pair._2.root, probCross))
kids map(kid => kid.mutate(probMutate))
// join the parent with the kid trees and return that.
topTrees ++ kids
}
}