Skip to content

Commit

Permalink
Fix discovering classifier objective (#480)
Browse files Browse the repository at this point in the history
* Fix scoring classifier objective
  • Loading branch information
xadupre authored Aug 19, 2021
1 parent d981dac commit c1651b1
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions onnxmltools/convert/xgboost/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,16 @@ def _get_attributes(booster):
reg = re.compile(b'(multi:[a-z]{1,15})')
objs = list(set(reg.findall(bstate)))
if len(objs) != 1:
raise RuntimeError(
"Unable to guess objective in {}.".format(objs))
kwargs['num_class'] = trees // ntrees
kwargs["objective"] = objs[0].decode('ascii')
if '"name":"binary:logistic"' in str(bstate):
kwargs['num_class'] = 1
kwargs["objective"] = "binary:logistic"
else:
raise RuntimeError(
"Unable to guess objective in %r (trees=%r, ntrees=%r)"
"." % (objs, trees, ntrees))
else:
kwargs['num_class'] = trees // ntrees
kwargs["objective"] = objs[0].decode('ascii')
else:
kwargs['num_class'] = 1
kwargs["objective"] = "binary:logistic"
Expand Down

0 comments on commit c1651b1

Please sign in to comment.