Skip to content

Commit f8b7486

Browse files
apaszkejax authors
authored andcommitted
Update the TPU dialect binding extension to follow MLIR guidelines
The way MLIR dialects are allowed to be extended in Python has recently changed (in llvm/llvm-project#68853), so we have to update our bindings. PiperOrigin-RevId: 575796552
1 parent 373c421 commit f8b7486

File tree

2 files changed

+29
-39
lines changed

2 files changed

+29
-39
lines changed

jaxlib/mosaic/python/_tpu_ops_ext.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

jaxlib/mosaic/python/tpu.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,35 @@
1717
# flake8: noqa: F401
1818
# flake8: noqa: F403
1919

20+
from mlir.dialects._ods_common import _cext
21+
2022
# pylint: disable=g-bad-import-order
2123
from ._tpu_gen import * # pylint: disable=wildcard-import
24+
from ._tpu_gen import _Dialect
2225
from jaxlib.mlir._mlir_libs._tpu_ext import * # pylint: disable=wildcard-import
26+
27+
28+
@_cext.register_operation(_Dialect, replace=True)
29+
class TraceOp(TraceOp):
30+
"""An extension to the automatically generated TraceOp bindings."""
31+
32+
def __init__(self, results, message, level, *, loc=None, ip=None):
33+
super().__init__(results, message, level, loc=loc, ip=ip)
34+
self.regions[0].blocks.append(*[]) # Append the block.
35+
36+
@property
37+
def body(self):
38+
return self.regions[0].blocks[0]
39+
40+
41+
@_cext.register_operation(_Dialect, replace=True)
42+
class RegionOp(RegionOp):
43+
"""An extension to the automatically generated RegionOp bindings."""
44+
45+
def __init__(self, *, loc=None, ip=None):
46+
super().__init__([], loc=loc, ip=ip)
47+
self.regions[0].blocks.append() # Append the block.
48+
49+
@property
50+
def body(self):
51+
return self.regions[0].blocks[0]

0 commit comments

Comments
 (0)