Skip to content

Commit d35edcd

Browse files
committed
Create tests to check dispatch_table behavior
It seems that the fit method is not working properly when we have terminals of type ArrayXb. The sig_hash of the node is different from all of the available nodes in the dispatch_table, raising an error in dispatch_table.h:172. This commit introduces a simple test case to reproduce the error that I am getting. Ideally, We should fix this bug so this new test case does not get a core dump.
1 parent 0627058 commit d35edcd

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

src/variation.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -260,5 +260,5 @@ Program<T> cross(const Program<T>& root, const Program<T>& other)
260260

261261
return child;
262262
};
263-
} //namespace vary
263+
} //namespace variation
264264
#endif

tests/cpp/test_data.cpp

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#include "testsHeader.h"
2+
#include "../../src/search_space.h"
3+
#include "../../src/program/program.h"
4+
#include "../../src/program/dispatch_table.h"
5+
6+
TEST(Data, MixedVariableTypes)
7+
{
8+
// We need to set at least the mutation options (and respective
9+
// probabilities) in order to call PRG.predict()
10+
PARAMS["mutation_options"] = {
11+
{"point",0.25}, {"insert", 0.25}, {"delete", 0.25}, {"toggle_weight", 0.25}
12+
};
13+
14+
MatrixXf X(5,3);
15+
X << 0 , 1, 0 , // binary with integer values
16+
0.0, 1.0, 1.0, // binary with float values
17+
2 , 1.0, -3.0, // integer with float and negative values
18+
2 , 1 , 3 , // integer with integer values
19+
2.1, 3.7, -5.2; // float values
20+
21+
X.transposeInPlace();
22+
23+
ArrayXf y(5);
24+
25+
y << 6.1, 7.7, -4.2; // y = x_0 + x_1 + x_2
26+
27+
unordered_map<string, float> user_ops = {
28+
{"Add", 1},
29+
{"Sub", 1},
30+
{"SplitOn", 1}
31+
};
32+
33+
Dataset dt(X, y);
34+
SearchSpace SS;
35+
SS.init(dt, user_ops);
36+
37+
dt.print();
38+
SS.print();
39+
40+
for (int d = 1; d < 5; ++d)
41+
for (int s = 1; s < 5; ++s)
42+
{
43+
44+
PARAMS["max_size"] = s;
45+
PARAMS["max_depth"] = d;
46+
47+
RegressorProgram PRG = SS.make_regressor(d, s);
48+
fmt::print(
49+
"=================================================\n"
50+
"Tree model for depth = {}, size= {}: {}\n",
51+
d, s, PRG.get_model("compact", true)
52+
);
53+
54+
auto Child = PRG.mutate();
55+
fmt::print("Child model: {}\n", Child.get_model("compact", true));
56+
57+
std::for_each(PRG.Tree.begin(), PRG.Tree.end(),
58+
[](const auto& n) {
59+
fmt::print("Name {}, node {}, feature {}, sig_hash {}\n",
60+
n.name, n.node_type, n.get_feature(), n.sig_hash);
61+
});
62+
63+
std::cout << std::endl;
64+
65+
PRG.fit(dt);
66+
fmt::print( "PRG predict\n");
67+
ArrayXf y_pred = PRG.predict(dt);
68+
fmt::print( "y_pred: {}\n", y_pred);
69+
70+
Child.fit(dt);
71+
fmt::print( "Child predict\n");
72+
ArrayXf y_pred_child = Child.predict(dt);
73+
fmt::print( "y_pred: {}\n", y_pred);
74+
}
75+
76+
// Brush exports two DispatchTable structs named dtable_fit and dtable_predict.
77+
// These structures holds the mapping between nodes and its corresponding
78+
// operations, and are used to resolve the evaluation of an expression.
79+
// dtable_fit.print();
80+
// dtable_predict.print();
81+
}

0 commit comments

Comments
 (0)