@@ -94,11 +94,67 @@ def is_strided_fragmented_layout(attr: ir.Attribute) -> bool:
94
94
return bool (_strided_fragmented_layout_attr_pattern .search (str (attr )))
95
95
96
96
97
+ _tiled_layout_attr_pattern = re .compile (
98
+ r"^#mosaic_gpu.TiledLayout<\[(?P<tiling>.*)\],"
99
+ r" warp_dim\s*=\s*(?P<warp_dim>[-\d]+),"
100
+ r" lane_dims\s*=\s*\[(?P<lane_dims>.*)\],"
101
+ r" vector_dim\s*=\s*(?P<vector_dim>[-\d]+)>$"
102
+ )
103
+
104
+
105
+ def to_tiled_layout_attr (
106
+ layout : fa .TiledLayout ,
107
+ ) -> ir .Attribute :
108
+ """Constructs a #mosaic_gpu.TiledLayout attribute from a TiledLayout."""
109
+
110
+ tile_str = lambda tile : "[" + ", " .join (str (d ) for d in tile ) + "]"
111
+ tiling = "[" + ", " .join (tile_str (tile ) for tile in layout .tiling .tiles ) + "]"
112
+ return ir .Attribute .parse (
113
+ f"#mosaic_gpu.TiledLayout<{ tiling } , warp_dim={ layout .warp_dim } ,"
114
+ f" lane_dims={ list (layout .lane_dims )} , vector_dim={ layout .vector_dim } >"
115
+ )
116
+
117
+
118
+ _list_of_lists_delimiter = re .compile (r"\]\s*,\s*\[" )
119
+
120
+
121
+ def from_tiled_layout_attr (
122
+ attr : ir .Attribute ,
123
+ ) -> fa .TiledLayout :
124
+ """Constructs a TiledLayout from a #mosaic_gpu.TiledLayout attribute.
125
+
126
+ Raises:
127
+ ValueError: If the attribute is not a #mosaic_gpu.TiledLayout
128
+ attribute.
129
+ """
130
+ match = _tiled_layout_attr_pattern .fullmatch (str (attr ))
131
+ if not match :
132
+ raise ValueError (
133
+ f"Expected a #mosaic_gpu.TiledLayout attribute, got { attr } "
134
+ )
135
+
136
+ tiling_str = match .group ("tiling" )
137
+ tile_strings = []
138
+ if len (tiling_str ) > 2 :
139
+ tile_strings = _list_of_lists_delimiter .split (tiling_str [1 :- 1 ])
140
+ tiles = tuple (tuple (map (int , ts .split ("," ))) for ts in tile_strings )
141
+ return fa .TiledLayout (
142
+ tiling = fa .Tiling (tiles ),
143
+ warp_dim = int (match .group ("warp_dim" )),
144
+ lane_dims = tuple (int (s ) for s in match .group ("lane_dims" ).split ("," )),
145
+ vector_dim = int (match .group ("vector_dim" ))
146
+ )
147
+
148
+
149
+ def is_tiled_layout (attr : ir .Attribute ) -> bool :
150
+ return bool (_tiled_layout_attr_pattern .search (str (attr )))
151
+
152
+
97
153
def to_layout_attr (
98
154
layout : (
99
155
fa .WGSplatFragLayout
100
156
| fa .WGStridedFragLayout
101
- | fa .WGMMAFragLayout
157
+ | fa .TiledLayout
102
158
| fa .WGMMARowFragLayout
103
159
),
104
160
) -> ir .Attribute :
@@ -108,8 +164,8 @@ def to_layout_attr(
108
164
return to_splat_fragmented_layout_attr (layout )
109
165
case fa .WGStridedFragLayout ():
110
166
return to_strided_fragmented_layout_attr (layout )
111
- case fa .WGMMAFragLayout ():
112
- return ir . Attribute . parse ( "#mosaic_gpu.WGMMAFragLayout" )
167
+ case fa .TiledLayout ():
168
+ return to_tiled_layout_attr ( layout )
113
169
case fa .WGMMARowFragLayout ():
114
170
return ir .Attribute .parse ("#mosaic_gpu.WGMMARowFragLayout" )
115
171
case _:
@@ -118,15 +174,6 @@ def to_layout_attr(
118
174
)
119
175
120
176
121
- _wgmma_fragmented_layout_attr_pattern = re .compile (
122
- r"^#mosaic_gpu.WGMMAFragLayout$"
123
- )
124
-
125
-
126
- def is_wgmma_fragmented_layout (attr : ir .Attribute ) -> bool :
127
- return bool (_wgmma_fragmented_layout_attr_pattern .search (str (attr )))
128
-
129
-
130
177
_wgmma_row_fragmented_layout_attr_pattern = re .compile (
131
178
r"^#mosaic_gpu.WGMMARowFragLayout$"
132
179
)
@@ -141,16 +188,16 @@ def from_layout_attr(
141
188
) -> (
142
189
fa .WGSplatFragLayout
143
190
| fa .WGStridedFragLayout
144
- | fa .WGMMAFragLayout
191
+ | fa .TiledLayout
145
192
| fa .WGMMARowFragLayout
146
193
):
147
194
"""Constructs a layout from an MLIR attribute."""
148
195
if is_splat_fragmented_layout (attr ):
149
196
return from_splat_fragmented_layout_attr (attr )
150
197
elif is_strided_fragmented_layout (attr ):
151
198
return from_strided_fragmented_layout_attr (attr )
152
- elif is_wgmma_fragmented_layout (attr ):
153
- return fa . WGMMAFragLayout ( )
199
+ elif is_tiled_layout (attr ):
200
+ return from_tiled_layout_attr ( attr )
154
201
elif is_wgmma_row_fragmented_layout (attr ):
155
202
return fa .WGMMARowFragLayout ()
156
203
else :
0 commit comments