Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Support Bitwise operations #13888

Merged
merged 5 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 75 additions & 8 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5578,13 +5578,31 @@ def _impl_v10(cls, inputs, attr, params):
)


class BitShift(OnnxOpConverter):
"""Operator converter for NonZero"""
class BitwiseBase(OnnxOpConverter):
"""Base class of operator converter for Bitwise operations"""

name = ""

@classmethod
def check_inputs(cls, inputs, num=2, use_int=True):
assert len(inputs) == num, "{} takes {} inputs, {} given".format(cls.name, num, len(inputs))

valid_types = ["uint8", "uint16", "uint32", "uint64"]
if use_int:
valid_types += ["int8", "int16", "int32", "int64"]
for i in range(num):
in_dtype = infer_type(inputs[i]).checked_type.dtype
assert in_dtype in valid_types, "Wrong dtype of the {}-th input: {}".format(i, in_dtype)


class BitShift(BitwiseBase):
"""Operator converter for BitShift"""

name = "BitShift"

@classmethod
def _impl_v11(cls, inputs, attr, params):
if len(inputs) != 2:
raise ValueError("Bitshift expects 2 inputs")
cls.check_inputs(inputs, use_int=False)

direction = attr.get("direction", "LEFT").decode("ascii")
if direction == "LEFT":
Expand All @@ -5596,6 +5614,54 @@ def _impl_v11(cls, inputs, attr, params):
return out


class BitwiseAnd(BitwiseBase):
"""Operator converter for BitwiseAnd"""

name = "BitwiseAnd"

@classmethod
def _impl_v18(cls, inputs, attr, params):
cls.check_inputs(inputs)

return _op.bitwise_and(*inputs)


class BitwiseNot(BitwiseBase):
"""Operator converter for BitwiseNot"""

name = "BitwiseNot"

@classmethod
def _impl_v18(cls, inputs, attr, params):
cls.check_inputs(inputs, num=1)

return _op.bitwise_not(*inputs)


class BitwiseOr(BitwiseBase):
"""Operator converter for BitwiseOr"""

name = "BitwiseOr"

@classmethod
def _impl_v18(cls, inputs, attr, params):
cls.check_inputs(inputs)

return _op.bitwise_or(*inputs)


class BitwiseXor(BitwiseBase):
"""Operator converter for BitwiseXor"""

name = "BitwiseXor"

@classmethod
def _impl_v18(cls, inputs, attr, params):
cls.check_inputs(inputs)

return _op.bitwise_xor(*inputs)


class Unique(OnnxOpConverter):
"""Operator converter for unique"""

Expand Down Expand Up @@ -6319,7 +6385,12 @@ def _get_convert_map(opset):
"OptionalHasElement": OptionalHasElement.get_converter(opset),
"OptionalGetElement": OptionalGetElement.get_converter(opset),
"Affine": Affine.get_converter(opset),
# Bitwise operators
"BitShift": BitShift.get_converter(opset),
"BitwiseAnd": BitwiseAnd.get_converter(opset),
"BitwiseNot": BitwiseNot.get_converter(opset),
"BitwiseOr": BitwiseOr.get_converter(opset),
"BitwiseXor": BitwiseXor.get_converter(opset),
"ThresholdedRelu": ThresholdedRelu.get_converter(opset),
"ScaledTanh": ScaledTanh.get_converter(opset),
"ParametricSoftplus": ParametricSoftPlus.get_converter(opset),
Expand All @@ -6337,10 +6408,6 @@ def _get_convert_map(opset):
"Upsample": Upsample.get_converter(opset),
"SpatialBN": BatchNorm.get_converter(opset),
# defs/generator
# 'RandomUniform'
# 'RandomNormal'
# 'RandomUniformLike'
# 'RandomNormalLike'
# defs/logical
# defs/math
"Add": Add.get_converter(opset),
Expand Down
150 changes: 150 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7506,6 +7506,156 @@ def repeat(num, dims):
)


@tvm.testing.parametrize_targets
def test_bitshift(target, dev):
"""test_bitshift"""

def verify_bitshift(in_shape, shift_shape, high=1000000000, in_dtype="uint64"):
in_shape = list(in_shape)
shift_shape = list(shift_shape)

