This repository has been archived by the owner on Jul 16, 2018. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.cpp
121 lines (100 loc) · 2.8 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#include "random-forest.h"
#include "data-reader.h"
#include <cstdio>
#include <string>
#include <unordered_map>
#include <signal.h>
using ArgsTable = std::unordered_map<std::string, std::string>;
void TableValToInt(const ArgsTable &table, const std::string &key, int &val) {
if (table.count(key) > 0) {
sscanf(table.at(key).c_str(), "%d", &val);
}
}
void ShowHint() {
printf("Use rf train|test train_data|test_data\n");
}
constexpr char kTreeBinFile[] = "tree.bin";
constexpr char kTestResFile[] = "test_res.csv";
RandomForest *p_rf = nullptr;
void SaveAll() {
if (!p_rf) {
printf("No RandomForest instance yet..");
} else {
printf("Saving all the things...");
p_rf->SaveTest(kTestResFile);
p_rf->SaveTreesToFile(kTreeBinFile);
}
}
void HandleSignal(int sig) {
if (sig == SIGINT) {
SaveAll();
printf("Saving done!\nUse `kill -30` to kill this process\n");
} else if (sig == 30) {
SaveAll();
exit(1);
}
}
int main(int argc, char *args[]) {
if (argc <= 2) {
ShowHint();
return 1;
}
signal(SIGINT, HandleSignal);
signal(30, HandleSignal);
ArgsTable table;
std::string first, second;
for (int i = 3; i < argc; ++i) {
if (i % 2 == 1) {
first = args[i];
} else {
second = args[i];
table[first] = second;
}
}
// printf("ArgTable:\n");
// for (auto &pair : table) {
// printf("%s: %s\n", pair.first.c_str(), pair.second.c_str());
// }
std::string arg1 = args[1];
std::string arg2 = args[2];
DataReader reader(arg2);
int threading = -1;
TableValToInt(table, "-p", threading);
int verbose = 1;
TableValToInt(table, "-v", verbose);
Logger logger(verbose, false);
if (arg1 == "train") {
// TreeInfo
int max_depth = 10;
int min_samples_split = 2;
int tree_count = 100;
int one_sample_size = 1000;
TableValToInt(table, "-d", max_depth);
TableValToInt(table, "-min-split", min_samples_split);
TableValToInt(table, "-c", tree_count);
TableValToInt(table, "-sample-size", one_sample_size);
DecisionTreeInfo info;
info.max_depth = max_depth;
info.min_samples_split = min_samples_split;
RandomForest rf(201, reader.samples, threading, info, tree_count, one_sample_size, logger);
p_rf = &rf;
rf.CalcTrees();
rf.SaveTreesToFile("tree.bin");
} else if (arg1 == "test") {
RandomForest rf(201, reader.samples, threading, DecisionTreeInfo(), 100, 1000, logger);
p_rf = &rf;
rf.LoadTreesFromFile("tree.bin");
rf.TestAndSave("test_res.csv");
} else if (arg1 == "print") {
RandomForest rf(201, reader.samples, threading, DecisionTreeInfo(), 100, 1000, logger);
p_rf = &rf;
rf.LoadTreesFromFile("tree.bin");
auto &trees = rf.trees;
for (auto &tree : trees) {
tree.TryTree();
}
} else {
ShowHint();
}
return 0;
}