@@ -1543,6 +1543,60 @@ def _reduce_max_lowering_rule(ctx: LoweringRuleContext, x, *, axes):
1543
1543
raise NotImplementedError (f"Unsupported layout { x .layout } " )
1544
1544
1545
1545
1546
+ def _reduce_lowering_rule_wg (
1547
+ kind : vector_dialect .CombiningKind ,
1548
+ acc : object ,
1549
+ ctx : LoweringRuleContext ,
1550
+ x ,
1551
+ * ,
1552
+ axes ,
1553
+ ) -> ir .OpView :
1554
+ [x_aval ] = ctx .avals_in
1555
+ [out_aval ] = ctx .avals_out
1556
+ x = _ensure_ir_value (x , x_aval .dtype )
1557
+ out_type = mgpu_utils .dtype_to_ir_type (out_aval .dtype )
1558
+ if not out_aval .shape :
1559
+ # Special-case: reducing to a scalar.
1560
+ if x_aval .ndim != 1 :
1561
+ # TODO(slebedev): Flatten to 1D, since vector.reduction only supports
1562
+ # 1D inputs.
1563
+ raise NotImplementedError ("Only 1D inputs are supported" )
1564
+ return vector_dialect .ReductionOp (out_type , kind , x )
1565
+ acc = vector_dialect .splat (
1566
+ ir .VectorType .get (out_aval .shape , out_type ),
1567
+ _ensure_ir_value (acc , out_aval .dtype ),
1568
+ )
1569
+ return vector_dialect .MultiDimReductionOp (kind , x , acc , axes )
1570
+
1571
+
1572
+ @register_lowering_rule (lax .reduce_sum_p , mgpu .ThreadSemantics .Warpgroup )
1573
+ def _reduce_sum_lowering_rule_wg (ctx : LoweringRuleContext , x , * , axes ):
1574
+ op = _reduce_lowering_rule_wg (
1575
+ vector_dialect .CombiningKind .ADD , 0 , ctx , x , axes = axes
1576
+ )
1577
+ op .attributes ["offset" ] = ir .IntegerAttr .get (
1578
+ ir .IntegerType .get_signless (32 ), ctx .module_ctx .smem_used_bytes
1579
+ )
1580
+ return op .result
1581
+
1582
+
1583
+ @register_lowering_rule (lax .reduce_max_p , mgpu .ThreadSemantics .Warpgroup )
1584
+ def _reduce_max_lowering_rule_wg (ctx : LoweringRuleContext , x , * , axes ):
1585
+ [x_aval ] = ctx .avals_in
1586
+ if jnp .issubdtype (x_aval .dtype , jnp .floating ):
1587
+ kind = vector_dialect .CombiningKind .MAXIMUMF
1588
+ acc = float ("-inf" )
1589
+ elif jnp .issubdtype (x_aval .dtype , jnp .signedinteger ):
1590
+ kind = vector_dialect .CombiningKind .MAXSI
1591
+ acc = np .iinfo (x_aval .dtype ).max
1592
+ elif jnp .issubdtype (x_aval .dtype , jnp .unsignedinteger ):
1593
+ kind = vector_dialect .CombiningKind .MAXUI
1594
+ acc = np .iinfo (x_aval .dtype ).max
1595
+ else :
1596
+ raise NotImplementedError (f"Unsupported dtype { x_aval .dtype } " )
1597
+ return _reduce_lowering_rule_wg (kind , acc , ctx , x , axes = axes ).result
1598
+
1599
+
1546
1600
@register_lowering_rule (lax .axis_index_p , mgpu .ThreadSemantics .Lane )
1547
1601
def _axis_index_rule (ctx : LoweringRuleContext , * , axis_name : Hashable ):
1548
1602
i32 = ir .IntegerType .get_signless (32 )
0 commit comments