Skip to content

Commit

Permalink
update with main branch
Browse files Browse the repository at this point in the history
Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
  • Loading branch information
xadupre committed Sep 14, 2023
1 parent a281f01 commit d45ad7a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 18 deletions.
5 changes: 0 additions & 5 deletions skl2onnx/algebra/automation.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,6 @@ def format_name_with_domain(sch):
return "{} ({})".format(sch.name, sch.domain)
return sch.name

def get_type_str(obj):
if hasattr(obj, "type_str"):
return obj.type_str
return obj.typeStr

def get_is_homogeneous(obj):
try:
return obj.is_homogeneous
Expand Down
2 changes: 1 addition & 1 deletion skl2onnx/operator_converters/one_hot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def convert_sklearn_one_hot_encoder(
if afeat:
index_name = scope.get_unique_variable_name(name + str(index_in))
container.add_initializer(
index_name, onnx_proto.TensorProto.INT64, [], [index_in]
index_name, onnx_proto.TensorProto.INT64, [1], [index_in]
)
out_name = scope.get_unique_variable_name(name + str(index_in))
container.add_node(
Expand Down
26 changes: 14 additions & 12 deletions tests/test_issue_shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,22 @@ def test_shape_inference(self):
)

# ReferenceEvaluator
ref = ReferenceEvaluator(model_onnx, verbose=9)
res = ref.run(None, feeds)
self.assertEqual(1, len(res))
self.assertEqual(expected.shape, res[0].shape)
assert_almost_equal(expected, res[0])
with self.subTest(engine="onnx"):
ref = ReferenceEvaluator(model_onnx, verbose=9)
res = ref.run(None, feeds)
self.assertEqual(1, len(res))
self.assertEqual(expected.shape, res[0].shape)
assert_almost_equal(expected, res[0])

# onnxruntime
sess = rt.InferenceSession(
model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
)
res = sess.run(None, feeds)
self.assertEqual(1, len(res))
self.assertEqual(expected.shape, res[0].shape)
assert_almost_equal(expected, res[0])
with self.subTest(engine="onnxruntime"):
sess = rt.InferenceSession(
model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
)
res = sess.run(None, feeds)
self.assertEqual(1, len(res))
self.assertEqual(expected.shape, res[0].shape)
assert_almost_equal(expected, res[0])


if __name__ == "__main__":
Expand Down

0 comments on commit d45ad7a

Please sign in to comment.