From d45ad7ac5543f0e6490745b10bde4857d7f55f93 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Thu, 14 Sep 2023 17:52:30 +0200 Subject: [PATCH] update with main branch Signed-off-by: Xavier Dupre --- skl2onnx/algebra/automation.py | 5 ---- .../operator_converters/one_hot_encoder.py | 2 +- tests/test_issue_shape_inference.py | 26 ++++++++++--------- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/skl2onnx/algebra/automation.py b/skl2onnx/algebra/automation.py index c667884ef..6f8c47a74 100644 --- a/skl2onnx/algebra/automation.py +++ b/skl2onnx/algebra/automation.py @@ -152,11 +152,6 @@ def format_name_with_domain(sch): return "{} ({})".format(sch.name, sch.domain) return sch.name - def get_type_str(obj): - if hasattr(obj, "type_str"): - return obj.type_str - return obj.typeStr - def get_is_homogeneous(obj): try: return obj.is_homogeneous diff --git a/skl2onnx/operator_converters/one_hot_encoder.py b/skl2onnx/operator_converters/one_hot_encoder.py index f57f4360e..6ddead657 100644 --- a/skl2onnx/operator_converters/one_hot_encoder.py +++ b/skl2onnx/operator_converters/one_hot_encoder.py @@ -97,7 +97,7 @@ def convert_sklearn_one_hot_encoder( if afeat: index_name = scope.get_unique_variable_name(name + str(index_in)) container.add_initializer( - index_name, onnx_proto.TensorProto.INT64, [], [index_in] + index_name, onnx_proto.TensorProto.INT64, [1], [index_in] ) out_name = scope.get_unique_variable_name(name + str(index_in)) container.add_node( diff --git a/tests/test_issue_shape_inference.py b/tests/test_issue_shape_inference.py index 1af8d5093..6fe22c992 100644 --- a/tests/test_issue_shape_inference.py +++ b/tests/test_issue_shape_inference.py @@ -89,20 +89,22 @@ def test_shape_inference(self): ) # ReferenceEvaluator - ref = ReferenceEvaluator(model_onnx, verbose=9) - res = ref.run(None, feeds) - self.assertEqual(1, len(res)) - self.assertEqual(expected.shape, res[0].shape) - assert_almost_equal(expected, res[0]) + with self.subTest(engine="onnx"): + ref = ReferenceEvaluator(model_onnx, verbose=9) + res = ref.run(None, feeds) + self.assertEqual(1, len(res)) + self.assertEqual(expected.shape, res[0].shape) + assert_almost_equal(expected, res[0]) # onnxruntime - sess = rt.InferenceSession( - model_onnx.SerializeToString(), providers=["CPUExecutionProvider"] - ) - res = sess.run(None, feeds) - self.assertEqual(1, len(res)) - self.assertEqual(expected.shape, res[0].shape) - assert_almost_equal(expected, res[0]) + with self.subTest(engine="onnxruntime"): + sess = rt.InferenceSession( + model_onnx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + res = sess.run(None, feeds) + self.assertEqual(1, len(res)) + self.assertEqual(expected.shape, res[0].shape) + assert_almost_equal(expected, res[0]) if __name__ == "__main__":