# Create an input for each tensor.
tensor_values = [
np.random.randint(high, size=in_shape).astype(in_dtype),
np.random.randint(16, size=shift_shape).astype(in_dtype),
np.random.randint(16, size=shift_shape).astype(in_dtype),
]

bitshift_left_node = helper.make_node(
"BitShift",
inputs=["input", "shift_left"],
outputs=["shifted"],
direction="LEFT",
)

bitshift_right_node = helper.make_node(
"BitShift",
inputs=["shifted", "shift_right"],
outputs=["output"],
direction="RIGHT",
)

# Create input and output tensors.
proto_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)]
graph_inputs = [
helper.make_tensor_value_info("input", proto_type, in_shape),
helper.make_tensor_value_info("shift_left", proto_type, shift_shape),
helper.make_tensor_value_info("shift_right", proto_type, shift_shape),
]

graph_outputs = [helper.make_tensor_value_info("output", proto_type, in_shape)]

graph_nodes = [bitshift_left_node, bitshift_right_node]

graph = helper.make_graph(
graph_nodes,
"BitShift_test",
inputs=graph_inputs,
outputs=graph_outputs,
)
model = helper.make_model(
graph,
producer_name="BitShift_test",
)

verify_with_ort_with_inputs(model, tensor_values, target=target, dev=dev)

shape = (100, 4, 2)
broadcast_shape = (100, 1, 1)
# Common bitwise test
verify_bitshift(shape, shape)
# Bitwise test with broadcasting
verify_bitshift(shape, broadcast_shape)


# TODO(vvchernov): return test back than ONNX Runtime in CI will support domain version of 18
@pytest.mark.skip("Currently ONNX Runtime in CI does not support domain version of 18")
@tvm.testing.parametrize_targets
def test_bitwise(target, dev):
"""test_bitwise"""

def verify_bitwise_ops(A_shape, B_shape, C_shape, D_shape, high=128, in_dtype="int32"):
A_shape = list(A_shape)
B_shape = list(B_shape)
C_shape = list(C_shape)
D_shape = list(D_shape)

# Create an input for each tensor.
tensor_values = [
np.random.randint(high, size=A_shape).astype(in_dtype),
np.random.randint(high, size=B_shape).astype(in_dtype),
np.random.randint(high, size=C_shape).astype(in_dtype),
np.random.randint(high, size=D_shape).astype(in_dtype),
]

or_node = helper.make_node(
"BitwiseOr",
inputs=["A", "B"],
outputs=["OR"],
)

and_node = helper.make_node(
"BitwiseAnd",
inputs=["OR", "C"],
outputs=["AND"],
)

xor_node = helper.make_node(
"BitwiseXor",
inputs=["AND", "D"],
outputs=["XOR"],
)

not_node = helper.make_node(
"BitwiseNot",
inputs=["XOR"],
outputs=["output"],
)

# Create input and output tensors.
proto_type = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(in_dtype)]
graph_inputs = [
helper.make_tensor_value_info("A", proto_type, A_shape),
helper.make_tensor_value_info("B", proto_type, B_shape),
helper.make_tensor_value_info("C", proto_type, C_shape),
helper.make_tensor_value_info("D", proto_type, D_shape),
]

graph_outputs = [
helper.make_tensor_value_info("output", proto_type, A_shape),
]

graph_nodes = [
or_node,
and_node,
xor_node,
not_node,
]

graph = helper.make_graph(
graph_nodes,
"Bitwise_test",
inputs=graph_inputs,
outputs=graph_outputs,
)
model = helper.make_model(
graph,
producer_name="Bitwise_test",
)

verify_with_ort_with_inputs(model, tensor_values, target=target, dev=dev)

shape = (100, 4, 2)
broadcast_shape = (100, 1, 1)
dtypes = ["int8", "uint8", "int32", "uint32"]
high_vals = [128, 128, 2147483648, 2147483648]
for high, dtype in zip(high_vals, dtypes):
# Common bitwise test
verify_bitwise_ops(shape, shape, shape, shape, high, dtype)
# Bitwise test with broadcasting
verify_bitwise_ops(shape, broadcast_shape, broadcast_shape, broadcast_shape, high, dtype)


@tvm.testing.parametrize_targets
def test_scan(target, dev):
"""test_scan"""
Expand Down