diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 32b2c508f009..f6bde18b8db5 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -564,15 +564,14 @@ def _impl_v18(cls, bb, inputs, attr, params): return cls.base_impl(bb, inputs, attr, params) -class BitwiseNot(BitwiseBase): +class BitwiseNot(OnnxOpConverter): """Converts an onnx BitwiseNot node into an equivalent Relax expression.""" - numpy_op = _np.bitwise_not - relax_op = relax.op.bitwise_not - @classmethod def _impl_v18(cls, bb, inputs, attr, params): - return cls.base_impl(bb, inputs, attr, params) + if isinstance(inputs[0], relax.Constant): + return relax.const(_np.bitwise_not(inputs[0].data.numpy()), inputs[0].struct_info.dtype) + return relax.op.bitwise_not(inputs[0]) class BitShift(BitwiseBase): @@ -3117,13 +3116,13 @@ def _get_convert_map(): "BitwiseAnd": BitwiseAnd, "BitwiseOr": BitwiseOr, "BitwiseXor": BitwiseXor, - "BitwiseNot": BitwiseNot, "BitShift": BitShift, "And": And, "Or": Or, "Xor": Xor, "Not": Not, # Unary operators + "BitwiseNot": BitwiseNot, "Log": Log, "Exp": Exp, "Acos": Acos, diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index b7305d4810ed..6c3334f64d12 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -402,13 +402,11 @@ def test_binary_bool(op_name: str): verify_binary(op_name, [32, 32], [32, 32], [32, 32], dtype=TensorProto.BOOL) -@pytest.mark.skip(reason="opset 18 is not supported in CI") @pytest.mark.parametrize("op_name", ["BitwiseAnd", "BitwiseOr", "BitwiseXor"]) def test_bitwise(op_name: str): verify_binary(op_name, [32, 32], [32, 32], [32, 32], dtype=TensorProto.UINT64, opset=18) -@pytest.mark.skip(reason="opset 18 is not supported in CI") def test_bitwise_not(): verify_unary( "BitwiseNot", @@ -945,7 +943,6 @@ def test_selu(): verify_unary("Selu", [3, 32, 32], attrs={"alpha": 0.25, "gamma": 0.3}) -@pytest.mark.skip(reason="opset 18 is not supported in CI") def test_mish(): verify_unary("Mish", [3, 32, 32], opset=18)