Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Equal 11 for string input #2149

Merged
merged 4 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,14 @@ def func(x1, x2):
return tf.identity(mi, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})

def test_equal_string(self):
x_val1 = np.array(['1'], dtype=np.string_)
x_val2 = np.array(['2'], dtype=np.string_)
def func(x1, x2):
mi = tf.equal(x1, x2)
return tf.identity(mi, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2})

def test_equal(self):
x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2))
x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2))
Expand Down
32 changes: 27 additions & 5 deletions tf2onnx/onnx_opset/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,21 @@ def _add_cast_to_inputs(graph, node, supported_dtypes, target_dtype):
graph.copy_shape(inp, inp_cast.output[0])
graph.set_dtype(inp_cast.output[0], target_dtype)


def _add_cast_to_same_type_to_inputs(graph, node):
def _add_cast_to_same_type_to_inputs(graph, node, supported_dtypes, target_dtype):
common_dtype = graph.get_dtype(node.input[0])
if common_dtype not in supported_dtypes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, I misunderstood the issue before.

Actually, the issue is: The op Equal doesn't support 'string' with version 11, meaning we should fail the conversion once we detect this case. Even the content is empty, we should not convert it to an integer value which might confuse users.

So it'd better to set up an unsupported op list which may only contain string right now. If we detect current type of input[0] is in it, we fail the conversion with a reasonable message instead of making tricky things to work around it.

Make sense?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the issue is: The op Equal doesn't support 'string' with version 11, meaning we should fail the conversion once we detect this case.

In the OpSet 7 version of this, the Equal Op works fine assuming the strings can be converted to the desired type. So this change makes OpSet11 match the OpSet7 behavior.

So if we made this an error, then OpSet 7 Equal and Opset 11 would behave differently but both have the same stance on strings. Making them behave different didn't seem right. Rather than remove what was working in OpSet 7, we decided to enable the same setup for OpSet 11. It could be argued to make OpSet7 (and earlier) also automatically error in this case instead. If that's truely preferred we could look into it. However our thought was to expand 11 rather than restrict 7.

Even the content is empty, we should not convert it to an integer value which might confuse users.

I'll agree that the empty string was a special case that almost lead us to restrict OpSet 7. However after digging into the TF code, we found they explicitly handle this case. When creating a feature column, if an entry is "" then it's explicitly changed to -1 (https://github.com/tensorflow/tensorflow/blob/4e7f0185c70faf35e12acbfe381a729d1e6cc38c/tensorflow/python/feature_column/feature_column.py#L2286). Since it's explicitly handled, we matched the TF behavior over just erroring out. Otherwise it is also confusing to have a model work in TF but fail to convert.

For reference we have a model encountering this scenario that indeed works fine in TF but then fails to convert. With these changes, it converts fine. Unfortunately it's not a model we can share but it is a real world scenario.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your detailed explanations.

common_dtype = target_dtype

for inp in node.input[1:]:
for inp in node.input:
if graph.get_dtype(inp) != common_dtype:
inp_cast = graph.insert_new_node_on_input(node, "Cast", inp, to=common_dtype)
graph.copy_shape(inp, inp_cast.output[0])
graph.set_dtype(inp_cast.output[0], common_dtype)
if graph.is_const(inp) and graph.get_tensor_value(inp) == '':
# Convert '' string constant to -1 int
# https://github.com/tensorflow/tensorflow/blob/4e7f0185c70faf35e12acbfe381a729d1e6cc38c/tensorflow/python/feature_column/feature_column.py#L2286
const_node = graph.get_node_by_output(inp)
const_node.set_tensor_value(utils.np.array(-1))
fatcat-z marked this conversation as resolved.
Show resolved Hide resolved


@tf_op("LogicalNot", onnx_op="Not")
Expand Down Expand Up @@ -92,8 +98,24 @@ def version_7(cls, ctx, node, **kwargs):

@classmethod
def version_11(cls, ctx, node, **kwargs):
# starting with opset-11, equal supports all types (but both operands must be of the same type)
_add_cast_to_same_type_to_inputs(ctx, node)
# starting with opset-11, equal supports all numerical types (but both operands must be of the same type)
# string type is not supported
supported_dtypes = [
TensorProto.BOOL,
TensorProto.DOUBLE,
TensorProto.FLOAT,
TensorProto.FLOAT16,
TensorProto.INT8,
TensorProto.INT16,
TensorProto.INT32,
TensorProto.INT64,
TensorProto.UINT8,
TensorProto.UINT16,
TensorProto.UINT32,
TensorProto.UINT64
]
target_dtype = TensorProto.INT32
_add_cast_to_same_type_to_inputs(ctx, node, supported_dtypes, target_dtype)
need_not = node.type == "NotEqual"
if need_not:
node.type = "Equal"
Expand Down