-
Notifications
You must be signed in to change notification settings - Fork 3
/
model_utils.ml
166 lines (141 loc) · 5.63 KB
/
model_utils.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
open Model_t
open Dog_t
(* translate from the feature type defined in [Dog_t] to [Model_t] *)
let feature_d_to_m = function
| `Cat {
c_feature_id;
c_feature_name_opt;
c_categories;
c_anonymous_category } ->
`CategoricalFeature {
cf_feature_id = c_feature_id;
cf_feature_name_opt = c_feature_name_opt;
cf_categories = c_categories;
cf_anonymous_category_index_opt = c_anonymous_category;
}
| `Ord { o_feature_id; o_feature_name_opt } ->
`OrdinalFeature {
of_feature_id = o_feature_id;
of_feature_name_opt = o_feature_name_opt
}
(* convert int's to float's (model doesn't make a distinction, unlike
the dog file, in which this has implications on compression). We
also convert to arrays, so we can easily get the breakpoint value
corresponding to a split (index). *)
let float_array_of_breakpoints = function
| `Int int_list ->
let float_list = List.rev_map float_of_int (List.rev int_list) in
Array.of_list float_list
| `Float float_list ->
Array.of_list float_list
(* create a map from the feature id of ordinal features to their
breakpoints *)
let id_to_breakpoints id_to_feature =
Utils.IntMap.fold (
fun feature_id feature map ->
match feature with
| `Ord { o_breakpoints } ->
let float_array = float_array_of_breakpoints o_breakpoints in
Utils.IntMap.add feature_id float_array map
| `Cat _ -> map
) id_to_feature Utils.IntMap.empty
let rle_of_category_array directions =
let _, _, rle = Rle.encode_dense (Array.to_list directions) in
let dr_first_direction =
match rle with
| (_, `Left ) :: _ -> `Left
| (_, `Right) :: _ -> `Right
| _ -> assert false (* must have at least two direction! *)
in
let dr_run_lengths = List.rev_map fst (List.rev rle) in
{ dr_first_direction; dr_run_lengths }
let opposite_direction = function
| `Right -> `Left
| `Left -> `Right
let category_array_of_rle { dr_first_direction; dr_run_lengths } =
(* first, add a direction to each run length, so we can use [Rle.decode_runs_rev] *)
let _, runs_rev = List.fold_left (
fun (direction, runs_rev) run_length ->
let runs_rev = (run_length, direction) :: runs_rev in
let direction = opposite_direction direction in
direction, runs_rev
) (dr_first_direction, []) dr_run_lengths in
let _, directions = Rle.decode_rev runs_rev in
Array.of_list directions
let rec tree_l_to_c id_to_breakpoints = function
| `OrdinalNode { on_feature_id; on_split; on_left_tree; on_right_tree } ->
let breakpoints = Utils.IntMap.find on_feature_id id_to_breakpoints in
let on_split = breakpoints.( on_split ) in
let on_left_tree = tree_l_to_c id_to_breakpoints on_left_tree in
let on_right_tree = tree_l_to_c id_to_breakpoints on_right_tree in
`OrdinalNode { on_feature_id; on_split; on_left_tree; on_right_tree }
| `CategoricalNode {
cn_feature_id;
cn_category_directions;
cn_left_tree;
cn_right_tree
} ->
let cn_left_tree = tree_l_to_c id_to_breakpoints cn_left_tree in
let cn_right_tree = tree_l_to_c id_to_breakpoints cn_right_tree in
let cn_category_directions = rle_of_category_array cn_category_directions in
`CategoricalNode {
cn_feature_id;
cn_category_directions;
cn_left_tree;
cn_right_tree
}
| (`Leaf _) as leaf -> leaf
let rec add_features_of_tree feature_set map = function
| `CategoricalNode { cn_feature_id; cn_left_tree; cn_right_tree } ->
let feature = Feat_map.i_find_by_id feature_set cn_feature_id in
let map = Utils.IntMap.add cn_feature_id feature map in
let map = add_features_of_tree feature_set map cn_left_tree in
let map = add_features_of_tree feature_set map cn_right_tree in
map
| `OrdinalNode { on_feature_id; on_left_tree; on_right_tree } ->
let feature = Feat_map.i_find_by_id feature_set on_feature_id in
let map = Utils.IntMap.add on_feature_id feature map in
let map = add_features_of_tree feature_set map on_left_tree in
let map = add_features_of_tree feature_set map on_right_tree in
map
| `Leaf _ -> map
(* as a performance optimization, create a map containing only the
features referenced by the trees; this is presumeably a (much)
smaller map that the (misnamed) [feature_set]. *)
let id_to_feature feature_set trees =
List.fold_left (
fun map tree ->
add_features_of_tree feature_set map tree
) Utils.IntMap.empty trees
let l_to_c feature_set trees =
let id_to_feature = id_to_feature feature_set trees in
let features = Utils.IntMap.fold (
fun feature_id feature features ->
(feature_d_to_m feature) :: features
) id_to_feature [] in
let id_to_breakpoints = id_to_breakpoints id_to_feature in
let trees = List.map (tree_l_to_c id_to_breakpoints) trees in
trees, features
let rec tree_rle_to_array = function
| (`Leaf _) as leaf -> leaf
| `CategoricalNode {
cn_feature_id;
cn_category_directions;
cn_left_tree;
cn_right_tree;
} ->
let cn_category_directions = category_array_of_rle cn_category_directions in
let cn_left_tree = tree_rle_to_array cn_left_tree in
let cn_right_tree = tree_rle_to_array cn_right_tree in
`CategoricalNode {
cn_feature_id;
cn_category_directions;
cn_left_tree;
cn_right_tree
}
| `OrdinalNode { on_feature_id; on_split; on_left_tree; on_right_tree } ->
let on_left_tree = tree_rle_to_array on_left_tree in
let on_right_tree = tree_rle_to_array on_right_tree in
`OrdinalNode { on_feature_id; on_split; on_left_tree; on_right_tree }
let rle_to_array trees =
List.rev_map tree_rle_to_array trees