Skip to content

Commit 0f0636a

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
[Mosaic TPU][Pallas] Add pl.reciprocal
PiperOrigin-RevId: 734749577
1 parent 4988adc commit 0f0636a

File tree

7 files changed

+76
-0
lines changed

7 files changed

+76
-0
lines changed

jax/_src/pallas/mosaic/lowering.py

+8
Original file line numberDiff line numberDiff line change
@@ -3222,6 +3222,14 @@ def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x):
32223222
lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule
32233223

32243224

3225+
def _reciprocal_lowering_rule(ctx: LoweringRuleContext, x, *, approx):
3226+
if not isinstance(x.type.element_type, ir.F32Type):
3227+
raise ValueError("Only float32 is supported.")
3228+
return tpu.reciprocal(x, approx=approx)
3229+
3230+
3231+
lowering_rules[primitives.reciprocal_p] = _reciprocal_lowering_rule
3232+
32253233
def _bitcast_lowering_rule(ctx: LoweringRuleContext, x, *, ty):
32263234
del ty
32273235
(out_aval,) = ctx.avals_out

jax/_src/pallas/primitives.py

+28
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,34 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False,
705705
preferred_element_type=out_dtype,
706706
)
707707

708+
reciprocal_p = jax_core.Primitive("reciprocal")
709+
710+
711+
def reciprocal(x, *, approx=False):
712+
return reciprocal_p.bind(x, approx=approx)
713+
714+
715+
@reciprocal_p.def_abstract_eval
716+
def _reciprocal_abstract_eval(x, *, approx):
717+
del approx
718+
return x
719+
720+
721+
def _reciprocal_lowering_rule(
722+
ctx: mlir.LoweringRuleContext, x, *, approx=False
723+
):
724+
def _reciprocal(x, *, approx=False):
725+
if approx:
726+
return jnp.reciprocal(x.astype(jnp.bfloat16)).astype(jnp.float32)
727+
return jnp.reciprocal(x)
728+
729+
return mlir.lower_fun(_reciprocal, multiple_results=False)(
730+
ctx, x, approx=approx
731+
)
732+
733+
734+
mlir.register_lowering(reciprocal_p, _reciprocal_lowering_rule)
735+
708736

709737
class PrintEffect(effects.Effect):
710738
__str__ = lambda self: "Print"

jax/experimental/jax2tf/jax2tf.py

+1
Original file line numberDiff line numberDiff line change
@@ -1545,6 +1545,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
15451545
"symmetric_product",
15461546
"from_edtype",
15471547
"to_edtype",
1548+
"reciprocal",
15481549
# Pallas TPU primitives
15491550
"bitcast",
15501551
"repeat",

jax/experimental/pallas/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from jax._src.pallas.primitives import multiple_of as multiple_of
5454
from jax._src.pallas.primitives import num_programs as num_programs
5555
from jax._src.pallas.primitives import program_id as program_id
56+
from jax._src.pallas.primitives import reciprocal as reciprocal
5657
from jax._src.pallas.primitives import run_scoped as run_scoped
5758
from jax._src.pallas.primitives import store as store
5859
from jax._src.pallas.primitives import swap as swap

jaxlib/mosaic/dialect/tpu/tpu.td

+10
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,16 @@ def TPU_WeirdOp : TPU_Op<"weird", [Pure, ElementwiseMappable]> {
538538
let hasVerifier = 1;
539539
}
540540

541+
def TPU_ReciprocalOp : TPU_Op<"reciprocal", [Pure, SameOperandsAndResultType, ElementwiseMappable]> {
542+
let arguments = (ins
543+
AnyVectorOfNonZeroRank:$input,
544+
DefaultValuedAttr<BoolAttr, "false">:$approx
545+
);
546+
let results = (outs AnyVectorOfNonZeroRank:$output);
547+
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
548+
let hasVerifier = 1;
549+
}
550+
541551
def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> {
542552
let arguments = (ins Variadic<AnyVectorOfNonZeroRank>:$input);
543553
let results = (outs AnyVectorOfNonZeroRank:$output);

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

+8
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ limitations under the License.
3434
#include "mlir/include/mlir/IR/Builders.h"
3535
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h"
3636
#include "mlir/include/mlir/IR/BuiltinTypes.h"
37+
#include "mlir/include/mlir/IR/Diagnostics.h"
3738
#include "mlir/include/mlir/IR/IRMapping.h"
3839
#include "mlir/include/mlir/IR/OperationSupport.h"
3940
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
@@ -1164,6 +1165,13 @@ LogicalResult LogBufferOp::verify() {
11641165
return success();
11651166
}
11661167

1168+
LogicalResult ReciprocalOp::verify() {
1169+
if (!getType().getElementType().isF32()) {
1170+
return emitOpError("Not implemented: Reciprocal op for non-f32 dtypes");
1171+
}
1172+
return success();
1173+
}
1174+
11671175
void PackSubelementsOp::build(OpBuilder &builder, OperationState &state,
11681176
const VectorType output_type,
11691177
const ArrayRef<Value> padded_sources,

tests/pallas/ops_test.py

+20
Original file line numberDiff line numberDiff line change
@@ -2459,6 +2459,26 @@ def body(x_ref):
24592459
wrap_init(body, 1), [state.shaped_array_ref((4, 3, 2), jnp.int32)])
24602460
self.assertIn(expected, jaxpr.pretty_print(use_color=False))
24612461

2462+
@parameterized.product(approx=[False, True])
2463+
def test_reciprocal(self, approx):
2464+
if not jtu.test_device_matches(["tpu"]):
2465+
self.skipTest("Not implemented on non-TPU devices")
2466+
if not jtu.if_cloud_tpu_at_least(2025, 3, 8):
2467+
self.skipTest("Test requires libtpu from 2025/3/8 or later")
2468+
shape = (32, 256)
2469+
x = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
2470+
2471+
def kernel(x_ref, o_ref):
2472+
o_ref[...] = pl.reciprocal(x_ref[...], approx=approx)
2473+
2474+
out = self.pallas_call(
2475+
kernel, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32)
2476+
)(x)
2477+
kwargs = {}
2478+
if approx:
2479+
kwargs.update(dict(atol=2e-5, rtol=2e-5))
2480+
np.testing.assert_allclose(out, jax.lax.reciprocal(x), **kwargs)
2481+
24622482

24632483
class PallasPrimitivesInterpretTest(PallasPrimitivesTest):
24642484
INTERPRET = True

0 commit comments

Comments
 (0)