From 4a2a332324fe165065ec51e30546ec012a1c2217 Mon Sep 17 00:00:00 2001 From: bammari Date: Mon, 17 Jun 2024 17:28:05 +0300 Subject: [PATCH 1/3] added lgbm reader --- src/omlt/linear_tree/__init__.py | 1 + src/omlt/linear_tree/gblt_model.py | 291 +++++++++++++++++++++++++++++ 2 files changed, 292 insertions(+) create mode 100644 src/omlt/linear_tree/gblt_model.py diff --git a/src/omlt/linear_tree/__init__.py b/src/omlt/linear_tree/__init__.py index 2f89a669..36de6bc6 100644 --- a/src/omlt/linear_tree/__init__.py +++ b/src/omlt/linear_tree/__init__.py @@ -23,3 +23,4 @@ LinearTreeGDPFormulation, LinearTreeHybridBigMFormulation, ) +from omlt.linear_tree.gblt_model import EnsembleDefinition diff --git a/src/omlt/linear_tree/gblt_model.py b/src/omlt/linear_tree/gblt_model.py new file mode 100644 index 00000000..c5f779ca --- /dev/null +++ b/src/omlt/linear_tree/gblt_model.py @@ -0,0 +1,291 @@ +import lightgbm as lgb +import numpy as np + + +class EnsembleDefinition: + def __init__(self, gblt_model, scaling_object=None, scaled_input_bounds=None): + """ + Create a network definition object used to create the gradient-boosted trees + formulation in Pyomo + + Args: + onnx_model : ONNX Model + An ONNX model that is generated by the ONNX convert function for + lightgbm. + scaling_object : ScalingInterface or None + A scaling object to specify the scaling parameters for the + tree ensemble inputs and outputs. If None, then no + scaling is performed. + scaled_input_bounds : dict or None + A dict that contains the bounds on the scaled variables (the + direct inputs to the tree ensemble). If None, then no bounds + are specified or they are generated using unscaled bounds. + """ + self._model = gblt_model + n_inputs = find_n_inputs(gblt_model) + self.n_inputs = n_inputs + self.n_outputs = 1 + self.splits, self.leaves, self.thresholds =\ + parse_model(gblt_model, scaled_input_bounds, n_inputs) + self.scaling_object = scaling_object + self.scaled_input_bounds = scaled_input_bounds + + +def find_all_children_splits(split, splits_dict): + """ + This helper function finds all multigeneration children splits for an + argument split. + + Arguments: + split --The split for which you are trying to find children splits + splits_dict -- A dictionary of all the splits in the tree + + Returns: + A list containing the Node IDs of all children splits + """ + all_splits = [] + + # Check if the immediate left child of the argument split is also a split. + # If so append to the list then use recursion to generate the remainder + left_child = splits_dict[split]['children'][0] + if left_child in splits_dict: + all_splits.append(left_child) + all_splits.extend(find_all_children_splits(left_child, splits_dict)) + + # Same as above but with right child + right_child = splits_dict[split]['children'][1] + if right_child in splits_dict: + all_splits.append(right_child) + all_splits.extend(find_all_children_splits(right_child, splits_dict)) + + return all_splits + +def _reassign_none_bounds(leaves, input_bounds, n_inputs): + """ + This helper function reassigns bounds that are None to the bounds + input by the user + + Arguments: + leaves -- The dictionary of leaf information. Attribute of the + LinearTreeDefinition object + input_bounds -- The nested dictionary + + Returns: + The modified leaves dict without any bounds that are listed as None + """ + leaf_indices = np.array(list(leaves.keys())) + features = np.arange(0, n_inputs) + + for leaf in leaf_indices: + slopes = leaves[leaf]["slope"] + if len(slopes) == 0: + leaves[leaf]["slope"] = list(np.zeros(n_inputs)) + for feat in features: + if leaves[leaf]["bounds"][feat][0] is None: + leaves[leaf]["bounds"][feat][0] = input_bounds[feat][0] + if leaves[leaf]["bounds"][feat][1] is None: + leaves[leaf]["bounds"][feat][1] = input_bounds[feat][1] + + return leaves + +def find_all_children_leaves(split, splits_dict, leaves_dict): + """ + This helper function finds all multigeneration children leaves for an + argument split. + + Arguments: + split -- The split for which you are trying to find children leaves + splits_dict -- A dictionary of all the split info in the tree + leaves_dict -- A dictionary of all the leaf info in the tree + + Returns: + A list containing all the Node IDs of all children leaves + """ + all_leaves = [] + + # Find all the splits that are children of the relevant split + all_splits = find_all_children_splits(split, splits_dict) + + # Ensure the current split is included + if split not in all_splits: + all_splits.append(split) + + # For each leaf, check if the parents appear in the list of children + # splits (all_splits). If so, it must be a leaf of the argument split + for leaf in leaves_dict: + if leaves_dict[leaf]['parent'] in all_splits: + all_leaves.append(leaf) + + return all_leaves + + +def find_n_inputs(model): + if str(type(model)) == "": + n_inputs = model.num_feature() + else: + n_inputs = len(model["feature_names"]) + + return n_inputs + +def parse_model(model, input_bounds, n_inputs): + if str(type(model)) == "": + whole_model = model.dump_model() + # import pprint as pp + # pp.pprint(whole_model) + else: + whole_model=model + + tree = {} + for i in range(whole_model['tree_info'][-1]['tree_index']+1): + + node = whole_model['tree_info'][i]["tree_structure"] + + queue = [node] + splits = {} + + # the very first node + splits["split"+str(queue[0]["split_index"])] = {'th': queue[0]["threshold"], + 'col': queue[0]["split_feature"] } + + # flow though the tree + while queue: + + # left child + if "left_child" in queue[0].keys(): + queue.append(queue[0]["left_child"]) + # child is a split + if "split_index" in queue[0]["left_child"].keys(): + splits["split"+str(queue[0]["left_child"]["split_index"])] = {'parent': "split"+str(queue[0]["split_index"]), + 'direction': 'left', + 'th': queue[0]["left_child"]["threshold"], + 'col': queue[0]["left_child"]["split_feature"]} + # child is a leaf + else: + splits["leaf"+str(queue[0]["left_child"]["leaf_index"])] = {'parent': "split"+str(queue[0]["split_index"]), + 'direction': 'left', + 'intercept': queue[0]["left_child"]["leaf_const"], + 'slope': list(np.zeros(n_inputs))} + + for idx, val in zip(queue[0]["left_child"]["leaf_features"], queue[0]["left_child"]["leaf_coeff"]): + splits["leaf"+str(queue[0]["left_child"]["leaf_index"])]["slope"][idx] = val + + # right child + if "right_child" in queue[0].keys(): + queue.append(queue[0]["right_child"]) + # child is a split + if "split_index" in queue[0]["right_child"].keys(): + splits["split"+str(queue[0]["right_child"]["split_index"])] = {'parent': "split"+str(queue[0]["split_index"]), + 'direction': 'right', + 'th': queue[0]["right_child"]["threshold"], + 'col': queue[0]["right_child"]["split_feature"]} + # child is a leaf + else: + splits["leaf"+str(queue[0]["right_child"]["leaf_index"])] = {'parent': "split"+str(queue[0]["split_index"]), + 'direction': 'right', + 'intercept': queue[0]["right_child"]["leaf_const"], + 'slope': list(np.zeros(n_inputs))} + + for idx, val in zip(queue[0]["right_child"]["leaf_features"], queue[0]["right_child"]["leaf_coeff"]): + splits["leaf"+str(queue[0]["right_child"]["leaf_index"])]["slope"][idx] = val + # delet the first node + queue.pop(0) + + tree['tree'+str(i)] = splits + + nested_splits = {} + nested_leaves = {} + nested_thresholds = {} + + for index in tree: + + splits = tree[index] + for i in splits: + # print(i) + if 'parent' in splits[i].keys(): + splits[splits[i]['parent']]['children'] = [] + + for i in splits: + # print(i) + if 'parent' in splits[i].keys(): + if splits[i]['direction'] == 'left': + splits[splits[i]['parent']]['children'].insert(0,i) + if splits[i]['direction'] == 'right': + splits[splits[i]['parent']]['children'].insert(11,i) + + leaves = {} + for i in splits.keys(): + if i[0] == 'l': + leaves[i] = splits[i] + + for leaf in leaves: + del splits[leaf] + + for split in splits: + # print("split:" + str(split)) + left_child = splits[split]['children'][0] + right_child = splits[split]['children'][1] + + if left_child in splits: + # means left_child is split + splits[split]['left_leaves'] = find_all_children_leaves( + left_child, splits, leaves + ) + else: + # means left_child is leaf + splits[split]['left_leaves'] = [left_child] + # print("left_child" + str(left_child)) + + if right_child in splits: + splits[split]['right_leaves'] = find_all_children_leaves( + right_child, splits, leaves + ) + else: + splits[split]['right_leaves'] = [right_child] + # print("right_child" + str(right_child)) + + splitting_thresholds = {} + for split in splits: + var = splits[split]['col'] + splitting_thresholds[var] = {} + for split in splits: + var = splits[split]['col'] + splitting_thresholds[var][split] = splits[split]['th'] + + for var in splitting_thresholds: + splitting_thresholds[var] = dict(sorted(splitting_thresholds[var].items(), key=lambda x: x[1])) + + for split in splits: + var = splits[split]['col'] + splits[split]['y_index'] = [] + splits[split]['y_index'].append(var) + splits[split]['y_index'].append( + list(splitting_thresholds[var]).index(split) + ) + + features = np.arange(0,n_inputs) + + for leaf in leaves: + leaves[leaf]['bounds'] = {} + for th in features: + for leaf in leaves: + leaves[leaf]['bounds'][th] = [None, None] + + # import pprint + # pp = pprint.PrettyPrinter(indent=4) + # pp.pprint(splits) + # pp.pprint(leaves) + for split in splits: + var = splits[split]['col'] + for leaf in splits[split]['left_leaves']: + leaves[leaf]['bounds'][var][1] = splits[split]['th'] + + for leaf in splits[split]['right_leaves']: + leaves[leaf]['bounds'][var][0] = splits[split]['th'] + + leaves = _reassign_none_bounds(leaves, input_bounds, n_inputs) + + nested_splits[str(index)] = splits + nested_leaves[str(index)] = leaves + nested_thresholds[str(index)] = splitting_thresholds + + return nested_splits, nested_leaves, nested_thresholds \ No newline at end of file From 1e02245e7d74633d140be64946a83cc287557683 Mon Sep 17 00:00:00 2001 From: bammari Date: Wed, 19 Jun 2024 17:23:19 +0300 Subject: [PATCH 2/3] Updating code for repository --- src/omlt/linear_tree/gblt_model.py | 107 +++++++++++++++++++++++------ 1 file changed, 87 insertions(+), 20 deletions(-) diff --git a/src/omlt/linear_tree/gblt_model.py b/src/omlt/linear_tree/gblt_model.py index c5f779ca..194362e0 100644 --- a/src/omlt/linear_tree/gblt_model.py +++ b/src/omlt/linear_tree/gblt_model.py @@ -3,7 +3,13 @@ class EnsembleDefinition: - def __init__(self, gblt_model, scaling_object=None, scaled_input_bounds=None): + def __init__( + self, + gblt_model, + scaling_object=None, + unscaled_input_bounds=None, + scaled_input_bounds=None + ): """ Create a network definition object used to create the gradient-boosted trees formulation in Pyomo @@ -21,17 +27,78 @@ def __init__(self, gblt_model, scaling_object=None, scaled_input_bounds=None): direct inputs to the tree ensemble). If None, then no bounds are specified or they are generated using unscaled bounds. """ - self._model = gblt_model - n_inputs = find_n_inputs(gblt_model) - self.n_inputs = n_inputs - self.n_outputs = 1 - self.splits, self.leaves, self.thresholds =\ - parse_model(gblt_model, scaled_input_bounds, n_inputs) - self.scaling_object = scaling_object - self.scaled_input_bounds = scaled_input_bounds - - -def find_all_children_splits(split, splits_dict): + self.__model = gblt_model + self.__scaling_object = scaling_object + + # Process input bounds to insure scaled input bounds exist for formulations + if scaled_input_bounds is None: + if unscaled_input_bounds is not None and scaling_object is not None: + lbs = scaling_object.get_scaled_input_expressions( + {k: t[0] for k, t in unscaled_input_bounds.items()} + ) + ubs = scaling_object.get_scaled_input_expressions( + {k: t[1] for k, t in unscaled_input_bounds.items()} + ) + + scaled_input_bounds = { + k: (lbs[k], ubs[k]) for k in unscaled_input_bounds.keys() + } + + # If unscaled input bounds provided and no scaler provided, scaled + # input bounds = unscaled input bounds + elif unscaled_input_bounds is not None and scaling_object is None: + scaled_input_bounds = unscaled_input_bounds + elif unscaled_input_bounds is None: + raise ValueError( + "Input Bounds needed to represent linear trees as MIPs" + ) + + self.__unscaled_input_bounds = unscaled_input_bounds + self.__scaled_input_bounds = scaled_input_bounds + + n_inputs = _find_n_inputs(gblt_model) + self.__n_inputs = n_inputs + self.__n_outputs = 1 + self.__splits, self.__leaves, self.__thresholds =\ + _parse_model(gblt_model, scaled_input_bounds, n_inputs) + + @property + def scaling_object(self): + """Returns scaling object""" + return self.__scaling_object + + @property + def scaled_input_bounds(self): + """Returns dict containing scaled input bounds""" + return self.__scaled_input_bounds + + @property + def splits(self): + """Returns dict containing split information""" + return self.__splits + + @property + def leaves(self): + """Returns dict containing leaf information""" + return self.__leaves + + @property + def thresholds(self): + """Returns dict containing threshold information""" + return self.__thresholds + + @property + def n_inputs(self): + """Returns number of inputs to the linear tree""" + return self.__n_inputs + + @property + def n_outputs(self): + """Returns number of outputs to the linear tree""" + return self.__n_outputs + + +def _find_all_children_splits(split, splits_dict): """ This helper function finds all multigeneration children splits for an argument split. @@ -50,13 +117,13 @@ def find_all_children_splits(split, splits_dict): left_child = splits_dict[split]['children'][0] if left_child in splits_dict: all_splits.append(left_child) - all_splits.extend(find_all_children_splits(left_child, splits_dict)) + all_splits.extend(_find_all_children_splits(left_child, splits_dict)) # Same as above but with right child right_child = splits_dict[split]['children'][1] if right_child in splits_dict: all_splits.append(right_child) - all_splits.extend(find_all_children_splits(right_child, splits_dict)) + all_splits.extend(_find_all_children_splits(right_child, splits_dict)) return all_splits @@ -88,7 +155,7 @@ def _reassign_none_bounds(leaves, input_bounds, n_inputs): return leaves -def find_all_children_leaves(split, splits_dict, leaves_dict): +def _find_all_children_leaves(split, splits_dict, leaves_dict): """ This helper function finds all multigeneration children leaves for an argument split. @@ -104,7 +171,7 @@ def find_all_children_leaves(split, splits_dict, leaves_dict): all_leaves = [] # Find all the splits that are children of the relevant split - all_splits = find_all_children_splits(split, splits_dict) + all_splits = _find_all_children_splits(split, splits_dict) # Ensure the current split is included if split not in all_splits: @@ -119,7 +186,7 @@ def find_all_children_leaves(split, splits_dict, leaves_dict): return all_leaves -def find_n_inputs(model): +def _find_n_inputs(model): if str(type(model)) == "": n_inputs = model.num_feature() else: @@ -127,7 +194,7 @@ def find_n_inputs(model): return n_inputs -def parse_model(model, input_bounds, n_inputs): +def _parse_model(model, input_bounds, n_inputs): if str(type(model)) == "": whole_model = model.dump_model() # import pprint as pp @@ -227,7 +294,7 @@ def parse_model(model, input_bounds, n_inputs): if left_child in splits: # means left_child is split - splits[split]['left_leaves'] = find_all_children_leaves( + splits[split]['left_leaves'] = _find_all_children_leaves( left_child, splits, leaves ) else: @@ -236,7 +303,7 @@ def parse_model(model, input_bounds, n_inputs): # print("left_child" + str(left_child)) if right_child in splits: - splits[split]['right_leaves'] = find_all_children_leaves( + splits[split]['right_leaves'] = _find_all_children_leaves( right_child, splits, leaves ) else: From c555d62b1bf5f5f3a22decd2d63db897c74d1dcf Mon Sep 17 00:00:00 2001 From: bammari Date: Wed, 19 Jun 2024 17:24:58 +0300 Subject: [PATCH 3/3] Update docs --- src/omlt/linear_tree/gblt_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/omlt/linear_tree/gblt_model.py b/src/omlt/linear_tree/gblt_model.py index 194362e0..6a4924e0 100644 --- a/src/omlt/linear_tree/gblt_model.py +++ b/src/omlt/linear_tree/gblt_model.py @@ -197,8 +197,6 @@ def _find_n_inputs(model): def _parse_model(model, input_bounds, n_inputs): if str(type(model)) == "": whole_model = model.dump_model() - # import pprint as pp - # pp.pprint(whole_model) else: whole_model=model