From 0fdd7253b59ffa73a8e6022ad851b75b3d26602b Mon Sep 17 00:00:00 2001 From: Francesco Salvetti Date: Tue, 29 Nov 2022 09:41:54 -0500 Subject: [PATCH 1/2] Add support to tf.strings.reduce_join --- tf2onnx/custom_opsets/string_ops.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tf2onnx/custom_opsets/string_ops.py b/tf2onnx/custom_opsets/string_ops.py index 303fcd94b..2c5767e33 100644 --- a/tf2onnx/custom_opsets/string_ops.py +++ b/tf2onnx/custom_opsets/string_ops.py @@ -5,6 +5,7 @@ import json import logging import numpy as np +from onnx.numpy_helper import to_array from onnx.onnx_pb import TensorProto from onnx.helper import make_attribute from tf2onnx import constants, handler @@ -86,6 +87,32 @@ def version_1(cls, ctx, node, **kwargs): stack_node = ctx.make_node("Concat", unsqueezes, attr={'axis': 0}) ctx.replace_inputs(node, [stack_node.output[0], separator_node.output[0], axis_node.output[0]]) +@tf_op("ReduceJoin", domain=constants.CONTRIB_OPS_DOMAIN) +class ReduceJoin: + @classmethod + def version_1(cls, ctx, node, **kwargs): + node.domain = constants.CONTRIB_OPS_DOMAIN + node.type = "StringJoin" + + axis_node = ctx.get_node_by_output(node.input[1]) + axis = axis_node.get_attr_value('value') + if axis.dims not in [[], [1]]: + raise TypeError("Onnx ReduceJoin operation supports a single axis, only.") + axis = to_array(axis) + new_axis_node = ctx.make_const(utils.make_name("axis"), np.array(axis, np.int64).reshape((1))) + + separator = node.get_attr_value("separator") + if isinstance(separator, bytes): + separator = separator.decode() + separator_node = ctx.make_const(utils.make_name("separator"), np.array([separator], object)) + + ctx.replace_inputs(node, [node.input[0], separator_node.output[0], new_axis_node.output[0]]) + + keep_dims = node.get_attr_value("keep_dims") + if keep_dims: + unsqueeze_node = GraphBuilder(ctx).make_unsqueeze({'data': node.output[0], 'axes': [-1]}, name=node.name + '/Unsqueeze') + ctx.insert_node_on_output(ctx.get_node_by_output(unsqueeze_node)) + @tf_op(["Equal", "NotEqual"], domain=constants.CONTRIB_OPS_DOMAIN) class StringEqual: @classmethod From 8612b1ac5fb94fa43e3d75288519e153430df6da Mon Sep 17 00:00:00 2001 From: Francesco Salvetti Date: Tue, 29 Nov 2022 10:05:39 -0500 Subject: [PATCH 2/2] Support for ReduceJoin op --- tf2onnx/custom_opsets/string_ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tf2onnx/custom_opsets/string_ops.py b/tf2onnx/custom_opsets/string_ops.py index 2c5767e33..bd3b27f48 100644 --- a/tf2onnx/custom_opsets/string_ops.py +++ b/tf2onnx/custom_opsets/string_ops.py @@ -96,8 +96,7 @@ def version_1(cls, ctx, node, **kwargs): axis_node = ctx.get_node_by_output(node.input[1]) axis = axis_node.get_attr_value('value') - if axis.dims not in [[], [1]]: - raise TypeError("Onnx ReduceJoin operation supports a single axis, only.") + utils.make_sure(axis.dims in [[], [1]], "Only a single axis is supported for ReduceJoin node") axis = to_array(axis) new_axis_node = ctx.make_const(utils.make_name("axis"), np.array(axis, np.int64).reshape((1)))