From d613b99883b9057e2cb956a499f643ab35f9d657 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Tue, 31 Jan 2023 23:00:37 +0300 Subject: [PATCH 1/5] add base class for bitwise operations. BitwiseAnd, BitwiseNot, BitwiseOr and BitwiseXor were implemented --- python/tvm/relay/frontend/onnx.py | 83 ++++++++++++++++++++++++++++--- 1 file changed, 75 insertions(+), 8 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 8b4a0cc5e8d3..8b3ee8374fbd 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -5578,13 +5578,31 @@ def _impl_v10(cls, inputs, attr, params): ) -class BitShift(OnnxOpConverter): - """Operator converter for NonZero""" +class BitwiseBase(OnnxOpConverter): + """Base class of operator converter for Bitwise operations""" + + name = "" + + @classmethod + def check_inputs(cls, inputs, num=2, use_int=True): + assert len(inputs) == num, "{} takes {} inputs, {} given".format(cls.name, num, len(inputs)) + + valid_types = ["uint8", "uint16","uint32", "uint64"] + if use_int: + valid_types += ["int8", "int16","int32", "int64"] + for i in range(num): + in_dtype = infer_type(inputs[i]).checked_type.dtype + assert in_dtype in valid_types, "Wrong dtype of the {}-th input: {}".format(i, in_dtype) + + +class BitShift(BitwiseBase): + """Operator converter for BitShift""" + + name = "BitShift" @classmethod def _impl_v11(cls, inputs, attr, params): - if len(inputs) != 2: - raise ValueError("Bitshift expects 2 inputs") + cls.check_inputs(inputs, use_int=False) direction = attr.get("direction", "LEFT").decode("ascii") if direction == "LEFT": @@ -5596,6 +5614,54 @@ def _impl_v11(cls, inputs, attr, params): return out +class BitwiseAnd(BitwiseBase): + """Operator converter for BitwiseAnd""" + + name = "BitwiseAnd" + + @classmethod + def _impl_v18(cls, inputs, attr, params): + cls.check_inputs(inputs) + + return _op.bitwise_and(*inputs) + + +class BitwiseNot(BitwiseBase): + """Operator converter for BitwiseNot""" + + name = "BitwiseNot" + + @classmethod + def _impl_v18(cls, inputs, attr, params): + cls.check_inputs(inputs, num=1) + + return _op.bitwise_not(*inputs) + + +class BitwiseOr(BitwiseBase): + """Operator converter for BitwiseOr""" + + name = "BitwiseOr" + + @classmethod + def _impl_v18(cls, inputs, attr, params): + cls.check_inputs(inputs) + + return _op.bitwise_or(*inputs) + + +class BitwiseXor(BitwiseBase): + """Operator converter for BitwiseXor""" + + name = "BitwiseXor" + + @classmethod + def _impl_v18(cls, inputs, attr, params): + cls.check_inputs(inputs) + + return _op.bitwise_xor(*inputs) + + class Unique(OnnxOpConverter): """Operator converter for unique""" @@ -6319,7 +6385,12 @@ def _get_convert_map(opset): "OptionalHasElement": OptionalHasElement.get_converter(opset), "OptionalGetElement": OptionalGetElement.get_converter(opset), "Affine": Affine.get_converter(opset), + # Bitwise operators "BitShift": BitShift.get_converter(opset), + "BitwiseAnd": BitwiseAnd.get_converter(opset), + "BitwiseNot": BitwiseNot.get_converter(opset), + "BitwiseOr": BitwiseOr.get_converter(opset), + "BitwiseXor": BitwiseXor.get_converter(opset), "ThresholdedRelu": ThresholdedRelu.get_converter(opset), "ScaledTanh": ScaledTanh.get_converter(opset), "ParametricSoftplus": ParametricSoftPlus.get_converter(opset), @@ -6337,10 +6408,6 @@ def _get_convert_map(opset): "Upsample": Upsample.get_converter(opset), "SpatialBN": BatchNorm.get_converter(opset), # defs/generator - # 'RandomUniform' - # 'RandomNormal' - # 'RandomUniformLike' - # 'RandomNormalLike' # defs/logical # defs/math "Add": Add.get_converter(opset), From a66a845846e48418b7eeed012246c5d099d5f185 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 1 Feb 2023 09:59:13 +0300 Subject: [PATCH 2/5] add test for BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor operations to ONNX front-end --- python/tvm/relay/frontend/onnx.py | 4 +- tests/python/frontend/onnx/test_forward.py | 86 ++++++++++++++++++++++ 2 files changed, 88 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 8b3ee8374fbd..8de5e0e08bd8 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -5587,9 +5587,9 @@ class BitwiseBase(OnnxOpConverter): def check_inputs(cls, inputs, num=2, use_int=True): assert len(inputs) == num, "{} takes {} inputs, {} given".format(cls.name, num, len(inputs)) - valid_types = ["uint8", "uint16","uint32", "uint64"] + valid_types = ["uint8", "uint16", "uint32", "uint64"] if use_int: - valid_types += ["int8", "int16","int32", "int64"] + valid_types += ["int8", "int16", "int32", "int64"] for i in range(num): in_dtype = infer_type(inputs[i]).checked_type.dtype assert in_dtype in valid_types, "Wrong dtype of the {}-th input: {}".format(i, in_dtype) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index dd172d1ddea6..89d8db33d9f5 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -7506,6 +7506,92 @@ def repeat(num, dims): ) +@tvm.testing.parametrize_targets +def test_bitwise(target, dev): + """test_bitwise""" + + def verify_bitwise_ops(A_shape, B_shape, C_shape, D_shape, high=128, in_dtype="int32"): + A_shape = list(A_shape) + B_shape = list(B_shape) + C_shape = list(C_shape) + D_shape = list(D_shape) + + # Create an input for each tensor. + tensor_values = [ + np.random.randint(high, size=A_shape).astype(in_dtype), + np.random.randint(high, size=B_shape).astype(in_dtype), + np.random.randint(high, size=C_shape).astype(in_dtype), + np.random.randint(high, size=D_shape).astype(in_dtype), + ] + + or_node = helper.make_node( + "BitwiseOr", + inputs=["A", "B"], + outputs=["OR"], + ) + + and_node = helper.make_node( + "BitwiseAnd", + inputs=["OR", "C"], + outputs=["AND"], + ) + + xor_node = helper.make_node( + "BitwiseXor", + inputs=["AND", "D"], + outputs=["XOR"], + ) + + not_node = helper.make_node( + "BitwiseNot", + inputs=["XOR"], + outputs=["output"], + ) + + # Create input and output tensors. + proto_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + graph_inputs = [ + helper.make_tensor_value_info("A", proto_type, A_shape), + helper.make_tensor_value_info("B", proto_type, B_shape), + helper.make_tensor_value_info("C", proto_type, C_shape), + helper.make_tensor_value_info("D", proto_type, D_shape), + ] + + graph_outputs = [ + helper.make_tensor_value_info("output", proto_type, A_shape), + ] + + graph_nodes = [ + or_node, + and_node, + xor_node, + not_node, + ] + + graph = helper.make_graph( + graph_nodes, + "Bitwise_test", + inputs=graph_inputs, + outputs=graph_outputs, + ) + model = helper.make_model( + graph, + producer_name="Bitwise_test", + ) + + verify_with_ort_with_inputs(model, tensor_values, target=target, dev=dev) + + shape = (100, 4, 2,) + broadcast_shape = (100, 1, 1,) + dtypes = ["int8", "uint8", "int32", "uint32"] + high_vals = [128, 128, 2147483648, 2147483648] + for high, dtype in zip(high_vals, dtypes): + # Common bitwise test + verify_bitwise_ops(shape, shape, shape, shape, high, dtype) + # Bitwise test with broadcasting + verify_bitwise_ops(shape, broadcast_shape, broadcast_shape, broadcast_shape, high, dtype) + + @tvm.testing.parametrize_targets def test_scan(target, dev): """test_scan""" From 25dad0ab7ea98ae1cb78324eb6178884377a0b99 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 1 Feb 2023 10:32:30 +0300 Subject: [PATCH 3/5] add test of BitShift for ONNX front-end --- tests/python/frontend/onnx/test_forward.py | 68 +++++++++++++++++++++- 1 file changed, 65 insertions(+), 3 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 89d8db33d9f5..e09e83218f92 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -7506,6 +7506,68 @@ def repeat(num, dims): ) +@tvm.testing.parametrize_targets +def test_bitshift(target, dev): + """test_bitshift""" + + def verify_bitshift(in_shape, shift_shape, high=1000000000, in_dtype="int64"): + in_shape = list(in_shape) + shift_shape = list(shift_shape) + + # Create an input for each tensor. + tensor_values = [ + np.random.randint(high, size=in_shape).astype(in_dtype), + np.random.randint(16, size=shift_shape).astype(in_dtype), + np.random.randint(16, size=shift_shape).astype(in_dtype), + ] + + bitshift_left_node = helper.make_node( + "BitShift", + inputs=["input", "shift_left"], + outputs=["shifted"], + direction="LEFT", + ) + + bitshift_right_node = helper.make_node( + "BitShift", + inputs=["shifted", "shift_right"], + outputs=["output"], + direction="RIGHT", + ) + + # Create input and output tensors. + proto_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)] + graph_inputs = [ + helper.make_tensor_value_info("input", proto_type, in_shape), + helper.make_tensor_value_info("shift_left", proto_type, shift_shape), + helper.make_tensor_value_info("shift_right", proto_type, shift_shape), + ] + + graph_outputs = [helper.make_tensor_value_info("output", proto_type, in_shape)] + + graph_nodes = [bitshift_left_node, bitshift_right_node] + + graph = helper.make_graph( + graph_nodes, + "BitShift_test", + inputs=graph_inputs, + outputs=graph_outputs, + ) + model = helper.make_model( + graph, + producer_name="BitShift_test", + ) + + verify_with_ort_with_inputs(model, tensor_values, target=target, dev=dev) + + shape = (100, 4, 2) + broadcast_shape = (100, 1, 1) + # Common bitwise test + verify_bitshift(shape, shape) + # Bitwise test with broadcasting + verify_bitshift(shape, broadcast_shape) + + @tvm.testing.parametrize_targets def test_bitwise(target, dev): """test_bitwise""" @@ -7549,7 +7611,7 @@ def verify_bitwise_ops(A_shape, B_shape, C_shape, D_shape, high=128, in_dtype="i ) # Create input and output tensors. - proto_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] + proto_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)] graph_inputs = [ helper.make_tensor_value_info("A", proto_type, A_shape), helper.make_tensor_value_info("B", proto_type, B_shape), @@ -7581,8 +7643,8 @@ def verify_bitwise_ops(A_shape, B_shape, C_shape, D_shape, high=128, in_dtype="i verify_with_ort_with_inputs(model, tensor_values, target=target, dev=dev) - shape = (100, 4, 2,) - broadcast_shape = (100, 1, 1,) + shape = (100, 4, 2) + broadcast_shape = (100, 1, 1) dtypes = ["int8", "uint8", "int32", "uint32"] high_vals = [128, 128, 2147483648, 2147483648] for high, dtype in zip(high_vals, dtypes): From 6b00601d2fbe5f2ca8835739f7359a71e59b1ef9 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Wed, 1 Feb 2023 12:59:59 +0300 Subject: [PATCH 4/5] fix dtype for test --- tests/python/frontend/onnx/test_forward.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index e09e83218f92..ee66657b616d 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -7510,7 +7510,7 @@ def repeat(num, dims): def test_bitshift(target, dev): """test_bitshift""" - def verify_bitshift(in_shape, shift_shape, high=1000000000, in_dtype="int64"): + def verify_bitshift(in_shape, shift_shape, high=1000000000, in_dtype="uint64"): in_shape = list(in_shape) shift_shape = list(shift_shape) From b4896350b10b1e6702c73646ffd6deeb93546e40 Mon Sep 17 00:00:00 2001 From: Valery Chernov Date: Thu, 2 Feb 2023 12:53:54 +0300 Subject: [PATCH 5/5] skip test due to old version of ORT --- tests/python/frontend/onnx/test_forward.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index ee66657b616d..0a032843267a 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -7568,6 +7568,8 @@ def verify_bitshift(in_shape, shift_shape, high=1000000000, in_dtype="uint64"): verify_bitshift(shape, broadcast_shape) +# TODO(vvchernov): return test back than ONNX Runtime in CI will support domain version of 18 +@pytest.mark.skip("Currently ONNX Runtime in CI does not support domain version of 18") @tvm.testing.parametrize_targets def test_bitwise(target, dev): """test_bitwise"""