diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 00e85200b3c6..1ff069720ed7 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -113,12 +113,21 @@ def __rshift__(self, other): def __and__(self, other): return _make.bitwise_and(self, other) + def __rand__(self, other): + return _make.bitwise_and(other, self) + def __or__(self, other): return _make.bitwise_or(self, other) + def __ror__(self, other): + return _make.bitwise_or(other, self) + def __xor__(self, other): return _make.bitwise_xor(self, other) + def __rxor__(self, other): + return _make.bitwise_xor(other, self) + def __invert__(self): return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0) diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_lang_basic.py index 8b54ef9534d6..0015d6d2cd8d 100644 --- a/tests/python/unittest/test_lang_basic.py +++ b/tests/python/unittest/test_lang_basic.py @@ -175,6 +175,9 @@ def test_bitwise(): assert str(x & y) == 'bitwise_and(x, y)' assert str(x | y) == 'bitwise_or(x, y)' assert str(x ^ y) == 'bitwise_xor(x, y)' + assert str(10 & x) == 'bitwise_and(10, x)' + assert str(10 | x) == 'bitwise_or(10, x)' + assert str(10 ^ x) == 'bitwise_xor(10, x)' assert str(~x) == 'bitwise_not(x)' assert(tvm.const(1, "int8x2") >> 1).dtype == "int8x2" assert(x >> tvm.const(1, "int32x2")).dtype == "int32x2"