forked from iandurbach/ml-for-ecology
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathregtrees-validation.R
71 lines (52 loc) · 1.84 KB
/
regtrees-validation.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#### Regression trees II: Validation
# - dividing your data into training, validation, and test sets
# - model validation (accuracy on test set)
# - overfitting
library(tree)
load("data/aloe.RData")
head(aloe)
# make training, validation, test datasets (60/20/20 split)
# shuffle rows
set.seed(123)
aloe <- aloe[sample(1:nrow(aloe)),]
# get numbers in train, valid, test sets
ntrain <- round(0.6 * nrow(aloe))
nvalid <- round(0.2 * nrow(aloe))
ntest <- nrow(aloe) - ntrain - nvalid
# allocate data to train, valid, test sets
aloe$train_id <- c(rep(1,ntrain), rep(2, nvalid), rep(3, ntest))
#### fit data on train + valid, assess on test
# since we're not doing any fine tuning (e.g. pruning), we don't need
# a separate validation set (see later)
# build tree
tree_aloe <- tree(log(tottrees) ~ Latitude + Longitude ,
data = subset(aloe, train_id != 3),
split = "deviance")
# plot the tree
plot(tree_aloe)
text(tree_aloe, cex=0.9)
# assess training accuracy
pred_aloe <- predict(tree_aloe)
mean((tree_aloe$y - pred_aloe)^2)
# assess *test* accuracy
pred_aloe <- predict(tree_aloe, newdata = subset(aloe, train_id == 3))
observed <- log(aloe[aloe$train_id == 3, "tottrees"])
mean((observed - pred_aloe)^2)
## try again with the overfitted tree
# build tree
tree_aloe <- tree(log(tottrees) ~ Latitude + Longitude ,
data = subset(aloe, train_id != 3),
split = "deviance",
mincut = 1,
minsize = 2,
mindev = 0)
# plot the tree
plot(tree_aloe)
text(tree_aloe, cex=0.9)
# assess training accuracy
pred_aloe <- predict(tree_aloe)
mean((tree_aloe$y - pred_aloe)^2)
# assess *test* accuracy
pred_aloe <- predict(tree_aloe, newdata = subset(aloe, train_id == 3))
observed <- log(aloe[aloe$train_id == 3, "tottrees"])
mean((observed - pred_aloe)^2)