Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding tf.strings.reduce_join mapping #2091

Merged
merged 13 commits into from
Mar 16, 2023
8 changes: 8 additions & 0 deletions tests/test_string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def func(text1, text2, text3):
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val1, _INPUT1: text_val2, _INPUT2: text_val3})

@requires_custom_ops("ReduceJoin")
def test_reduce_join(self):
text_val = np.array([["a", "Test 1 2 3"], ["b", "test test"], ["c", "Hi there Test"]], dtype=np.str)
def func(text):
x_ = tf.strings.reduce_join(text, axis=1, separator="±")
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: text_val})

@requires_custom_ops("StringSplit")
@check_tf_min_version("2.0", "result is sparse not ragged in tf1")
def test_string_split(self):
Expand Down
25 changes: 25 additions & 0 deletions tf2onnx/custom_opsets/string_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import numpy as np
from onnx.numpy_helper import to_array
from onnx.onnx_pb import TensorProto
from onnx.helper import make_attribute
from tf2onnx import constants, handler
Expand Down Expand Up @@ -86,6 +87,30 @@ def version_1(cls, ctx, node, **kwargs):
stack_node = ctx.make_node("Concat", unsqueezes, attr={'axis': 0})
ctx.replace_inputs(node, [stack_node.output[0], separator_node.output[0], axis_node.output[0]])

@tf_op("ReduceJoin", domain=constants.CONTRIB_OPS_DOMAIN)
class ReduceJoin:
@classmethod
def version_1(cls, ctx, node, **kwargs):
node.domain = constants.CONTRIB_OPS_DOMAIN
node.type = "StringJoin"
axis_node = ctx.get_node_by_output(node.input[1])
axis = axis_node.get_attr_value('value')
utils.make_sure(axis.dims in [[], [1]], "Only a single axis is supported for ReduceJoin node")
axis = to_array(axis)
new_axis_node = ctx.make_const(utils.make_name("axis"), np.array(axis, np.int64).reshape((1)))
separator = node.get_attr_value("separator")
if isinstance(separator, bytes):
separator = separator.decode()
separator_node = ctx.make_const(utils.make_name("separator"), np.array([separator], object))
ctx.replace_inputs(node, [node.input[0], separator_node.output[0], new_axis_node.output[0]])
keep_dims = node.get_attr_value("keep_dims")
if keep_dims:
unsqueeze_node = GraphBuilder(ctx).make_unsqueeze(
{'data': node.output[0], 'axes': [-1]},
name=node.name + '/Unsqueeze'
)
ctx.insert_node_on_output(ctx.get_node_by_output(unsqueeze_node))

@tf_op(["Equal", "NotEqual"], domain=constants.CONTRIB_OPS_DOMAIN)
class StringEqual:
@classmethod
Expand Down