diff --git a/tests/filecheck/dialects/stablehlo/ops.mlir b/tests/filecheck/dialects/stablehlo/ops.mlir index 274ea4a77b..ccb59814f4 100644 --- a/tests/filecheck/dialects/stablehlo/ops.mlir +++ b/tests/filecheck/dialects/stablehlo/ops.mlir @@ -33,6 +33,15 @@ // [[2,8], [4,10], [6,12]] // ] +// CHECK: %count_leading_zeros = "stablehlo.count_leading_zeros"(%t0) : (tensor) -> tensor +%count_leading_zeros = "stablehlo.count_leading_zeros"(%t0) : (tensor) -> tensor + +// CHECK: %popcnt = "stablehlo.popcnt"(%t0) : (tensor) -> tensor +%popcnt = "stablehlo.popcnt"(%t0) : (tensor) -> tensor + +// CHECK: %not = "stablehlo.not"(%t0) : (tensor) -> tensor +%not = "stablehlo.not"(%t0) : (tensor) -> tensor + // CHECK: %and = "stablehlo.and"(%t0, %t0) : (tensor, tensor) -> tensor %and = "stablehlo.and"(%t0, %t0) : (tensor, tensor) -> tensor @@ -42,6 +51,15 @@ // CHECK: %xor = "stablehlo.xor"(%t0, %t0) : (tensor, tensor) -> tensor %xor = "stablehlo.xor"(%t0, %t0) : (tensor, tensor) -> tensor +// CHECK: %shift_left = "stablehlo.shift_left"(%t0, %t0) : (tensor, tensor) -> tensor +%shift_left = "stablehlo.shift_left"(%t0, %t0) : (tensor, tensor) -> tensor + +// CHECK: %shift_right_arithmetic = "stablehlo.shift_right_arithmetic"(%t0, %t0) : (tensor, tensor) -> tensor +%shift_right_arithmetic = "stablehlo.shift_right_arithmetic"(%t0, %t0) : (tensor, tensor) -> tensor + +// CHECK: %shift_right_logical = "stablehlo.shift_right_logical"(%t0, %t0) : (tensor, tensor) -> tensor +%shift_right_logical = "stablehlo.shift_right_logical"(%t0, %t0) : (tensor, tensor) -> tensor + // %bitcast = "stablehlo.bitcast_convert"(%t0) : (tensor) -> tensor<2xi16> %bitcast = "stablehlo.bitcast_convert"(%t0) : (tensor) -> tensor<2xi16> diff --git a/xdsl/dialects/stablehlo.py b/xdsl/dialects/stablehlo.py index 9c744e7f01..dd2f4dee7a 100644 --- a/xdsl/dialects/stablehlo.py +++ b/xdsl/dialects/stablehlo.py @@ -94,6 +94,19 @@ def __init__( super().__init__(operands=(lhs, rhs), result_types=(result_type,)) +class IntegerTensorLikeElementwiseUnaryOperation(IRDLOperation, abc.ABC): + # TODO: Remove this constraint for complex types. + T: ClassVar = VarConstraint("T", base(IntegerTensorType)) + + operand = operand_def(T) + result = result_def(T) + + def __init__(self, operand: SSAValue, result_type: Attribute | None = None): + if result_type is None: + result_type = operand.type + super().__init__(operands=(operand,), result_types=(result_type,)) + + # endregion # region Attributes @@ -290,6 +303,43 @@ def __init__(self, inputs: Sequence[SSAValue]): super().__init__(operands=[inputs], result_types=(TokenType(),)) +@irdl_op_definition +class CountLeadingZerosOp(IntegerTensorLikeElementwiseUnaryOperation): + """ + Performs element-wise count of the number of leading zero bits in the operand tensor and produces a result tensor. + + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#count_leading_zeros + """ + + name = "stablehlo.count_leading_zeros" + + +@irdl_op_definition +class PopcntOp(IntegerTensorLikeElementwiseUnaryOperation): + """ + Performs element-wise count of the number of bits set in the operand tensor and produces a result tensor. + + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#popcnt + """ + + name = "stablehlo.popcnt" + + +@irdl_op_definition +class NotOp(IntegerTensorLikeElementwiseUnaryOperation): + """ + Performs element-wise NOT of tensor operand and produces a result tensor. + Depending on the element type, does the following: + + For booleans: logical NOT. + For integers: bitwise NOT. + + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not + """ + + name = "stablehlo.not" + + @irdl_op_definition class AndOp(IntegerTensorLikeElementwiseBinaryOperation): """ @@ -335,6 +385,39 @@ class XorOp(IntegerTensorLikeElementwiseBinaryOperation): name = "stablehlo.xor" +@irdl_op_definition +class ShiftLeftOp(IntegerTensorLikeElementwiseBinaryOperation): + """ + Performs element-wise left-shift operation on the lhs tensor by rhs number of bits and produces a result tensor. + + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_left + """ + + name = "stablehlo.shift_left" + + +@irdl_op_definition +class ShiftRightArithmeticOp(IntegerTensorLikeElementwiseBinaryOperation): + """ + Performs element-wise arithmetic right-shift operation on the lhs tensor by rhs number of bits and produces a result tensor. + + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_arithmetic + """ + + name = "stablehlo.shift_right_arithmetic" + + +@irdl_op_definition +class ShiftRightLogicalOp(IntegerTensorLikeElementwiseBinaryOperation): + """ + Performs element-wise logical right-shift operation on the lhs tensor by rhs number of bits and produces a result tensor. + + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_logical + """ + + name = "stablehlo.shift_right_logical" + + # TODO: Change to SI32 once StableHLO adopts signful integer semantics # See: https://github.com/openxla/stablehlo/issues/22 # https://github.com/openxla/stablehlo/issues/2489 @@ -514,9 +597,15 @@ def verify_(self) -> None: AbsOp, AddOp, AfterAllOp, + CountLeadingZerosOp, + PopcntOp, + NotOp, AndOp, OrOp, XorOp, + ShiftLeftOp, + ShiftRightArithmeticOp, + ShiftRightLogicalOp, BitcastConvertOp, CaseOp, MultiplyOp,