Skip to content

Commit

Permalink
Change Equal 11 for string input (#2149)
Browse files Browse the repository at this point in the history
* Change Equal 11 for string input
* Unify the dtype of all of inputs and add backend test

---------

Signed-off-by: Mike Essenmacher <essen@us.ibm.com>
  • Loading branch information
mikeessen authored Apr 14, 2023
1 parent aaaea95 commit 276bdea
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
8 changes: 8 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,14 @@ def func(x1, x2):
return tf.identity(mi, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})

def test_equal_string(self):
x_val1 = np.array(['1'], dtype=np.string_)
x_val2 = np.array(['2'], dtype=np.string_)
def func(x1, x2):
mi = tf.equal(x1, x2)
return tf.identity(mi, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})

def test_equal(self):
x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2))
x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2))
Expand Down
32 changes: 27 additions & 5 deletions tf2onnx/onnx_opset/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,21 @@ def _add_cast_to_inputs(graph, node, supported_dtypes, target_dtype):
graph.copy_shape(inp, inp_cast.output[0])
graph.set_dtype(inp_cast.output[0], target_dtype)


def _add_cast_to_same_type_to_inputs(graph, node):
def _add_cast_to_same_type_to_inputs(graph, node, supported_dtypes, target_dtype):
common_dtype = graph.get_dtype(node.input[0])
if common_dtype not in supported_dtypes:
common_dtype = target_dtype

for inp in node.input[1:]:
for inp in node.input:
if graph.get_dtype(inp) != common_dtype:
inp_cast = graph.insert_new_node_on_input(node, "Cast", inp, to=common_dtype)
graph.copy_shape(inp, inp_cast.output[0])
graph.set_dtype(inp_cast.output[0], common_dtype)
if graph.is_const(inp) and graph.get_tensor_value(inp) == '':
# Convert '' string constant to -1 int
# https://github.com/tensorflow/tensorflow/blob/4e7f0185c70faf35e12acbfe381a729d1e6cc38c/tensorflow/python/feature_column/feature_column.py#L2286
const_node = graph.get_node_by_output(inp)
const_node.set_tensor_value(utils.np.array(-1))


@tf_op("LogicalNot", onnx_op="Not")
Expand Down Expand Up @@ -92,8 +98,24 @@ def version_7(cls, ctx, node, **kwargs):

@classmethod
def version_11(cls, ctx, node, **kwargs):
# starting with opset-11, equal supports all types (but both operands must be of the same type)
_add_cast_to_same_type_to_inputs(ctx, node)
# starting with opset-11, equal supports all numerical types (but both operands must be of the same type)
# string type is not supported
supported_dtypes = [
TensorProto.BOOL,
TensorProto.DOUBLE,
TensorProto.FLOAT,
TensorProto.FLOAT16,
TensorProto.INT8,
TensorProto.INT16,
TensorProto.INT32,
TensorProto.INT64,
TensorProto.UINT8,
TensorProto.UINT16,
TensorProto.UINT32,
TensorProto.UINT64
]
target_dtype = TensorProto.INT32
_add_cast_to_same_type_to_inputs(ctx, node, supported_dtypes, target_dtype)
need_not = node.type == "NotEqual"
if need_not:
node.type = "Equal"
Expand Down

0 comments on commit 276bdea

Please sign in to comment.