diff --git a/tests/test_backend.py b/tests/test_backend.py index 225f9461a..c1bc9c052 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1421,6 +1421,14 @@ def func(x1, x2): return tf.identity(mi, name=_TFOUTPUT) self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2}) + def test_equal_string(self): + x_val1 = np.array(['1'], dtype=np.string_) + x_val2 = np.array(['2'], dtype=np.string_) + def func(x1, x2): + mi = tf.equal(x1, x2) + return tf.identity(mi, name=_TFOUTPUT) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val1, _INPUT1: x_val2}) + def test_equal(self): x_val1 = np.array([4, 2, 4, 1], dtype=np.int32).reshape((2, 2)) x_val2 = np.array([2, 4, 4, 1], dtype=np.int32).reshape((2, 2)) diff --git a/tf2onnx/onnx_opset/logical.py b/tf2onnx/onnx_opset/logical.py index 6d9962ca9..3bd3be453 100644 --- a/tf2onnx/onnx_opset/logical.py +++ b/tf2onnx/onnx_opset/logical.py @@ -29,15 +29,21 @@ def _add_cast_to_inputs(graph, node, supported_dtypes, target_dtype): graph.copy_shape(inp, inp_cast.output[0]) graph.set_dtype(inp_cast.output[0], target_dtype) - -def _add_cast_to_same_type_to_inputs(graph, node): +def _add_cast_to_same_type_to_inputs(graph, node, supported_dtypes, target_dtype): common_dtype = graph.get_dtype(node.input[0]) + if common_dtype not in supported_dtypes: + common_dtype = target_dtype - for inp in node.input[1:]: + for inp in node.input: if graph.get_dtype(inp) != common_dtype: inp_cast = graph.insert_new_node_on_input(node, "Cast", inp, to=common_dtype) graph.copy_shape(inp, inp_cast.output[0]) graph.set_dtype(inp_cast.output[0], common_dtype) + if graph.is_const(inp) and graph.get_tensor_value(inp) == '': + # Convert '' string constant to -1 int + # https://github.com/tensorflow/tensorflow/blob/4e7f0185c70faf35e12acbfe381a729d1e6cc38c/tensorflow/python/feature_column/feature_column.py#L2286 + const_node = graph.get_node_by_output(inp) + const_node.set_tensor_value(utils.np.array(-1)) @tf_op("LogicalNot", onnx_op="Not") @@ -92,8 +98,24 @@ def version_7(cls, ctx, node, **kwargs): @classmethod def version_11(cls, ctx, node, **kwargs): - # starting with opset-11, equal supports all types (but both operands must be of the same type) - _add_cast_to_same_type_to_inputs(ctx, node) + # starting with opset-11, equal supports all numerical types (but both operands must be of the same type) + # string type is not supported + supported_dtypes = [ + TensorProto.BOOL, + TensorProto.DOUBLE, + TensorProto.FLOAT, + TensorProto.FLOAT16, + TensorProto.INT8, + TensorProto.INT16, + TensorProto.INT32, + TensorProto.INT64, + TensorProto.UINT8, + TensorProto.UINT16, + TensorProto.UINT32, + TensorProto.UINT64 + ] + target_dtype = TensorProto.INT32 + _add_cast_to_same_type_to_inputs(ctx, node, supported_dtypes, target_dtype) need_not = node.type == "NotEqual" if need_not: node.type = "Equal"