-
Notifications
You must be signed in to change notification settings - Fork 49
Basic decision tree
Download and install Titus. This article was tested with Titus 0.7.1; newer versions should work with no modification. Python >= 2.6 and < 3.0 is required.
Launch a Python prompt and import the following:
Python 2.7.6
Type "help", "copyright", "credits" or "license" for more information.
>>> import csv
>>> import math
>>>
>>> import titus.prettypfa
>>> from titus.genpy import PFAEngine
And download the Iris dataset.
The simplest algorithm for walking a decision tree is formed by combining model.tree.simpleWalk
with model.tree.simpleTest
. More complex variants are for non-binary trees, missing values, and compound predicates.
pfaDocument = titus.prettypfa.jsonNode('''
types:
Input = record(Input,
sepal_length: double,
sepal_width: double,
petal_length: double,
petal_width: double);
TreeNode = record(TreeNode,
field: enum([sepal_length, sepal_width, petal_length, petal_width]),
operator: string,
value: double,
pass: union(string, TreeNode),
fail: union(string, TreeNode))
input: Input
output: string
cells:
// empty tree to satisfy type constraint; will be filled in later
tree(TreeNode) = {field: "sepal_length", operator: "", value: 0.0, pass: {string: ""}, fail: {string: ""}}
action:
model.tree.simpleWalk(input, tree, fcn(d: Input, t: TreeNode -> boolean) {
model.tree.simpleTest(d, t)
})
''')
Titus has an implementation of CART in titus.producer.cart
with many features (numerical and categorical splits, different impurity criteria, etc.), but we'll use a simple implementation here and reuse most of the code for the random forest example.
class Datum(object):
def __init__(self, sepal_length, sepal_width, petal_length, petal_width, truth):
self.sepal_length = float(sepal_length)
self.sepal_width = float(sepal_width)
self.petal_length = float(petal_length)
self.petal_width = float(petal_width)
self.truth = truth.replace("Iris-", "")
def attemptSplit(field, data):
bestGain = None
bestSplit = None
values = sorted(set(getattr(x, field) for x in data))
splits = [(low + high)/2.0 for low, high in zip(values[:-1], values[1:])]
for split in splits:
gainPass = 0.0
gainFail = 0.0
for truth in "setosa", "versicolor", "virginica":
numer = sum(x.truth == truth for x in data if getattr(x, field) < split)
denom = sum(getattr(x, field) < split for x in data)
if numer > 0.0:
p = numer / float(denom)
gainPass -= p * math.log(p, 2)
numer = sum(x.truth == truth for x in data if getattr(x, field) >= split)
denom = sum(getattr(x, field) >= split for x in data)
if numer > 0.0:
p = numer / float(denom)
gainFail -= p * math.log(p, 2)
gain = gainPass + gainFail
if bestGain is None or gain > bestGain:
bestGain = gain
bestSplit = split
return bestGain, bestSplit
def predominant(data):
setosa = sum(x.truth == "setosa" for x in data)
versicolor = sum(x.truth == "versicolor" for x in data)
virginica = sum(x.truth == "virginica" for x in data)
if setosa > versicolor and setosa > virginica:
return "setosa"
elif versicolor > setosa and versicolor > virginica:
return "versicolor"
else:
return "virginica"
def ispure(data):
setosa = sum(x.truth == "setosa" for x in data)
versicolor = sum(x.truth == "versicolor" for x in data)
virginica = sum(x.truth == "virginica" for x in data)
if setosa == 0 and versicolor == 0:
return True
elif setosa == 0 and virginica == 0:
return True
elif versicolor == 0 and virginica == 0:
return True
else:
return False
def makeTree(data):
bestGain = None
bestField = None
bestSplit = None
for field in "sepal_length", "sepal_width", "petal_length", "petal_width":
gain, split = attemptSplit(field, data)
if bestGain is None or gain > bestGain:
bestGain = gain
bestField = field
bestSplit = split
dataPass = [x for x in data if getattr(x, bestField) < bestSplit]
dataFail = [x for x in data if getattr(x, bestField) >= bestSplit]
if ispure(dataPass):
branchPass = {"string": predominant(dataPass)}
else:
branchPass = makeTree(dataPass)
if ispure(dataFail):
branchFail = {"string": predominant(dataFail)}
else:
branchFail = makeTree(dataFail)
return {"TreeNode": {"field": bestField,
"operator": "<",
"value": bestSplit,
"pass": branchPass,
"fail": branchFail}}
iris = csv.reader(open("iris.csv"))
iris.next()
dataset = [Datum(*x) for x in iris]
tree = makeTree(dataset)["TreeNode"]
The only thing here that depends on PFA is the form of the return value of the makeTree
function.
This tree is already correctly formatted for PFA, so just attach it.
pfaDocument["cells"]["tree"]["init"] = tree
The tree-building algorithm above has no pruning and continues until all leaves are pure. Hence, it correctly labels 100% of the training dataset (because it's overtuned). Compare this with the ruleset example (95%) and the random forest example (typically 95%).
engine, = PFAEngine.fromJson(pfaDocument)
iris = csv.reader(open("iris.csv"))
iris.next()
numer = 0
denom = 0
for sepal_length, sepal_width, petal_length, petal_width, truth in iris:
result = engine.action({"sepal_length": sepal_length, "sepal_width": sepal_width, "petal_length": petal_length, "petal_width": petal_width})
print result, "vs", truth.replace("Iris-", "")
if result != truth.replace("Iris-", ""):
numer += 1
denom += 1
print "Fraction mislabeled:", float(numer) / float(denom)
Return to the Hadrian wiki table of contents.
Licensed under the Hadrian Personal Use and Evaluation License (PUEL).