forked from xdslproject/xdsl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstablehlo.py
621 lines (470 loc) · 18.8 KB
/
stablehlo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
"""
https://github.com/openxla/stablehlo/blob/main/docs/spec.md
StableHLO is an operation set for high-level operations (HLO) in machine learning (ML) models.
StableHLO works as a portability layer between different ML frameworks and ML compilers:
ML frameworks that produce StableHLO programs are compatible with ML compilers that consume StableHLO programs.
"""
import abc
from collections.abc import Sequence
from typing import Annotated, ClassVar, TypeAlias, cast
from xdsl.dialects.builtin import (
I32,
I64,
AnyTensorType,
AnyTensorTypeConstr,
ArrayAttr,
DenseArrayBase,
IntegerAttr,
IntegerType,
TensorType,
i64,
)
from xdsl.ir import (
Attribute,
Dialect,
EnumAttribute,
ParametrizedAttribute,
Region,
SpacedOpaqueSyntaxAttribute,
SSAValue,
StrEnum,
TypeAttribute,
)
from xdsl.irdl import (
BaseAttr,
ConstraintVar,
IRDLOperation,
ParameterDef,
VarConstraint,
attr_def,
base,
irdl_attr_definition,
irdl_op_definition,
operand_def,
result_def,
traits_def,
var_operand_def,
var_region_def,
var_result_def,
)
from xdsl.parser import AttrParser
from xdsl.printer import Printer
from xdsl.traits import IsTerminator
from xdsl.utils.exceptions import VerifyException
IntegerTensorType: TypeAlias = TensorType[IntegerType]
# region Abstract Base Classes
class ElementwiseBinaryOperation(IRDLOperation, abc.ABC):
# TODO: Remove this constraint for complex types.
T: ClassVar = VarConstraint("T", base(AnyTensorType))
lhs = operand_def(T)
rhs = operand_def(T)
result = result_def(T)
def __init__(
self, lhs: SSAValue, rhs: SSAValue, result_type: Attribute | None = None
):
if result_type is None:
result_type = lhs.type
super().__init__(operands=(lhs, rhs), result_types=(result_type,))
class IntegerTensorLikeElementwiseBinaryOperation(IRDLOperation, abc.ABC):
# TODO: Remove this constraint for complex types.
T: ClassVar = VarConstraint("T", base(IntegerTensorType))
lhs = operand_def(T)
rhs = operand_def(T)
result = result_def(T)
def __init__(
self, lhs: SSAValue, rhs: SSAValue, result_type: Attribute | None = None
):
if result_type is None:
result_type = lhs.type
super().__init__(operands=(lhs, rhs), result_types=(result_type,))
class IntegerTensorLikeElementwiseUnaryOperation(IRDLOperation, abc.ABC):
# TODO: Remove this constraint for complex types.
T: ClassVar = VarConstraint("T", base(IntegerTensorType))
operand = operand_def(T)
result = result_def(T)
def __init__(self, operand: SSAValue, result_type: Attribute | None = None):
if result_type is None:
result_type = operand.type
super().__init__(operands=(operand,), result_types=(result_type,))
# endregion
# region Attributes
class Precision(StrEnum):
"""
XLA precision for an operand. Has backend specific meaning.
"""
DEFAULT = "DEFAULT"
HIGH = "HIGH"
HIGHEST = "HIGHEST"
@irdl_attr_definition
class PrecisionAttr(EnumAttribute[Precision], SpacedOpaqueSyntaxAttribute):
"""
XLA precision for an operand. Has backend specific meaning.
https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloEnums.td#L46
"""
name = "stablehlo.precision"
@irdl_attr_definition
class TokenType(TypeAttribute, ParametrizedAttribute):
"""
Token types represent tokens, i.e. opaque values produced and consumed by some operations.
Tokens are used for imposing execution order on operations as described in the Execution section.
E.g.,
// %input0: !stablehlo.token
// %input1: !stablehlo.token
%result = "stablehlo.after_all"(%input0, %input1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
"""
name = "stablehlo.token"
@irdl_attr_definition
class DotAttr(ParametrizedAttribute):
"""
Attribute that models the dimension information for dot.
https://github.com/openxla/stablehlo/blob/b075e948092d8a27ed0be48f4f8dbaa6df7e2e3e/stablehlo/dialect/StablehloAttrs.td#L82
"""
name = "stablehlo.dot"
lhs_batching_dimensions: ParameterDef[ArrayAttr[IntegerAttr[I64]]]
rhs_batching_dimensions: ParameterDef[ArrayAttr[IntegerAttr[I64]]]
lhs_contracting_dimensions: ParameterDef[ArrayAttr[IntegerAttr[I64]]]
rhs_contracting_dimensions: ParameterDef[ArrayAttr[IntegerAttr[I64]]]
@staticmethod
def _print_parameter(
name: str, value: ArrayAttr[IntegerAttr[I64]], printer: Printer
):
printer.print_string(f"\n{name} = [")
printer.print_list(
value.data,
lambda dim: printer.print_string(f"{dim.value.data}"),
)
printer.print_string("]")
@staticmethod
def _parse_parameter(name: str, parser: AttrParser) -> ArrayAttr[IntegerAttr[I64]]:
parser.parse_characters(name)
parser.parse_punctuation("=")
value = parser.parse_comma_separated_list(
AttrParser.Delimiter.SQUARE,
lambda: IntegerAttr(parser.parse_integer(), i64),
)
return ArrayAttr(value)
def print_parameters(self, printer: Printer) -> None:
with printer.in_angle_brackets():
with printer.indented():
DotAttr._print_parameter(
"lhs_batching_dimensions", self.lhs_batching_dimensions, printer
)
printer.print_string(",")
DotAttr._print_parameter(
"rhs_batching_dimensions", self.rhs_batching_dimensions, printer
)
printer.print_string(",")
DotAttr._print_parameter(
"lhs_contracting_dimensions",
self.lhs_contracting_dimensions,
printer,
)
printer.print_string(",")
DotAttr._print_parameter(
"rhs_contracting_dimensions",
self.rhs_contracting_dimensions,
printer,
)
printer.print_string("\n")
@classmethod
def parse_parameters(cls, parser: AttrParser) -> Sequence[Attribute]:
with parser.in_angle_brackets():
lhs_batching_dimensions = DotAttr._parse_parameter(
"lhs_batching_dimensions", parser
)
parser.parse_punctuation(",")
rhs_batching_dimensions = DotAttr._parse_parameter(
"rhs_batching_dimensions", parser
)
parser.parse_punctuation(",")
lhs_contracting_dimensions = DotAttr._parse_parameter(
"lhs_contracting_dimensions", parser
)
parser.parse_punctuation(",")
rhs_contracting_dimensions = DotAttr._parse_parameter(
"rhs_contracting_dimensions", parser
)
return (
lhs_batching_dimensions,
rhs_batching_dimensions,
lhs_contracting_dimensions,
rhs_contracting_dimensions,
)
# endregion
@irdl_op_definition
class AbsOp(IRDLOperation):
"""
Performs element-wise abs operation on operand tensor and produces a result tensor.
Depending on the element type, does the following:
* For signed integers: integer modulus.
* For floats: abs from IEEE-754.
* For complex numbers: complex modulus.
* For quantized types: dequantize_op_quantize(abs, operand, type(result)).
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#abs
"""
name = "stablehlo.abs"
# TODO: Remove this constraint for complex types.
T: ClassVar = VarConstraint("T", base(AnyTensorType))
operand = operand_def(T)
result = result_def(T)
def __init__(self, operand: SSAValue, result_type: Attribute | None = None):
if result_type is None:
# TODO: Constraints for complex types.
result_type = operand.type
super().__init__(operands=(operand,), result_types=(result_type,))
@irdl_op_definition
class AddOp(ElementwiseBinaryOperation):
"""
Performs element-wise addition of two tensors `lhs` and `rhs` and produces a
`result` tensor. Depending on the element type, does the following:
* For booleans: logical OR.
* For integers: integer addition.
* For floats: `addition` from IEEE-754.
* For complex numbers: complex addition.
* For quantized types: `dequantize_op_quantize(add, lhs, rhs, type(result))`.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#add
"""
name = "stablehlo.add"
@irdl_op_definition
class AfterAllOp(IRDLOperation):
"""
Ensures that the operations producing the inputs are executed before any operations that depend on result.
Execution of this operation does nothing, it only exists to establish data dependencies from result to inputs.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#after_all
"""
name = "stablehlo.after_all"
inputs = var_operand_def(TokenType)
result = result_def(TokenType)
def __init__(self, inputs: Sequence[SSAValue]):
super().__init__(operands=[inputs], result_types=(TokenType(),))
@irdl_op_definition
class CountLeadingZerosOp(IntegerTensorLikeElementwiseUnaryOperation):
"""
Performs element-wise count of the number of leading zero bits in the operand tensor and produces a result tensor.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#count_leading_zeros
"""
name = "stablehlo.count_leading_zeros"
@irdl_op_definition
class PopcntOp(IntegerTensorLikeElementwiseUnaryOperation):
"""
Performs element-wise count of the number of bits set in the operand tensor and produces a result tensor.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#popcnt
"""
name = "stablehlo.popcnt"
@irdl_op_definition
class NotOp(IntegerTensorLikeElementwiseUnaryOperation):
"""
Performs element-wise NOT of tensor operand and produces a result tensor.
Depending on the element type, does the following:
For booleans: logical NOT.
For integers: bitwise NOT.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#not
"""
name = "stablehlo.not"
@irdl_op_definition
class AndOp(IntegerTensorLikeElementwiseBinaryOperation):
"""
Performs element-wise AND of two tensors lhs and rhs and produces a result tensor.
Depending on the element type, does the following:
For booleans: logical AND.
For integers: bitwise AND.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#and
"""
name = "stablehlo.and"
@irdl_op_definition
class OrOp(IntegerTensorLikeElementwiseBinaryOperation):
"""
Performs element-wise Or of two tensors lhs and rhs and produces a result tensor.
Depending on the element type, does the following:
For booleans: logical OR.
For integers: bitwise OR.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#or
"""
name = "stablehlo.or"
@irdl_op_definition
class XorOp(IntegerTensorLikeElementwiseBinaryOperation):
"""
Performs element-wise XOR of two tensors lhs and rhs and produces a result tensor.
Depending on the element type, does the following:
For booleans: logical XOR.
For integers: bitwise XOR.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#xor
"""
name = "stablehlo.xor"
@irdl_op_definition
class ShiftLeftOp(IntegerTensorLikeElementwiseBinaryOperation):
"""
Performs element-wise left-shift operation on the lhs tensor by rhs number of bits and produces a result tensor.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_left
"""
name = "stablehlo.shift_left"
@irdl_op_definition
class ShiftRightArithmeticOp(IntegerTensorLikeElementwiseBinaryOperation):
"""
Performs element-wise arithmetic right-shift operation on the lhs tensor by rhs number of bits and produces a result tensor.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_arithmetic
"""
name = "stablehlo.shift_right_arithmetic"
@irdl_op_definition
class ShiftRightLogicalOp(IntegerTensorLikeElementwiseBinaryOperation):
"""
Performs element-wise logical right-shift operation on the lhs tensor by rhs number of bits and produces a result tensor.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#shift_right_logical
"""
name = "stablehlo.shift_right_logical"
# TODO: Change to SI32 once StableHLO adopts signful integer semantics
# See: https://github.com/openxla/stablehlo/issues/22
# https://github.com/openxla/stablehlo/issues/2489
SI32TensorType: TypeAlias = TensorType[I32]
@irdl_op_definition
class CaseOp(IRDLOperation):
"""
Semantics
Produces the output from executing exactly one function from branches depending on the value of index.
More formally, result = selected_branch() where:
selected_branch = branches[index] if 0 <= index < size(branches).
selected_branch = branches[-1] otherwise.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case
"""
name = "stablehlo.case"
index = operand_def(SI32TensorType)
branches = var_region_def("single_block")
_results = var_result_def(AnyTensorTypeConstr | BaseAttr(TokenType))
def __init__(
self,
index: SSAValue,
branches: Sequence[Region],
result_types: Sequence[AnyTensorType | TokenType],
):
super().__init__(
operands=(index,), result_types=(result_types,), regions=(branches,)
)
@irdl_op_definition
class BitcastConvertOp(IRDLOperation):
"""
Performs a bitcast operation on operand tensor and produces a result tensor
where the bits of the entire operand tensor are reinterpreted using the type of the result tensor.
More formally, given E = element_type(operand), E' = element_type(result), and R = rank(operand):
If num_bits(E') < num_bits(E), bits(result[i0, ..., iR-1, :]) = bits(operand[i0, ..., iR-1]).
If num_bits(E') > num_bits(E), bits(result[i0, ..., iR-2]) = bits(operand[i0, ..., iR-2, :]).
If num_bits(E') = num_bits(E), bits(result[i0, ..., iR-1]) = bits(operand[i0, ..., iR-1]).
bits returns in-memory representation of a given value,
and its behavior is implementation-defined because the exact representation of tensors is implementation-defined,
and the exact representation of element types is implementation-defined as well.
"""
name = "stablehlo.bitcast_convert"
input = operand_def(AnyTensorType)
result = result_def(AnyTensorType)
def __init__(self, input: SSAValue, result: Attribute):
super().__init__(operands=(input,), result_types=(result,))
@irdl_op_definition
class MultiplyOp(ElementwiseBinaryOperation):
"""
Performs element-wise product of two tensors `lhs` and `rhs` and produces a
`result` tensor. Depending on the element type, does the following:
* For booleans: logical AND.
* For integers: integer multiplication.
* For floats: `multiplication` from IEEE-754.
* For complex numbers: complex multiplication.
* For quantized types:
* `dequantize_op_quantize(multiply, lhs, rhs, type(result))`.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#multiply
"""
name = "stablehlo.multiply"
@irdl_op_definition
class SubtractOp(ElementwiseBinaryOperation):
"""
Performs element-wise subtraction of two tensors `lhs` and `rhs` and produces a
`result` tensor. Depending on the element type, does the following:
* For integers: integer subtraction.
* For floats: `subtraction` from IEEE-754.
* For complex numbers: complex subtraction.
* For quantized types:
* `dequantize_op_quantize(subtract, lhs, rhs, type(result))`.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#subtract
"""
name = "stablehlo.subtract"
@irdl_op_definition
class ReturnOp(IRDLOperation):
"""This op is un-documented.
StableHLO's return is used inside of the bodies of StableHLO ops.
It behaves like func.return but for StableHLO ops.
The func.return op is used inside of func.func op.
https://discord.com/channels/999073994483433573/1259494021269688360/1259992088565645312
"""
name = "stablehlo.return"
input = var_operand_def(AnyTensorType)
traits = traits_def(IsTerminator())
def __init__(self, input: list[SSAValue]):
super().__init__(operands=(input,))
@irdl_op_definition
class TransposeOp(IRDLOperation):
"""
Permutes the dimensions of `operand` tensor using `permutation` and produces a
`result` tensor. More formally, `result[result_index] = operand[operand_index]`
where `result_index[d] = operand_index[permutation[d]]`.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#transpose
"""
name = "stablehlo.transpose"
ElementType = Annotated[Attribute, ConstraintVar("ElementType")]
operand = operand_def(TensorType[ElementType])
result = result_def(TensorType[ElementType])
permutation = attr_def(DenseArrayBase)
def __init__(
self, operand: SSAValue, permutation: DenseArrayBase, result_type: Attribute
):
super().__init__(
operands=(operand,),
result_types=(result_type,),
attributes={"permutation": permutation},
)
def get_permutation(self) -> tuple[int, ...]:
return cast(tuple[int, ...], self.permutation.get_values())
def verify_(self) -> None:
# Operand and result types are checked before the custom `verify_`
o_type = cast(TensorType[Attribute], self.operand.type)
r_type = cast(TensorType[Attribute], self.result.type)
o_shape = o_type.get_shape()
r_shape = r_type.get_shape()
# TODO: Quantization constraints
# `permutation` is a permutation of `range(rank(operand))`
permutation = self.get_permutation()
if sorted(permutation) != list(range(len(o_shape))):
raise VerifyException(
f"Permutation {permutation} of transpose must be a permutation of "
f"range({len(o_shape)})"
)
# `shape(result) = dim(operand, permutation...)`
for i, dim in enumerate(permutation):
if r_shape[i] != o_shape[dim]:
raise VerifyException(
f"Permutation mismatch at dimension {i}, expected {o_shape[dim]}"
)
StableHLO = Dialect(
"stablehlo",
[
AbsOp,
AddOp,
AfterAllOp,
CountLeadingZerosOp,
PopcntOp,
NotOp,
AndOp,
OrOp,
XorOp,
ShiftLeftOp,
ShiftRightArithmeticOp,
ShiftRightLogicalOp,
BitcastConvertOp,
CaseOp,
MultiplyOp,
ReturnOp,
SubtractOp,
TransposeOp,
],
[
DotAttr,
PrecisionAttr,
TokenType,
],
)