Skip to content

Commit

Permalink
add target_opset to test cases
Browse files Browse the repository at this point in the history
Signed-off-by: BowenBao <bowbao@microsoft.com>
  • Loading branch information
BowenBao committed Dec 6, 2021
1 parent 38e9371 commit badb290
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/xgboost/test_xgboost_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ def test_xgb_best_tree_limit(self):
initial_type = [('float_input', FloatTensorType([None, 4]))]
bst_original.save_model('model.json')

onx_loaded = convert_xgboost(bst_original, initial_types=initial_type)
onx_loaded = convert_xgboost(
bst_original, initial_types=initial_type,
target_opset=TARGET_OPSET)
sess = InferenceSession(onx_loaded.SerializeToString())
res = sess.run(None, {'float_input': X_test.astype(np.float32)})
assert_almost_equal(bst_original.predict(dtest, output_margin=True), res[1], decimal=5)
Expand All @@ -331,7 +333,9 @@ def test_xgb_best_tree_limit(self):
bst_original.predict(dtest, output_margin=True), decimal=5)
assert_almost_equal(bst_loaded.predict(dtest), bst_original.predict(dtest))

onx_loaded = convert_xgboost(bst_loaded, initial_types=initial_type)
onx_loaded = convert_xgboost(
bst_loaded, initial_types=initial_type,
target_opset=TARGET_OPSET)
sess = InferenceSession(onx_loaded.SerializeToString())
res = sess.run(None, {'float_input': X_test.astype(np.float32)})
assert_almost_equal(bst_loaded.predict(dtest, output_margin=True), res[1], decimal=5)
Expand Down

0 comments on commit badb290

Please sign in to comment.