Skip to content

Commit

Permalink
Infer num_class and n_estimators through tree_info
Browse files Browse the repository at this point in the history
Signed-off-by: BowenBao <bowbao@microsoft.com>
  • Loading branch information
BowenBao committed Dec 6, 2021
1 parent 9829d7d commit 38e9371
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 28 deletions.
41 changes: 19 additions & 22 deletions onnxmltools/convert/xgboost/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,23 @@ def _append_covers(node):


def _get_attributes(booster):
# num_class
state = booster.__getstate__()
bstate = bytes(state['handle'])
reg = re.compile(b'("tree_info":\\[[0-9,]*\\])')
objs = list(set(reg.findall(bstate)))
assert len(objs) == 1, 'Missing required property "tree_info".'
tree_info = json.loads("{{{}}}".format(objs[0].decode('ascii')))['tree_info']
num_class = len(set(tree_info))

atts = booster.attributes()
dp = booster.get_dump(dump_format='json', with_stats=True)
res = [json.loads(d) for d in dp]
trees = len(res)
try:
ntrees = booster.best_ntree_limit
except AttributeError:
ntrees = trees
ntrees = trees // num_class if num_class > 0 else trees
kwargs = atts.copy()
kwargs['feature_names'] = booster.feature_names
kwargs['n_estimators'] = ntrees
Expand All @@ -46,34 +55,22 @@ def _get_attributes(booster):

if all(map(lambda x: int(x) == x, set(covs))):
# regression
kwargs['num_target'] = num_class
kwargs['num_class'] = 0
if trees > ntrees > 0:
kwargs['num_target'] = trees // ntrees
kwargs["objective"] = "reg:squarederror"
else:
kwargs['num_target'] = 1
kwargs["objective"] = "reg:squarederror"
kwargs["objective"] = "reg:squarederror"
else:
# classification
kwargs['num_target'] = 0
if trees > ntrees > 0:
state = booster.__getstate__()
bstate = bytes(state['handle'])
kwargs['num_class'] = num_class
if num_class != 1:
reg = re.compile(b'(multi:[a-z]{1,15})')
objs = list(set(reg.findall(bstate)))
if len(objs) != 1:
if '"name":"binary:logistic"' in str(bstate):
kwargs['num_class'] = 1
kwargs["objective"] = "binary:logistic"
else:
raise RuntimeError(
"Unable to guess objective in %r (trees=%r, ntrees=%r)"
"." % (objs, trees, ntrees))
else:
kwargs['num_class'] = trees // ntrees
if len(objs) == 1:
kwargs["objective"] = objs[0].decode('ascii')
else:
raise RuntimeError(
"Unable to guess objective in %r (trees=%r, ntrees=%r, num_class=%r)"
"." % (objs, trees, ntrees, kwargs['num_class']))
else:
kwargs['num_class'] = 1
kwargs["objective"] = "binary:logistic"

if 'base_score' not in kwargs:
Expand Down
8 changes: 2 additions & 6 deletions tests/xgboost/test_xgboost_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_xgboost_booster_classifier_multiclass_softmax(self):
random_state=42, n_informative=3)
x_train, x_test, y_train, _ = train_test_split(x, y, test_size=0.5,
random_state=42)

data = DMatrix(x_train, label=y_train)
model = train({'objective': 'multi:softmax',
'n_estimators': 3, 'min_child_samples': 1,
Expand Down Expand Up @@ -317,15 +317,13 @@ def test_xgb_best_tree_limit(self):
bst_original.save_model('model.json')

onx_loaded = convert_xgboost(bst_original, initial_types=initial_type)
# with open("model.onnx", "wb") as f:
# f.write(onx_loaded.SerializeToString())
sess = InferenceSession(onx_loaded.SerializeToString())
res = sess.run(None, {'float_input': X_test.astype(np.float32)})
assert_almost_equal(bst_original.predict(dtest, output_margin=True), res[1], decimal=5)
assert_almost_equal(bst_original.predict(dtest), res[0])

# After being restored, the loaded booster is not exactly the same
# in memory and the conversion fails to find the objective.
# in memory. `best_ntree_limit` is not saved during `save_model`.
bst_loaded = Booster()
bst_loaded.load_model('model.json')
bst_loaded.save_model('model2.json')
Expand All @@ -334,8 +332,6 @@ def test_xgb_best_tree_limit(self):
assert_almost_equal(bst_loaded.predict(dtest), bst_original.predict(dtest))

onx_loaded = convert_xgboost(bst_loaded, initial_types=initial_type)
# with open("model2.onnx", "wb") as f:
# f.write(onx_loaded.SerializeToString())
sess = InferenceSession(onx_loaded.SerializeToString())
res = sess.run(None, {'float_input': X_test.astype(np.float32)})
assert_almost_equal(bst_loaded.predict(dtest, output_margin=True), res[1], decimal=5)
Expand Down

0 comments on commit 38e9371

Please sign in to comment.